github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/assets/tf_minimal/train_mnist_keras.py (about) 1 # Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. 2 # Issued under the Apache 2.0 License. 3 4 import sys 5 6 from keras.layers import Dense, Flatten 7 8 from keras.models import Sequential 9 from keras.datasets import mnist 10 from keras.utils import to_categorical 11 12 from keras.callbacks import ModelCheckpoint, TensorBoard 13 from keras import optimizers 14 15 from studio import fs_tracker 16 17 18 (x_train, y_train), (x_test, y_test) = mnist.load_data() 19 20 x_train = x_train.reshape(60000, 28, 28, 1) 21 x_test = x_test.reshape(10000, 28, 28, 1) 22 x_train = x_train.astype('float32') 23 x_test = x_test.astype('float32') 24 x_train /= 255 25 x_test /= 255 26 27 # convert class vectors to binary class matrices 28 y_train = to_categorical(y_train, 10) 29 y_test = to_categorical(y_test, 10) 30 31 32 model = Sequential() 33 34 model.add(Flatten(input_shape=(28, 28, 1))) 35 model.add(Dense(128, activation='relu')) 36 model.add(Dense(128, activation='relu')) 37 38 model.add(Dense(10, activation='softmax')) 39 model.summary() 40 41 42 batch_size = 128 43 no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 10 44 lr = 0.01 45 46 print('learning rate = {}'.format(lr)) 47 print('batch size = {}'.format(batch_size)) 48 print('no_epochs = {}'.format(no_epochs)) 49 50 model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr), 51 metrics=['accuracy']) 52 53 print("Saving checkpoints to {}".format(fs_tracker.get_model_directory())) 54 checkpointer = ModelCheckpoint( 55 fs_tracker.get_model_directory() + 56 '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf') 57 58 59 tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(), 60 histogram_freq=0, 61 write_graph=True, 62 write_images=True) 63 64 65 model.fit( 66 x_train, y_train, validation_data=( 67 x_test, 68 y_test), 69 epochs=no_epochs, 70 callbacks=[checkpointer, tbcallback], 71 batch_size=batch_size)