Path: blob/master/CharClassification/train_model.py
3118 views
# import required modules1from keras.preprocessing.image import ImageDataGenerator2from keras import optimizers3import matplotlib.pyplot as plt45# import created model6from net import Net78# Dimensions of our images9img_width, img_height = 32, 321011# 3 channel image12no_of_channels = 31314# train data Directory15train_data_dir = 'train/'16# test data Directory17validation_data_dir = 'test/'1819epochs = 8020batch_size = 322122#initialize model23model = Net.build(width = img_width, height = img_height, depth = no_of_channels)24print('building done')25# Compile model26rms = optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)27print('optimizing done')2829model.compile(loss='categorical_crossentropy',30optimizer=rms,31metrics=['accuracy'])3233print('compiling')3435# this is the augmentation configuration used for training36# horizontal_flip = False, as we need to retain Characters37train_datagen = ImageDataGenerator(38featurewise_center=True,39featurewise_std_normalization=True,40rescale=1. / 255,41shear_range=0.1,42zoom_range=0.1,43rotation_range=5,44width_shift_range=0.05,45height_shift_range=0.05,46horizontal_flip=False)4748# this is the augmentation configuration used for testing, only rescaling49test_datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True, rescale=1. / 255)5051train_generator = train_datagen.flow_from_directory(52train_data_dir,53target_size=(img_width, img_height),54batch_size=batch_size,55class_mode='categorical')5657validation_generator = test_datagen.flow_from_directory(58validation_data_dir,59target_size=(img_width, img_height),60batch_size=batch_size,61class_mode='categorical')6263# fit the model64history = model.fit_generator(65train_generator,66steps_per_epoch=train_generator.samples / batch_size,67epochs=epochs,68validation_data=validation_generator,69validation_steps=validation_generator.samples / batch_size)7071# evaluate on validation dataset72model.evaluate_generator(validation_generator)73# save weights in a file74model.save_weights('trained_weights.h5')7576print(history.history)7778# Loss Curves79plt.figure(figsize=[8,6])80plt.plot(history.history['loss'],'r',linewidth=3.0)81plt.plot(history.history['val_loss'],'b',linewidth=3.0)82plt.legend(['Training loss', 'Validation Loss'],fontsize=18)83plt.xlabel('Epochs ',fontsize=16)84plt.ylabel('Loss',fontsize=16)85plt.title('Loss Curves',fontsize=16)8687# Accuracy Curves88plt.figure(figsize=[8,6])89plt.plot(history.history['acc'],'r',linewidth=3.0)90plt.plot(history.history['val_acc'],'b',linewidth=3.0)9192plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)93plt.xlabel('Epochs ',fontsize=16)94plt.ylabel('Accuracy',fontsize=16)95plt.title('Accuracy Curves',fontsize=16)96plt.show()9798