Path: blob/master/Part 5 - Association Rule Learning/Apriori/apyori.py
1339 views
#!/usr/bin/env python12"""3a simple implementation of Apriori algorithm by Python.4"""56import sys7import csv8import argparse9import json10import os11from collections import namedtuple12from itertools import combinations13from itertools import chain141516# Meta informations.17__version__ = '1.1.1'18__author__ = 'Yu Mochizuki'19__author_email__ = '[email protected]'202122################################################################################23# Data structures.24################################################################################25class TransactionManager(object):26"""27Transaction managers.28"""2930def __init__(self, transactions):31"""32Initialize.3334Arguments:35transactions -- A transaction iterable object36(eg. [['A', 'B'], ['B', 'C']]).37"""38self.__num_transaction = 039self.__items = []40self.__transaction_index_map = {}4142for transaction in transactions:43self.add_transaction(transaction)4445def add_transaction(self, transaction):46"""47Add a transaction.4849Arguments:50transaction -- A transaction as an iterable object (eg. ['A', 'B']).51"""52for item in transaction:53if item not in self.__transaction_index_map:54self.__items.append(item)55self.__transaction_index_map[item] = set()56self.__transaction_index_map[item].add(self.__num_transaction)57self.__num_transaction += 15859def calc_support(self, items):60"""61Returns a support for items.6263Arguments:64items -- Items as an iterable object (eg. ['A', 'B']).65"""66# Empty items is supported by all transactions.67if not items:68return 1.06970# Empty transactions supports no items.71if not self.num_transaction:72return 0.07374# Create the transaction index intersection.75sum_indexes = None76for item in items:77indexes = self.__transaction_index_map.get(item)78if indexes is None:79# No support for any set that contains a not existing item.80return 0.08182if sum_indexes is None:83# Assign the indexes on the first time.84sum_indexes = indexes85else:86# Calculate the intersection on not the first time.87sum_indexes = sum_indexes.intersection(indexes)8889# Calculate and return the support.90return float(len(sum_indexes)) / self.__num_transaction9192def initial_candidates(self):93"""94Returns the initial candidates.95"""96return [frozenset([item]) for item in self.items]9798@property99def num_transaction(self):100"""101Returns the number of transactions.102"""103return self.__num_transaction104105@property106def items(self):107"""108Returns the item list that the transaction is consisted of.109"""110return sorted(self.__items)111112@staticmethod113def create(transactions):114"""115Create the TransactionManager with a transaction instance.116If the given instance is a TransactionManager, this returns itself.117"""118if isinstance(transactions, TransactionManager):119return transactions120return TransactionManager(transactions)121122123# Ignore name errors because these names are namedtuples.124SupportRecord = namedtuple( # pylint: disable=C0103125'SupportRecord', ('items', 'support'))126RelationRecord = namedtuple( # pylint: disable=C0103127'RelationRecord', SupportRecord._fields + ('ordered_statistics',))128OrderedStatistic = namedtuple( # pylint: disable=C0103129'OrderedStatistic', ('items_base', 'items_add', 'confidence', 'lift',))130131132################################################################################133# Inner functions.134################################################################################135def create_next_candidates(prev_candidates, length):136"""137Returns the apriori candidates as a list.138139Arguments:140prev_candidates -- Previous candidates as a list.141length -- The lengths of the next candidates.142"""143# Solve the items.144item_set = set()145for candidate in prev_candidates:146for item in candidate:147item_set.add(item)148items = sorted(item_set)149150# Create the temporary candidates. These will be filtered below.151tmp_next_candidates = (frozenset(x) for x in combinations(items, length))152153# Return all the candidates if the length of the next candidates is 2154# because their subsets are the same as items.155if length < 3:156return list(tmp_next_candidates)157158# Filter candidates that all of their subsets are159# in the previous candidates.160next_candidates = [161candidate for candidate in tmp_next_candidates162if all(163True if frozenset(x) in prev_candidates else False164for x in combinations(candidate, length - 1))165]166return next_candidates167168169def gen_support_records(transaction_manager, min_support, **kwargs):170"""171Returns a generator of support records with given transactions.172173Arguments:174transaction_manager -- Transactions as a TransactionManager instance.175min_support -- A minimum support (float).176177Keyword arguments:178max_length -- The maximum length of relations (integer).179"""180# Parse arguments.181max_length = kwargs.get('max_length')182183# For testing.184_create_next_candidates = kwargs.get(185'_create_next_candidates', create_next_candidates)186187# Process.188candidates = transaction_manager.initial_candidates()189length = 1190while candidates:191relations = set()192for relation_candidate in candidates:193support = transaction_manager.calc_support(relation_candidate)194if support < min_support:195continue196candidate_set = frozenset(relation_candidate)197relations.add(candidate_set)198yield SupportRecord(candidate_set, support)199length += 1200if max_length and length > max_length:201break202candidates = _create_next_candidates(relations, length)203204205def gen_ordered_statistics(transaction_manager, record):206"""207Returns a generator of ordered statistics as OrderedStatistic instances.208209Arguments:210transaction_manager -- Transactions as a TransactionManager instance.211record -- A support record as a SupportRecord instance.212"""213items = record.items214for combination_set in combinations(sorted(items), len(items) - 1):215items_base = frozenset(combination_set)216items_add = frozenset(items.difference(items_base))217confidence = (218record.support / transaction_manager.calc_support(items_base))219lift = confidence / transaction_manager.calc_support(items_add)220yield OrderedStatistic(221frozenset(items_base), frozenset(items_add), confidence, lift)222223224def filter_ordered_statistics(ordered_statistics, **kwargs):225"""226Filter OrderedStatistic objects.227228Arguments:229ordered_statistics -- A OrderedStatistic iterable object.230231Keyword arguments:232min_confidence -- The minimum confidence of relations (float).233min_lift -- The minimum lift of relations (float).234"""235min_confidence = kwargs.get('min_confidence', 0.0)236min_lift = kwargs.get('min_lift', 0.0)237238for ordered_statistic in ordered_statistics:239if ordered_statistic.confidence < min_confidence:240continue241if ordered_statistic.lift < min_lift:242continue243yield ordered_statistic244245246################################################################################247# API function.248################################################################################249def apriori(transactions, **kwargs):250"""251Executes Apriori algorithm and returns a RelationRecord generator.252253Arguments:254transactions -- A transaction iterable object255(eg. [['A', 'B'], ['B', 'C']]).256257Keyword arguments:258min_support -- The minimum support of relations (float).259min_confidence -- The minimum confidence of relations (float).260min_lift -- The minimum lift of relations (float).261max_length -- The maximum length of the relation (integer).262"""263# Parse the arguments.264min_support = kwargs.get('min_support', 0.1)265min_confidence = kwargs.get('min_confidence', 0.0)266min_lift = kwargs.get('min_lift', 0.0)267max_length = kwargs.get('max_length', None)268269# Check arguments.270if min_support <= 0:271raise ValueError('minimum support must be > 0')272273# For testing.274_gen_support_records = kwargs.get(275'_gen_support_records', gen_support_records)276_gen_ordered_statistics = kwargs.get(277'_gen_ordered_statistics', gen_ordered_statistics)278_filter_ordered_statistics = kwargs.get(279'_filter_ordered_statistics', filter_ordered_statistics)280281# Calculate supports.282transaction_manager = TransactionManager.create(transactions)283support_records = _gen_support_records(284transaction_manager, min_support, max_length=max_length)285286# Calculate ordered stats.287for support_record in support_records:288ordered_statistics = list(289_filter_ordered_statistics(290_gen_ordered_statistics(transaction_manager, support_record),291min_confidence=min_confidence,292min_lift=min_lift,293)294)295if not ordered_statistics:296continue297yield RelationRecord(298support_record.items, support_record.support, ordered_statistics)299300301################################################################################302# Application functions.303################################################################################304def parse_args(argv):305"""306Parse commandline arguments.307308Arguments:309argv -- An argument list without the program name.310"""311output_funcs = {312'json': dump_as_json,313'tsv': dump_as_two_item_tsv,314}315default_output_func_key = 'json'316317parser = argparse.ArgumentParser()318parser.add_argument(319'-v', '--version', action='version',320version='%(prog)s {0}'.format(__version__))321parser.add_argument(322'input', metavar='inpath', nargs='*',323help='Input transaction file (default: stdin).',324type=argparse.FileType('r'), default=[sys.stdin])325parser.add_argument(326'-o', '--output', metavar='outpath',327help='Output file (default: stdout).',328type=argparse.FileType('w'), default=sys.stdout)329parser.add_argument(330'-l', '--max-length', metavar='int',331help='Max length of relations (default: infinite).',332type=int, default=None)333parser.add_argument(334'-s', '--min-support', metavar='float',335help='Minimum support ratio (must be > 0, default: 0.1).',336type=float, default=0.1)337parser.add_argument(338'-c', '--min-confidence', metavar='float',339help='Minimum confidence (default: 0.5).',340type=float, default=0.5)341parser.add_argument(342'-t', '--min-lift', metavar='float',343help='Minimum lift (default: 0.0).',344type=float, default=0.0)345parser.add_argument(346'-d', '--delimiter', metavar='str',347help='Delimiter for items of transactions (default: tab).',348type=str, default='\t')349parser.add_argument(350'-f', '--out-format', metavar='str',351help='Output format ({0}; default: {1}).'.format(352', '.join(output_funcs.keys()), default_output_func_key),353type=str, choices=output_funcs.keys(), default=default_output_func_key)354args = parser.parse_args(argv)355356args.output_func = output_funcs[args.out_format]357return args358359360def load_transactions(input_file, **kwargs):361"""362Load transactions and returns a generator for transactions.363364Arguments:365input_file -- An input file.366367Keyword arguments:368delimiter -- The delimiter of the transaction.369"""370delimiter = kwargs.get('delimiter', '\t')371for transaction in csv.reader(input_file, delimiter=delimiter):372yield transaction if transaction else ['']373374375def dump_as_json(record, output_file):376"""377Dump an relation record as a json value.378379Arguments:380record -- A RelationRecord instance to dump.381output_file -- A file to output.382"""383def default_func(value):384"""385Default conversion for JSON value.386"""387if isinstance(value, frozenset):388return sorted(value)389raise TypeError(repr(value) + " is not JSON serializable")390391converted_record = record._replace(392ordered_statistics=[x._asdict() for x in record.ordered_statistics])393json.dump(394converted_record._asdict(), output_file,395default=default_func, ensure_ascii=False)396output_file.write(os.linesep)397398399def dump_as_two_item_tsv(record, output_file):400"""401Dump a relation record as TSV only for 2 item relations.402403Arguments:404record -- A RelationRecord instance to dump.405output_file -- A file to output.406"""407for ordered_stats in record.ordered_statistics:408if len(ordered_stats.items_base) != 1:409continue410if len(ordered_stats.items_add) != 1:411continue412output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}{5}'.format(413list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0],414record.support, ordered_stats.confidence, ordered_stats.lift,415os.linesep))416417418def main(**kwargs):419"""420Executes Apriori algorithm and print its result.421"""422# For tests.423_parse_args = kwargs.get('_parse_args', parse_args)424_load_transactions = kwargs.get('_load_transactions', load_transactions)425_apriori = kwargs.get('_apriori', apriori)426427args = _parse_args(sys.argv[1:])428transactions = _load_transactions(429chain(*args.input), delimiter=args.delimiter)430result = _apriori(431transactions,432max_length=args.max_length,433min_support=args.min_support,434min_confidence=args.min_confidence)435for record in result:436args.output_func(record, args.output)437438439if __name__ == '__main__':440main()441442443