Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
debakarr
GitHub Repository: debakarr/machinelearning
Path: blob/master/Part 5 - Association Rule Learning/Apriori/apyori.py
1339 views
1
#!/usr/bin/env python
2
3
"""
4
a simple implementation of Apriori algorithm by Python.
5
"""
6
7
import sys
8
import csv
9
import argparse
10
import json
11
import os
12
from collections import namedtuple
13
from itertools import combinations
14
from itertools import chain
15
16
17
# Meta informations.
18
__version__ = '1.1.1'
19
__author__ = 'Yu Mochizuki'
20
__author_email__ = '[email protected]'
21
22
23
################################################################################
24
# Data structures.
25
################################################################################
26
class TransactionManager(object):
27
"""
28
Transaction managers.
29
"""
30
31
def __init__(self, transactions):
32
"""
33
Initialize.
34
35
Arguments:
36
transactions -- A transaction iterable object
37
(eg. [['A', 'B'], ['B', 'C']]).
38
"""
39
self.__num_transaction = 0
40
self.__items = []
41
self.__transaction_index_map = {}
42
43
for transaction in transactions:
44
self.add_transaction(transaction)
45
46
def add_transaction(self, transaction):
47
"""
48
Add a transaction.
49
50
Arguments:
51
transaction -- A transaction as an iterable object (eg. ['A', 'B']).
52
"""
53
for item in transaction:
54
if item not in self.__transaction_index_map:
55
self.__items.append(item)
56
self.__transaction_index_map[item] = set()
57
self.__transaction_index_map[item].add(self.__num_transaction)
58
self.__num_transaction += 1
59
60
def calc_support(self, items):
61
"""
62
Returns a support for items.
63
64
Arguments:
65
items -- Items as an iterable object (eg. ['A', 'B']).
66
"""
67
# Empty items is supported by all transactions.
68
if not items:
69
return 1.0
70
71
# Empty transactions supports no items.
72
if not self.num_transaction:
73
return 0.0
74
75
# Create the transaction index intersection.
76
sum_indexes = None
77
for item in items:
78
indexes = self.__transaction_index_map.get(item)
79
if indexes is None:
80
# No support for any set that contains a not existing item.
81
return 0.0
82
83
if sum_indexes is None:
84
# Assign the indexes on the first time.
85
sum_indexes = indexes
86
else:
87
# Calculate the intersection on not the first time.
88
sum_indexes = sum_indexes.intersection(indexes)
89
90
# Calculate and return the support.
91
return float(len(sum_indexes)) / self.__num_transaction
92
93
def initial_candidates(self):
94
"""
95
Returns the initial candidates.
96
"""
97
return [frozenset([item]) for item in self.items]
98
99
@property
100
def num_transaction(self):
101
"""
102
Returns the number of transactions.
103
"""
104
return self.__num_transaction
105
106
@property
107
def items(self):
108
"""
109
Returns the item list that the transaction is consisted of.
110
"""
111
return sorted(self.__items)
112
113
@staticmethod
114
def create(transactions):
115
"""
116
Create the TransactionManager with a transaction instance.
117
If the given instance is a TransactionManager, this returns itself.
118
"""
119
if isinstance(transactions, TransactionManager):
120
return transactions
121
return TransactionManager(transactions)
122
123
124
# Ignore name errors because these names are namedtuples.
125
SupportRecord = namedtuple( # pylint: disable=C0103
126
'SupportRecord', ('items', 'support'))
127
RelationRecord = namedtuple( # pylint: disable=C0103
128
'RelationRecord', SupportRecord._fields + ('ordered_statistics',))
129
OrderedStatistic = namedtuple( # pylint: disable=C0103
130
'OrderedStatistic', ('items_base', 'items_add', 'confidence', 'lift',))
131
132
133
################################################################################
134
# Inner functions.
135
################################################################################
136
def create_next_candidates(prev_candidates, length):
137
"""
138
Returns the apriori candidates as a list.
139
140
Arguments:
141
prev_candidates -- Previous candidates as a list.
142
length -- The lengths of the next candidates.
143
"""
144
# Solve the items.
145
item_set = set()
146
for candidate in prev_candidates:
147
for item in candidate:
148
item_set.add(item)
149
items = sorted(item_set)
150
151
# Create the temporary candidates. These will be filtered below.
152
tmp_next_candidates = (frozenset(x) for x in combinations(items, length))
153
154
# Return all the candidates if the length of the next candidates is 2
155
# because their subsets are the same as items.
156
if length < 3:
157
return list(tmp_next_candidates)
158
159
# Filter candidates that all of their subsets are
160
# in the previous candidates.
161
next_candidates = [
162
candidate for candidate in tmp_next_candidates
163
if all(
164
True if frozenset(x) in prev_candidates else False
165
for x in combinations(candidate, length - 1))
166
]
167
return next_candidates
168
169
170
def gen_support_records(transaction_manager, min_support, **kwargs):
171
"""
172
Returns a generator of support records with given transactions.
173
174
Arguments:
175
transaction_manager -- Transactions as a TransactionManager instance.
176
min_support -- A minimum support (float).
177
178
Keyword arguments:
179
max_length -- The maximum length of relations (integer).
180
"""
181
# Parse arguments.
182
max_length = kwargs.get('max_length')
183
184
# For testing.
185
_create_next_candidates = kwargs.get(
186
'_create_next_candidates', create_next_candidates)
187
188
# Process.
189
candidates = transaction_manager.initial_candidates()
190
length = 1
191
while candidates:
192
relations = set()
193
for relation_candidate in candidates:
194
support = transaction_manager.calc_support(relation_candidate)
195
if support < min_support:
196
continue
197
candidate_set = frozenset(relation_candidate)
198
relations.add(candidate_set)
199
yield SupportRecord(candidate_set, support)
200
length += 1
201
if max_length and length > max_length:
202
break
203
candidates = _create_next_candidates(relations, length)
204
205
206
def gen_ordered_statistics(transaction_manager, record):
207
"""
208
Returns a generator of ordered statistics as OrderedStatistic instances.
209
210
Arguments:
211
transaction_manager -- Transactions as a TransactionManager instance.
212
record -- A support record as a SupportRecord instance.
213
"""
214
items = record.items
215
for combination_set in combinations(sorted(items), len(items) - 1):
216
items_base = frozenset(combination_set)
217
items_add = frozenset(items.difference(items_base))
218
confidence = (
219
record.support / transaction_manager.calc_support(items_base))
220
lift = confidence / transaction_manager.calc_support(items_add)
221
yield OrderedStatistic(
222
frozenset(items_base), frozenset(items_add), confidence, lift)
223
224
225
def filter_ordered_statistics(ordered_statistics, **kwargs):
226
"""
227
Filter OrderedStatistic objects.
228
229
Arguments:
230
ordered_statistics -- A OrderedStatistic iterable object.
231
232
Keyword arguments:
233
min_confidence -- The minimum confidence of relations (float).
234
min_lift -- The minimum lift of relations (float).
235
"""
236
min_confidence = kwargs.get('min_confidence', 0.0)
237
min_lift = kwargs.get('min_lift', 0.0)
238
239
for ordered_statistic in ordered_statistics:
240
if ordered_statistic.confidence < min_confidence:
241
continue
242
if ordered_statistic.lift < min_lift:
243
continue
244
yield ordered_statistic
245
246
247
################################################################################
248
# API function.
249
################################################################################
250
def apriori(transactions, **kwargs):
251
"""
252
Executes Apriori algorithm and returns a RelationRecord generator.
253
254
Arguments:
255
transactions -- A transaction iterable object
256
(eg. [['A', 'B'], ['B', 'C']]).
257
258
Keyword arguments:
259
min_support -- The minimum support of relations (float).
260
min_confidence -- The minimum confidence of relations (float).
261
min_lift -- The minimum lift of relations (float).
262
max_length -- The maximum length of the relation (integer).
263
"""
264
# Parse the arguments.
265
min_support = kwargs.get('min_support', 0.1)
266
min_confidence = kwargs.get('min_confidence', 0.0)
267
min_lift = kwargs.get('min_lift', 0.0)
268
max_length = kwargs.get('max_length', None)
269
270
# Check arguments.
271
if min_support <= 0:
272
raise ValueError('minimum support must be > 0')
273
274
# For testing.
275
_gen_support_records = kwargs.get(
276
'_gen_support_records', gen_support_records)
277
_gen_ordered_statistics = kwargs.get(
278
'_gen_ordered_statistics', gen_ordered_statistics)
279
_filter_ordered_statistics = kwargs.get(
280
'_filter_ordered_statistics', filter_ordered_statistics)
281
282
# Calculate supports.
283
transaction_manager = TransactionManager.create(transactions)
284
support_records = _gen_support_records(
285
transaction_manager, min_support, max_length=max_length)
286
287
# Calculate ordered stats.
288
for support_record in support_records:
289
ordered_statistics = list(
290
_filter_ordered_statistics(
291
_gen_ordered_statistics(transaction_manager, support_record),
292
min_confidence=min_confidence,
293
min_lift=min_lift,
294
)
295
)
296
if not ordered_statistics:
297
continue
298
yield RelationRecord(
299
support_record.items, support_record.support, ordered_statistics)
300
301
302
################################################################################
303
# Application functions.
304
################################################################################
305
def parse_args(argv):
306
"""
307
Parse commandline arguments.
308
309
Arguments:
310
argv -- An argument list without the program name.
311
"""
312
output_funcs = {
313
'json': dump_as_json,
314
'tsv': dump_as_two_item_tsv,
315
}
316
default_output_func_key = 'json'
317
318
parser = argparse.ArgumentParser()
319
parser.add_argument(
320
'-v', '--version', action='version',
321
version='%(prog)s {0}'.format(__version__))
322
parser.add_argument(
323
'input', metavar='inpath', nargs='*',
324
help='Input transaction file (default: stdin).',
325
type=argparse.FileType('r'), default=[sys.stdin])
326
parser.add_argument(
327
'-o', '--output', metavar='outpath',
328
help='Output file (default: stdout).',
329
type=argparse.FileType('w'), default=sys.stdout)
330
parser.add_argument(
331
'-l', '--max-length', metavar='int',
332
help='Max length of relations (default: infinite).',
333
type=int, default=None)
334
parser.add_argument(
335
'-s', '--min-support', metavar='float',
336
help='Minimum support ratio (must be > 0, default: 0.1).',
337
type=float, default=0.1)
338
parser.add_argument(
339
'-c', '--min-confidence', metavar='float',
340
help='Minimum confidence (default: 0.5).',
341
type=float, default=0.5)
342
parser.add_argument(
343
'-t', '--min-lift', metavar='float',
344
help='Minimum lift (default: 0.0).',
345
type=float, default=0.0)
346
parser.add_argument(
347
'-d', '--delimiter', metavar='str',
348
help='Delimiter for items of transactions (default: tab).',
349
type=str, default='\t')
350
parser.add_argument(
351
'-f', '--out-format', metavar='str',
352
help='Output format ({0}; default: {1}).'.format(
353
', '.join(output_funcs.keys()), default_output_func_key),
354
type=str, choices=output_funcs.keys(), default=default_output_func_key)
355
args = parser.parse_args(argv)
356
357
args.output_func = output_funcs[args.out_format]
358
return args
359
360
361
def load_transactions(input_file, **kwargs):
362
"""
363
Load transactions and returns a generator for transactions.
364
365
Arguments:
366
input_file -- An input file.
367
368
Keyword arguments:
369
delimiter -- The delimiter of the transaction.
370
"""
371
delimiter = kwargs.get('delimiter', '\t')
372
for transaction in csv.reader(input_file, delimiter=delimiter):
373
yield transaction if transaction else ['']
374
375
376
def dump_as_json(record, output_file):
377
"""
378
Dump an relation record as a json value.
379
380
Arguments:
381
record -- A RelationRecord instance to dump.
382
output_file -- A file to output.
383
"""
384
def default_func(value):
385
"""
386
Default conversion for JSON value.
387
"""
388
if isinstance(value, frozenset):
389
return sorted(value)
390
raise TypeError(repr(value) + " is not JSON serializable")
391
392
converted_record = record._replace(
393
ordered_statistics=[x._asdict() for x in record.ordered_statistics])
394
json.dump(
395
converted_record._asdict(), output_file,
396
default=default_func, ensure_ascii=False)
397
output_file.write(os.linesep)
398
399
400
def dump_as_two_item_tsv(record, output_file):
401
"""
402
Dump a relation record as TSV only for 2 item relations.
403
404
Arguments:
405
record -- A RelationRecord instance to dump.
406
output_file -- A file to output.
407
"""
408
for ordered_stats in record.ordered_statistics:
409
if len(ordered_stats.items_base) != 1:
410
continue
411
if len(ordered_stats.items_add) != 1:
412
continue
413
output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}{5}'.format(
414
list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0],
415
record.support, ordered_stats.confidence, ordered_stats.lift,
416
os.linesep))
417
418
419
def main(**kwargs):
420
"""
421
Executes Apriori algorithm and print its result.
422
"""
423
# For tests.
424
_parse_args = kwargs.get('_parse_args', parse_args)
425
_load_transactions = kwargs.get('_load_transactions', load_transactions)
426
_apriori = kwargs.get('_apriori', apriori)
427
428
args = _parse_args(sys.argv[1:])
429
transactions = _load_transactions(
430
chain(*args.input), delimiter=args.delimiter)
431
result = _apriori(
432
transactions,
433
max_length=args.max_length,
434
min_support=args.min_support,
435
min_confidence=args.min_confidence)
436
for record in result:
437
args.output_func(record, args.output)
438
439
440
if __name__ == '__main__':
441
main()
442
443