Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/BatchNormalization/cifar10_cnn_bn_100epochs.py
3119 views
1
'''
2
This code was originally written by the Keras team. It has been modified by
3
Sunita Nayak at BigVision LLC. to include Batch Normalization in the architecture.
4
5
Train a simple deep CNN on the CIFAR10 small images dataset using Batch Normalization.
6
It gets to a maximum of 87% validation accuracy. It gets to 79% in only 7 epochs. Note
7
that the keras team's maximum accuracy was 79% in 50 epochs. With Batch Normalization,
8
it exceeds 85% in just 21 epochs, and gets to 87% in 39 epochs.
9
'''
10
11
from __future__ import print_function
12
import keras
13
from keras.datasets import cifar10
14
from keras.preprocessing.image import ImageDataGenerator
15
from keras.models import Sequential
16
from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization
17
from keras.layers import Conv2D, MaxPooling2D
18
import os
19
import pickle
20
21
from numpy.random import seed
22
seed(7)
23
24
batch_size = 32
25
num_classes = 10
26
epochs = 100
27
data_augmentation = True
28
num_predictions = 20
29
save_dir = os.path.join(os.getcwd(), 'saved_models_bn_100_s7')
30
model_name = 'keras_cifar10_trained_model.h5'
31
32
# The data, split between train and test sets:
33
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
34
print('x_train shape:', x_train.shape)
35
print(x_train.shape[0], 'train samples')
36
print(x_test.shape[0], 'test samples')
37
38
# Convert class vectors to binary class matrices.
39
y_train = keras.utils.to_categorical(y_train, num_classes)
40
y_test = keras.utils.to_categorical(y_test, num_classes)
41
42
model = Sequential()
43
model.add(Conv2D(32, (3, 3), padding='same',
44
input_shape=x_train.shape[1:]))
45
model.add(BatchNormalization())
46
model.add(Activation('relu'))
47
model.add(Conv2D(32, (3, 3)))
48
model.add(BatchNormalization())
49
model.add(Activation('relu'))
50
model.add(MaxPooling2D(pool_size=(2, 2)))
51
#model.add(Dropout(0.25))
52
53
model.add(Conv2D(64, (3, 3), padding='same'))
54
model.add(BatchNormalization())
55
model.add(Activation('relu'))
56
model.add(Conv2D(64, (3, 3)))
57
model.add(BatchNormalization())
58
model.add(Activation('relu'))
59
model.add(MaxPooling2D(pool_size=(2, 2)))
60
#model.add(Dropout(0.25))
61
62
model.add(Flatten())
63
model.add(Dense(512))
64
model.add(BatchNormalization())
65
model.add(Activation('relu'))
66
#model.add(Dropout(0.5))
67
model.add(Dense(num_classes))
68
model.add(BatchNormalization())
69
model.add(Activation('softmax'))
70
71
# initiate RMSprop optimizer
72
opt = keras.optimizers.rmsprop(lr=0.001, decay=1e-6)
73
74
# Let's train the model using RMSprop
75
model.compile(loss='categorical_crossentropy',
76
optimizer=opt,
77
metrics=['accuracy'])
78
79
x_train = x_train.astype('float32')
80
x_test = x_test.astype('float32')
81
x_train /= 255
82
x_test /= 255
83
84
if not data_augmentation:
85
print('Not using data augmentation.')
86
history = model.fit(x_train, y_train,
87
batch_size=batch_size,
88
epochs=epochs,
89
validation_data=(x_test, y_test),
90
shuffle=True)
91
else:
92
print('Using real-time data augmentation.')
93
# This will do preprocessing and realtime data augmentation:
94
datagen = ImageDataGenerator(
95
featurewise_center=False, # set input mean to 0 over the dataset
96
samplewise_center=False, # set each sample mean to 0
97
featurewise_std_normalization=False, # divide inputs by std of the dataset
98
samplewise_std_normalization=False, # divide each input by its std
99
zca_whitening=False, # apply ZCA whitening
100
zca_epsilon=1e-06, # epsilon for ZCA whitening
101
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
102
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
103
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
104
shear_range=0., # set range for random shear
105
zoom_range=0., # set range for random zoom
106
channel_shift_range=0., # set range for random channel shifts
107
fill_mode='nearest', # set mode for filling points outside the input boundaries
108
cval=0., # value used for fill_mode = "constant"
109
horizontal_flip=True, # randomly flip images
110
vertical_flip=False, # randomly flip images
111
rescale=None, # set rescaling factor (applied before any other transformation)
112
preprocessing_function=None, # set function that will be applied on each input
113
data_format=None, # image data format, either "channels_first" or "channels_last"
114
validation_split=0.0) # fraction of images reserved for validation (strictly between 0 and 1)
115
116
# Compute quantities required for feature-wise normalization
117
# (std, mean, and principal components if ZCA whitening is applied).
118
datagen.fit(x_train)
119
120
# Fit the model on the batches generated by datagen.flow().
121
history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4)
122
123
with open('./trainHistoryDictWithBn1', 'wb') as file_pi:
124
pickle.dump(history.history, file_pi)
125
126
# Save model and weights
127
if not os.path.isdir(save_dir):
128
os.makedirs(save_dir)
129
model_path = os.path.join(save_dir, model_name)
130
model.save(model_path)
131
print('Saved trained model at %s ' % model_path)
132
133
# Score trained model.
134
scores = model.evaluate(x_test, y_test, verbose=1)
135
print('Test loss:', scores[0])
136
print('Test accuracy:', scores[1])
137
138
139