github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/examples/local/app.py (about)

     1  import sys
     2  
     3  from keras.layers import Dense, Flatten
     4  
     5  from keras.models import Sequential
     6  from keras.datasets import mnist
     7  from keras.utils import to_categorical
     8  
     9  from keras.callbacks import ModelCheckpoint, TensorBoard
    10  from keras import optimizers
    11  
    12  from studio import fs_tracker
    13  
    14  
    15  (x_train, y_train), (x_test, y_test) = mnist.load_data()
    16  
    17  x_train = x_train.reshape(60000, 28, 28, 1)
    18  x_test = x_test.reshape(10000, 28, 28, 1)
    19  x_train = x_train.astype('float32')
    20  x_test = x_test.astype('float32')
    21  x_train /= 255
    22  x_test /= 255
    23  
    24  # convert class vectors to binary class matrices
    25  y_train = to_categorical(y_train, 10)
    26  y_test = to_categorical(y_test, 10)
    27  
    28  
    29  model = Sequential()
    30  
    31  model.add(Flatten(input_shape=(28, 28, 1)))
    32  model.add(Dense(128, activation='relu'))
    33  model.add(Dense(128, activation='relu'))
    34  
    35  model.add(Dense(10, activation='softmax'))
    36  model.summary()
    37  
    38  
    39  batch_size = 128
    40  no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 10
    41  lr = 0.01
    42  
    43  print('{"experiment": {"completed": "true"}}')
    44  print('{}{}{}'.format('{"experiment": {"learning_rate": "', '{}'.format(lr), '"}}'))
    45  print('{}{}{}'.format('{"experiment": {"batch_size": "', '{}'.format(batch_size), '"}}'))
    46  print('{}{}{}'.format('{"experiment": {"no_epochs": "', '{}'.format(no_epochs), '"}}'))
    47  
    48  model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr),
    49                metrics=['accuracy'])
    50  
    51  print("Saving checkpoints to {}".format(fs_tracker.get_model_directory()))
    52  checkpointer = ModelCheckpoint(
    53      fs_tracker.get_model_directory() +
    54      '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf')
    55  
    56  
    57  tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(),
    58                           histogram_freq=0,
    59                           write_graph=True,
    60                           write_images=True)
    61  
    62  model.fit(
    63      x_train, y_train, validation_data=(
    64          x_test,
    65          y_test),
    66      epochs=no_epochs,
    67      callbacks=[checkpointer, tbcallback],
    68      batch_size=batch_size)
    69  
    70  print('{"experiment": {"completed": "true"}}')
    71