github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/assets/tf_minimal/train_mnist_keras.py (about)

     1  # Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved.
     2  # Issued under the Apache 2.0 License.
     3  
     4  import sys
     5  
     6  from keras.layers import Dense, Flatten
     7  
     8  from keras.models import Sequential
     9  from keras.datasets import mnist
    10  from keras.utils import to_categorical
    11  
    12  from keras.callbacks import ModelCheckpoint, TensorBoard
    13  from keras import optimizers
    14  
    15  from studio import fs_tracker
    16  
    17  
    18  (x_train, y_train), (x_test, y_test) = mnist.load_data()
    19  
    20  x_train = x_train.reshape(60000, 28, 28, 1)
    21  x_test = x_test.reshape(10000, 28, 28, 1)
    22  x_train = x_train.astype('float32')
    23  x_test = x_test.astype('float32')
    24  x_train /= 255
    25  x_test /= 255
    26  
    27  # convert class vectors to binary class matrices
    28  y_train = to_categorical(y_train, 10)
    29  y_test = to_categorical(y_test, 10)
    30  
    31  
    32  model = Sequential()
    33  
    34  model.add(Flatten(input_shape=(28, 28, 1)))
    35  model.add(Dense(128, activation='relu'))
    36  model.add(Dense(128, activation='relu'))
    37  
    38  model.add(Dense(10, activation='softmax'))
    39  model.summary()
    40  
    41  
    42  batch_size = 128
    43  no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 10
    44  lr = 0.01
    45  
    46  print('learning rate = {}'.format(lr))
    47  print('batch size = {}'.format(batch_size))
    48  print('no_epochs = {}'.format(no_epochs))
    49  
    50  model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr),
    51                metrics=['accuracy'])
    52  
    53  print("Saving checkpoints to {}".format(fs_tracker.get_model_directory()))
    54  checkpointer = ModelCheckpoint(
    55      fs_tracker.get_model_directory() +
    56      '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf')
    57  
    58  
    59  tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(),
    60                           histogram_freq=0,
    61                           write_graph=True,
    62                           write_images=True)
    63  
    64  
    65  model.fit(
    66      x_train, y_train, validation_data=(
    67          x_test,
    68          y_test),
    69      epochs=no_epochs,
    70      callbacks=[checkpointer, tbcallback],
    71      batch_size=batch_size)