github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/assets/tf_minimal/train_mnist_keras_mutligpu.py (about)

     1  # Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved.
     2  # Issued under the Apache 2.0 License.
     3  
     4  import sys
     5  from keras.layers import Input, Dense
     6  from keras.models import Model
     7  from keras.datasets import mnist
     8  from keras.utils import to_categorical
     9  from keras.callbacks import ModelCheckpoint, TensorBoard
    10  
    11  from studio import fs_tracker
    12  from studio.multi_gpu import make_parallel
    13  
    14  # this placeholder will contain our input digits, as flat vectors
    15  img = Input((784,))
    16  # fully-connected layer with 128 units and ReLU activation
    17  x = Dense(128, activation='relu')(img)
    18  x = Dense(128, activation='relu')(x)
    19  # output layer with 10 units and a softmax activation
    20  preds = Dense(10, activation='softmax')(x)
    21  
    22  
    23  no_gpus = 2
    24  batch_size = 128
    25  
    26  model = Model(img, preds)
    27  model = make_parallel(model, no_gpus)
    28  model.compile(loss='categorical_crossentropy', optimizer='adam')
    29  
    30  (x_train, y_train), (x_test, y_test) = mnist.load_data()
    31  
    32  x_train = x_train.reshape(60000, 784)
    33  x_test = x_test.reshape(10000, 784)
    34  x_train = x_train.astype('float32')
    35  x_test = x_test.astype('float32')
    36  x_train /= 255
    37  x_test /= 255
    38  
    39  # convert class vectors to binary class matrices
    40  y_train = to_categorical(y_train, 10)
    41  y_test = to_categorical(y_test, 10)
    42  
    43  
    44  checkpointer = ModelCheckpoint(
    45      fs_tracker.get_model_directory() +
    46      '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf')
    47  
    48  
    49  tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(),
    50                           histogram_freq=0,
    51                           write_graph=True,
    52                           write_images=False)
    53  
    54  
    55  model.fit(
    56      x_train,
    57      y_train,
    58      validation_data=(x_test, y_test),
    59      epochs=int(sys.argv[1]),
    60      batch_size=batch_size * no_gpus,
    61      callbacks=[checkpointer, tbcallback])