github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/examples/local/app.py (about) 1 import sys 2 3 from keras.layers import Dense, Flatten 4 5 from keras.models import Sequential 6 from keras.datasets import mnist 7 from keras.utils import to_categorical 8 9 from keras.callbacks import ModelCheckpoint, TensorBoard 10 from keras import optimizers 11 12 from studio import fs_tracker 13 14 15 (x_train, y_train), (x_test, y_test) = mnist.load_data() 16 17 x_train = x_train.reshape(60000, 28, 28, 1) 18 x_test = x_test.reshape(10000, 28, 28, 1) 19 x_train = x_train.astype('float32') 20 x_test = x_test.astype('float32') 21 x_train /= 255 22 x_test /= 255 23 24 # convert class vectors to binary class matrices 25 y_train = to_categorical(y_train, 10) 26 y_test = to_categorical(y_test, 10) 27 28 29 model = Sequential() 30 31 model.add(Flatten(input_shape=(28, 28, 1))) 32 model.add(Dense(128, activation='relu')) 33 model.add(Dense(128, activation='relu')) 34 35 model.add(Dense(10, activation='softmax')) 36 model.summary() 37 38 39 batch_size = 128 40 no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 10 41 lr = 0.01 42 43 print('{"experiment": {"completed": "true"}}') 44 print('{}{}{}'.format('{"experiment": {"learning_rate": "', '{}'.format(lr), '"}}')) 45 print('{}{}{}'.format('{"experiment": {"batch_size": "', '{}'.format(batch_size), '"}}')) 46 print('{}{}{}'.format('{"experiment": {"no_epochs": "', '{}'.format(no_epochs), '"}}')) 47 48 model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr), 49 metrics=['accuracy']) 50 51 print("Saving checkpoints to {}".format(fs_tracker.get_model_directory())) 52 checkpointer = ModelCheckpoint( 53 fs_tracker.get_model_directory() + 54 '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf') 55 56 57 tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(), 58 histogram_freq=0, 59 write_graph=True, 60 write_images=True) 61 62 model.fit( 63 x_train, y_train, validation_data=( 64 x_test, 65 y_test), 66 epochs=no_epochs, 67 callbacks=[checkpointer, tbcallback], 68 batch_size=batch_size) 69 70 print('{"experiment": {"completed": "true"}}') 71