Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/softmax.py
1470 views
1
import numpy as np
2
3
4
class SoftmaxRegression:
5
"""
6
Softmax regression classifier
7
8
Parameters
9
------------
10
eta : float
11
learning rate, or so called step size (between 0.0 and 1.0)
12
13
epochs : int
14
number of passes over the training dataset (iterations),
15
prior to each epoch, the dataset is shuffled
16
if `minibatches > 1` to prevent cycles in stochastic gradient descent.
17
18
minibatches : int
19
The number of minibatches for gradient-based optimization.
20
if len(y): gradient descent
21
if 1: stochastic gradient descent (SGD) online learning
22
if 1 < minibatches < len(y): SGD minibatch learning
23
24
l2 : float, default 0
25
l2 regularization parameter
26
if 0: no regularization
27
"""
28
29
def __init__(self, eta, epochs, minibatches, l2 = 0):
30
self.eta = eta
31
self.epochs = epochs
32
self.minibatches = minibatches
33
self.l2 = l2
34
35
def fit(self, X, y):
36
data_num = X.shape[0]
37
feature_num = X.shape[1]
38
class_num = np.unique(y).shape[0]
39
40
# initialize the weights and bias
41
self.w = np.random.normal(size = (feature_num, class_num))
42
self.b = np.zeros(class_num)
43
self.costs = []
44
45
# one hot encode the output column and shuffle the data before starting
46
y_encode = self._one_hot_encode(y, class_num)
47
X, y_encode = self._shuffle(X, y_encode, data_num)
48
49
# `i` keeps track of the starting index of
50
# current batch, so we can do batch training
51
i = 0
52
53
# note that epochs refers to the number of passes over the
54
# entire dataset, thus if we're using batches, we need to multiply it
55
# with the number of iterations, we also make sure the batch size
56
# doesn't exceed the number of training samples, if it does use batch size of 1
57
iterations = self.epochs * max(data_num // self.minibatches, 1)
58
59
for _ in range(iterations):
60
batch = slice(i, i + self.minibatches)
61
batch_X, batch_y_encode = X[batch], y_encode[batch]
62
63
# forward and store the cross entropy cost
64
net = self._net_input(batch_X)
65
softm = self._softmax(net)
66
error = softm - batch_y_encode
67
cost = self._cross_entropy_cost(output = softm, y_target = batch_y_encode)
68
self.costs.append(cost)
69
70
# compute gradient and update the weight and bias
71
gradient = np.dot(batch_X.T, error)
72
self.w -= self.eta * (gradient + self.l2 * self.w)
73
self.b -= self.eta * np.sum(error, axis = 0)
74
75
# update starting index of for the batches
76
# and if we made a complete pass over data, shuffle again
77
# and refresh the index that keeps track of the batch
78
i += self.minibatches
79
if i + self.minibatches > data_num:
80
X, y_encode = self._shuffle(X, y_encode, data_num)
81
i = 0
82
83
# stating that the model is fitted and
84
# can be used for prediction
85
self._is_fitted = True
86
return self
87
88
def _one_hot_encode(self, y, class_num):
89
y_encode = np.zeros((y.shape[0], class_num))
90
for idx, val in enumerate(y):
91
y_encode[idx, val] = 1.0
92
93
return y_encode
94
95
def _shuffle(self, X, y_encode, data_num):
96
permutation = np.random.permutation(data_num)
97
X, y_encode = X[permutation], y_encode[permutation]
98
return X, y_encode
99
100
def _net_input(self, X):
101
net = X.dot(self.w) + self.b
102
return net
103
104
def _softmax(self, z):
105
softm = np.exp(z) / np.sum(np.exp(z), axis = 1, keepdims = True)
106
return softm
107
108
def _cross_entropy_cost(self, output, y_target):
109
cross_entropy = np.mean(-np.sum(np.log(output) * y_target, axis = 1))
110
l2_penalty = 0.5 * self.l2 * np.sum(self.w ** 2)
111
cost = cross_entropy + l2_penalty
112
return cost
113
114
def predict_proba(self, X):
115
if not self._is_fitted:
116
raise AttributeError('Model is not fitted, yet!')
117
118
net = self._net_input(X)
119
softm = self._softmax(net)
120
return softm
121
122
def predict(self, X):
123
softm = self.predict_proba(X)
124
class_labels = np.argmax(softm, axis = 1)
125
return class_labels
126
127