Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
greyhatguy007
GitHub Repository: greyhatguy007/Machine-Learning-Specialization-Coursera
Path: blob/main/C2 - Advanced Learning Algorithms/week4/C2W4A1/public_tests.py
3585 views
1
import numpy as np
2
3
def compute_entropy_test(target):
4
y = np.array([1] * 10)
5
result = target(y)
6
7
assert result == 0, "Entropy must be 0 with array of ones"
8
9
y = np.array([0] * 10)
10
result = target(y)
11
12
assert result == 0, "Entropy must be 0 with array of zeros"
13
14
y = np.array([0] * 12 + [1] * 12)
15
result = target(y)
16
17
assert result == 1, "Entropy must be 1 with same ammount of ones and zeros"
18
19
y = np.array([1, 0, 1, 0, 1, 1, 1, 0, 1])
20
assert np.isclose(target(y), 0.918295, atol=1e-6), "Wrong value. Something between 0 and 1"
21
assert np.isclose(target(-y + 1), target(y), atol=1e-6), "Wrong value"
22
23
print("\033[92m All tests passed.")
24
25
def split_dataset_test(target):
26
X = np.array([[1, 0],
27
[1, 0],
28
[1, 1],
29
[0, 0],
30
[0, 1]])
31
X_t = np.array([[0, 1, 0, 1, 0]])
32
X = np.concatenate((X, X_t.T), axis=1)
33
34
left, right = target(X, list(range(5)), 2)
35
expected = {'left': np.array([1, 3]),
36
'right': np.array([0, 2, 4])}
37
38
assert type(left) == list, f"Wrong type for left. Expected: list got: {type(left)}"
39
assert type(right) == list, f"Wrong type for right. Expected: list got: {type(right)}"
40
41
assert type(left[0]) == int, f"Wrong type for elements in the left list. Expected: int got: {type(left[0])}"
42
assert type(right[0]) == int, f"Wrong type for elements in the right list. Expected: number got: {type(right[0])}"
43
44
assert len(left) == 2, f"left must have 2 elements but got: {len(left)}"
45
assert len(right) == 3, f"right must have 3 elements but got: {len(right)}"
46
47
assert np.allclose(right, expected['right']), f"Wrong value for right. Expected: { expected['right']} \ngot: {right}"
48
assert np.allclose(left, expected['left']), f"Wrong value for left. Expected: { expected['left']} \ngot: {left}"
49
50
X = np.array([[0, 1],
51
[1, 1],
52
[1, 1],
53
[0, 0],
54
[1, 0]])
55
X_t = np.array([[0, 1, 0, 1, 0]])
56
X = np.concatenate((X_t.T, X), axis=1)
57
58
left, right = target(X, list(range(5)), 0)
59
expected = {'left': np.array([1, 3]),
60
'right': np.array([0, 2, 4])}
61
62
assert np.allclose(right, expected['right']) and np.allclose(left, expected['left']), f"Wrong value when target is at index 0."
63
64
X = (np.random.rand(11, 3) > 0.5) * 1 # Just random binary numbers
65
X_t = np.array([[0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0]])
66
X = np.concatenate((X, X_t.T), axis=1)
67
68
left, right = target(X, [1, 2, 3, 6, 7, 9, 10], 3)
69
expected = {'left': np.array([1, 3, 6]),
70
'right': np.array([2, 7, 9, 10])}
71
72
assert np.allclose(right, expected['right']) and np.allclose(left, expected['left']), f"Wrong value when target is at index 0. \nExpected: {expected} \ngot: \{left:{left}, 'right': {right}\}"
73
74
75
print("\033[92m All tests passed.")
76
77
def compute_information_gain_test(target):
78
X = np.array([[1, 0],
79
[1, 0],
80
[1, 0],
81
[0, 0],
82
[0, 1]])
83
84
y = np.array([[0, 0, 0, 0, 0]]).T
85
node_indexes = list(range(5))
86
87
result1 = target(X, y, node_indexes, 0)
88
result2 = target(X, y, node_indexes, 0)
89
90
assert result1 == 0 and result2 == 0, f"Information gain must be 0 when target variable is pure. Got {result1} and {result2}"
91
92
y = np.array([[0, 1, 0, 1, 0]]).T
93
node_indexes = list(range(5))
94
95
result = target(X, y, node_indexes, 0)
96
assert np.isclose(result, 0.019973, atol=1e-6), f"Wrong information gain. Expected {0.019973} got: {result}"
97
98
result = target(X, y, node_indexes, 1)
99
assert np.isclose(result, 0.170951, atol=1e-6), f"Wrong information gain. Expected {0.170951} got: {result}"
100
101
node_indexes = list(range(4))
102
result = target(X, y, node_indexes, 0)
103
assert np.isclose(result, 0.311278, atol=1e-6), f"Wrong information gain. Expected {0.311278} got: {result}"
104
105
result = target(X, y, node_indexes, 1)
106
assert np.isclose(result, 0, atol=1e-6), f"Wrong information gain. Expected {0.0} got: {result}"
107
108
print("\033[92m All tests passed.")
109
110
def get_best_split_test(target):
111
X = np.array([[1, 0],
112
[1, 0],
113
[1, 0],
114
[0, 0],
115
[0, 1]])
116
117
y = np.array([[0, 0, 0, 0, 0]]).T
118
node_indexes = list(range(5))
119
120
result = target(X, y, node_indexes)
121
122
assert result == -1, f"When the target variable is pure, there is no best split to do. Expected -1, got {result}"
123
124
y = X[:,0]
125
result = target(X, y, node_indexes)
126
assert result == 0, f"If the target is fully correlated with other feature, that feature must be the best split. Expected 0, got {result}"
127
y = X[:,1]
128
result = target(X, y, node_indexes)
129
assert result == 1, f"If the target is fully correlated with other feature, that feature must be the best split. Expected 1, got {result}"
130
131
y = 1 - X[:,0]
132
result = target(X, y, node_indexes)
133
assert result == 0, f"If the target is fully correlated with other feature, that feature must be the best split. Expected 0, got {result}"
134
135
y = np.array([[0, 1, 0, 1, 0]]).T
136
result = target(X, y, node_indexes)
137
assert result == 1, f"Wrong result. Expected 1, got {result}"
138
139
y = np.array([[0, 1, 0, 1, 0]]).T
140
node_indexes = [2, 3, 4]
141
result = target(X, y, node_indexes)
142
assert result == 0, f"Wrong result. Expected 0, got {result}"
143
144
n_samples = 100
145
X0 = np.array([[1] * n_samples])
146
X1 = np.array([[0] * n_samples])
147
X2 = (np.random.rand(1, 100) > 0.5) * 1
148
X3 = np.array([[1] * int(n_samples / 2) + [0] * int(n_samples / 2)])
149
150
y = X2.T
151
node_indexes = list(range(20, 80))
152
X = np.array([X0, X1, X2, X3]).T.reshape(n_samples, 4)
153
result = target(X, y, node_indexes)
154
155
assert result == 2, f"Wrong result. Expected 2, got {result}"
156
157
y = X0.T
158
result = target(X, y, node_indexes)
159
assert result == -1, f"When the target variable is pure, there is no best split to do. Expected -1, got {result}"
160
print("\033[92m All tests passed.")
161