Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aimacode
GitHub Repository: aimacode/aima-python
Path: blob/master/knowledge.py
615 views
1
"""Knowledge in learning (Chapter 19)"""
2
3
from collections import defaultdict
4
from functools import partial
5
from itertools import combinations, product
6
from random import shuffle
7
8
import numpy as np
9
10
from logic import (FolKB, constant_symbols, predicate_symbols, standardize_variables,
11
variables, is_definite_clause, subst, expr, Expr)
12
from utils import power_set
13
14
15
def current_best_learning(examples, h, examples_so_far=None):
16
"""
17
[Figure 19.2]
18
The hypothesis is a list of dictionaries, with each dictionary representing
19
a disjunction.
20
"""
21
if examples_so_far is None:
22
examples_so_far = []
23
if not examples:
24
return h
25
26
e = examples[0]
27
if is_consistent(e, h):
28
return current_best_learning(examples[1:], h, examples_so_far + [e])
29
elif false_positive(e, h):
30
for h2 in specializations(examples_so_far + [e], h):
31
h3 = current_best_learning(examples[1:], h2, examples_so_far + [e])
32
if h3 != 'FAIL':
33
return h3
34
elif false_negative(e, h):
35
for h2 in generalizations(examples_so_far + [e], h):
36
h3 = current_best_learning(examples[1:], h2, examples_so_far + [e])
37
if h3 != 'FAIL':
38
return h3
39
40
return 'FAIL'
41
42
43
def specializations(examples_so_far, h):
44
"""Specialize the hypothesis by adding AND operations to the disjunctions"""
45
hypotheses = []
46
47
for i, disj in enumerate(h):
48
for e in examples_so_far:
49
for k, v in e.items():
50
if k in disj or k == 'GOAL':
51
continue
52
53
h2 = h[i].copy()
54
h2[k] = '!' + v
55
h3 = h.copy()
56
h3[i] = h2
57
if check_all_consistency(examples_so_far, h3):
58
hypotheses.append(h3)
59
60
shuffle(hypotheses)
61
return hypotheses
62
63
64
def generalizations(examples_so_far, h):
65
"""Generalize the hypothesis. First delete operations
66
(including disjunctions) from the hypothesis. Then, add OR operations."""
67
hypotheses = []
68
69
# Delete disjunctions
70
disj_powerset = power_set(range(len(h)))
71
for disjs in disj_powerset:
72
h2 = h.copy()
73
for d in reversed(list(disjs)):
74
del h2[d]
75
76
if check_all_consistency(examples_so_far, h2):
77
hypotheses += h2
78
79
# Delete AND operations in disjunctions
80
for i, disj in enumerate(h):
81
a_powerset = power_set(disj.keys())
82
for attrs in a_powerset:
83
h2 = h[i].copy()
84
for a in attrs:
85
del h2[a]
86
87
if check_all_consistency(examples_so_far, [h2]):
88
h3 = h.copy()
89
h3[i] = h2.copy()
90
hypotheses += h3
91
92
# Add OR operations
93
if hypotheses == [] or hypotheses == [{}]:
94
hypotheses = add_or(examples_so_far, h)
95
else:
96
hypotheses.extend(add_or(examples_so_far, h))
97
98
shuffle(hypotheses)
99
return hypotheses
100
101
102
def add_or(examples_so_far, h):
103
"""Add an OR operation to the hypothesis. The AND operations in the disjunction
104
are generated by the last example (which is the problematic one)."""
105
ors = []
106
e = examples_so_far[-1]
107
108
attrs = {k: v for k, v in e.items() if k != 'GOAL'}
109
a_powerset = power_set(attrs.keys())
110
111
for c in a_powerset:
112
h2 = {}
113
for k in c:
114
h2[k] = attrs[k]
115
116
if check_negative_consistency(examples_so_far, h2):
117
h3 = h.copy()
118
h3.append(h2)
119
ors.append(h3)
120
121
return ors
122
123
124
# ______________________________________________________________________________
125
126
127
def version_space_learning(examples):
128
"""
129
[Figure 19.3]
130
The version space is a list of hypotheses, which in turn are a list
131
of dictionaries/disjunctions.
132
"""
133
V = all_hypotheses(examples)
134
for e in examples:
135
if V:
136
V = version_space_update(V, e)
137
138
return V
139
140
141
def version_space_update(V, e):
142
return [h for h in V if is_consistent(e, h)]
143
144
145
def all_hypotheses(examples):
146
"""Build a list of all the possible hypotheses"""
147
values = values_table(examples)
148
h_powerset = power_set(values.keys())
149
hypotheses = []
150
for s in h_powerset:
151
hypotheses.extend(build_attr_combinations(s, values))
152
153
hypotheses.extend(build_h_combinations(hypotheses))
154
155
return hypotheses
156
157
158
def values_table(examples):
159
"""Build a table with all the possible values for each attribute.
160
Returns a dictionary with keys the attribute names and values a list
161
with the possible values for the corresponding attribute."""
162
values = defaultdict(lambda: [])
163
for e in examples:
164
for k, v in e.items():
165
if k == 'GOAL':
166
continue
167
168
mod = '!'
169
if e['GOAL']:
170
mod = ''
171
172
if mod + v not in values[k]:
173
values[k].append(mod + v)
174
175
values = dict(values)
176
return values
177
178
179
def build_attr_combinations(s, values):
180
"""Given a set of attributes, builds all the combinations of values.
181
If the set holds more than one attribute, recursively builds the
182
combinations."""
183
if len(s) == 1:
184
# s holds just one attribute, return its list of values
185
k = values[s[0]]
186
h = [[{s[0]: v}] for v in values[s[0]]]
187
return h
188
189
h = []
190
for i, a in enumerate(s):
191
rest = build_attr_combinations(s[i + 1:], values)
192
for v in values[a]:
193
o = {a: v}
194
for r in rest:
195
t = o.copy()
196
for d in r:
197
t.update(d)
198
h.append([t])
199
200
return h
201
202
203
def build_h_combinations(hypotheses):
204
"""Given a set of hypotheses, builds and returns all the combinations of the
205
hypotheses."""
206
h = []
207
h_powerset = power_set(range(len(hypotheses)))
208
209
for s in h_powerset:
210
t = []
211
for i in s:
212
t.extend(hypotheses[i])
213
h.append(t)
214
215
return h
216
217
218
# ______________________________________________________________________________
219
220
221
def minimal_consistent_det(E, A):
222
"""Return a minimal set of attributes which give consistent determination"""
223
n = len(A)
224
225
for i in range(n + 1):
226
for A_i in combinations(A, i):
227
if consistent_det(A_i, E):
228
return set(A_i)
229
230
231
def consistent_det(A, E):
232
"""Check if the attributes(A) is consistent with the examples(E)"""
233
H = {}
234
235
for e in E:
236
attr_values = tuple(e[attr] for attr in A)
237
if attr_values in H and H[attr_values] != e['GOAL']:
238
return False
239
H[attr_values] = e['GOAL']
240
241
return True
242
243
244
# ______________________________________________________________________________
245
246
247
class FOILContainer(FolKB):
248
"""Hold the kb and other necessary elements required by FOIL."""
249
250
def __init__(self, clauses=None):
251
self.const_syms = set()
252
self.pred_syms = set()
253
super().__init__(clauses)
254
255
def tell(self, sentence):
256
if is_definite_clause(sentence):
257
self.clauses.append(sentence)
258
self.const_syms.update(constant_symbols(sentence))
259
self.pred_syms.update(predicate_symbols(sentence))
260
else:
261
raise Exception('Not a definite clause: {}'.format(sentence))
262
263
def foil(self, examples, target):
264
"""Learn a list of first-order horn clauses
265
'examples' is a tuple: (positive_examples, negative_examples).
266
positive_examples and negative_examples are both lists which contain substitutions."""
267
clauses = []
268
269
pos_examples = examples[0]
270
neg_examples = examples[1]
271
272
while pos_examples:
273
clause, extended_pos_examples = self.new_clause((pos_examples, neg_examples), target)
274
# remove positive examples covered by clause
275
pos_examples = self.update_examples(target, pos_examples, extended_pos_examples)
276
clauses.append(clause)
277
278
return clauses
279
280
def new_clause(self, examples, target):
281
"""Find a horn clause which satisfies part of the positive
282
examples but none of the negative examples.
283
The horn clause is specified as [consequent, list of antecedents]
284
Return value is the tuple (horn_clause, extended_positive_examples)."""
285
clause = [target, []]
286
extended_examples = examples
287
while extended_examples[1]:
288
l = self.choose_literal(self.new_literals(clause), extended_examples)
289
clause[1].append(l)
290
extended_examples = [sum([list(self.extend_example(example, l)) for example in
291
extended_examples[i]], []) for i in range(2)]
292
293
return clause, extended_examples[0]
294
295
def extend_example(self, example, literal):
296
"""Generate extended examples which satisfy the literal."""
297
# find all substitutions that satisfy literal
298
for s in self.ask_generator(subst(example, literal)):
299
s.update(example)
300
yield s
301
302
def new_literals(self, clause):
303
"""Generate new literals based on known predicate symbols.
304
Generated literal must share at least one variable with clause"""
305
share_vars = variables(clause[0])
306
for l in clause[1]:
307
share_vars.update(variables(l))
308
for pred, arity in self.pred_syms:
309
new_vars = {standardize_variables(expr('x')) for _ in range(arity - 1)}
310
for args in product(share_vars.union(new_vars), repeat=arity):
311
if any(var in share_vars for var in args):
312
# make sure we don't return an existing rule
313
if not Expr(pred, args) in clause[1]:
314
yield Expr(pred, *[var for var in args])
315
316
def choose_literal(self, literals, examples):
317
"""Choose the best literal based on the information gain."""
318
return max(literals, key=partial(self.gain, examples=examples))
319
320
def gain(self, l, examples):
321
"""
322
Find the utility of each literal when added to the body of the clause.
323
Utility function is:
324
gain(R, l) = T * (log_2 (post_pos / (post_pos + post_neg)) - log_2 (pre_pos / (pre_pos + pre_neg)))
325
326
where:
327
328
pre_pos = number of possitive bindings of rule R (=current set of rules)
329
pre_neg = number of negative bindings of rule R
330
post_pos = number of possitive bindings of rule R' (= R U {l} )
331
post_neg = number of negative bindings of rule R'
332
T = number of possitive bindings of rule R that are still covered
333
after adding literal l
334
335
"""
336
pre_pos = len(examples[0])
337
pre_neg = len(examples[1])
338
post_pos = sum([list(self.extend_example(example, l)) for example in examples[0]], [])
339
post_neg = sum([list(self.extend_example(example, l)) for example in examples[1]], [])
340
if pre_pos + pre_neg == 0 or len(post_pos) + len(post_neg) == 0:
341
return -1
342
# number of positive example that are represented in extended_examples
343
T = 0
344
for example in examples[0]:
345
represents = lambda d: all(d[x] == example[x] for x in example)
346
if any(represents(l_) for l_ in post_pos):
347
T += 1
348
value = T * (np.log2(len(post_pos) / (len(post_pos) + len(post_neg)) + 1e-12) -
349
np.log2(pre_pos / (pre_pos + pre_neg)))
350
return value
351
352
def update_examples(self, target, examples, extended_examples):
353
"""Add to the kb those examples what are represented in extended_examples
354
List of omitted examples is returned."""
355
uncovered = []
356
for example in examples:
357
represents = lambda d: all(d[x] == example[x] for x in example)
358
if any(represents(l) for l in extended_examples):
359
self.tell(subst(example, target))
360
else:
361
uncovered.append(example)
362
363
return uncovered
364
365
366
# ______________________________________________________________________________
367
368
369
def check_all_consistency(examples, h):
370
"""Check for the consistency of all examples under h."""
371
for e in examples:
372
if not is_consistent(e, h):
373
return False
374
375
return True
376
377
378
def check_negative_consistency(examples, h):
379
"""Check if the negative examples are consistent under h."""
380
for e in examples:
381
if e['GOAL']:
382
continue
383
384
if not is_consistent(e, [h]):
385
return False
386
387
return True
388
389
390
def disjunction_value(e, d):
391
"""The value of example e under disjunction d."""
392
for k, v in d.items():
393
if v[0] == '!':
394
# v is a NOT expression
395
# e[k], thus, should not be equal to v
396
if e[k] == v[1:]:
397
return False
398
elif e[k] != v:
399
return False
400
401
return True
402
403
404
def guess_value(e, h):
405
"""Guess value of example e under hypothesis h."""
406
for d in h:
407
if disjunction_value(e, d):
408
return True
409
410
return False
411
412
413
def is_consistent(e, h):
414
return e['GOAL'] == guess_value(e, h)
415
416
417
def false_positive(e, h):
418
return guess_value(e, h) and not e['GOAL']
419
420
421
def false_negative(e, h):
422
return e['GOAL'] and not guess_value(e, h)
423
424