github.com/kubeflow/training-operator@v1.7.0/examples/tensorflow/distribution_strategy/keras-API/multi_worker_strategy-with-keras.py (about)

     1  # Copyright 2020 The Kubeflow Authors. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  # ==============================================================================
    15  """An example of multi-worker training with Keras model using Strategy API."""
    16  
    17  from __future__ import absolute_import, division, print_function
    18  
    19  import argparse
    20  import json
    21  import os
    22  
    23  import tensorflow_datasets as tfds
    24  import tensorflow as tf
    25  from tensorflow.keras import layers, models
    26  
    27  
    28  def make_datasets_unbatched():
    29    BUFFER_SIZE = 10000
    30  
    31    # Scaling MNIST data from (0, 255] to (0., 1.]
    32    def scale(image, label):
    33      image = tf.cast(image, tf.float32)
    34      image /= 255
    35      return image, label
    36  
    37    datasets, _ = tfds.load(name='mnist', with_info=True, as_supervised=True)
    38  
    39    return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)
    40  
    41  
    42  def build_and_compile_cnn_model():
    43    model = models.Sequential()
    44    model.add(
    45        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    46    model.add(layers.MaxPooling2D((2, 2)))
    47    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    48    model.add(layers.MaxPooling2D((2, 2)))
    49    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    50    model.add(layers.Flatten())
    51    model.add(layers.Dense(64, activation='relu'))
    52    model.add(layers.Dense(10, activation='softmax'))
    53  
    54    model.summary()
    55  
    56    model.compile(optimizer='adam',
    57                  loss='sparse_categorical_crossentropy',
    58                  metrics=['accuracy'])
    59  
    60    return model
    61  
    62  
    63  def decay(epoch):
    64    if epoch < 3: #pylint: disable=no-else-return
    65      return 1e-3
    66    if 3 <= epoch < 7:
    67      return 1e-4
    68    return 1e-5
    69  
    70  
    71  def main(args):
    72  
    73    # MultiWorkerMirroredStrategy creates copies of all variables in the model's
    74    # layers on each device across all workers
    75    # if your GPUs don't support NCCL, replace "communication" with another
    76    strategy = tf.distribute.MultiWorkerMirroredStrategy(
    77        communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CollectiveCommunication.AUTO))
    78  
    79    BATCH_SIZE_PER_REPLICA = 64
    80    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
    81  
    82    with strategy.scope():
    83      ds_train = make_datasets_unbatched().batch(BATCH_SIZE).repeat()
    84      options = tf.data.Options()
    85      options.experimental_distribute.auto_shard_policy = \
    86          tf.data.experimental.AutoShardPolicy.DATA
    87      ds_train = ds_train.with_options(options)
    88      # Model building/compiling need to be within `strategy.scope()`.
    89      multi_worker_model = build_and_compile_cnn_model()
    90  
    91    # Define the checkpoint directory to store the checkpoints
    92    checkpoint_dir = args.checkpoint_dir
    93  
    94    # Name of the checkpoint files
    95    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
    96  
    97    # Function for decaying the learning rate.
    98    # You can define any decay function you need.
    99    # Callback for printing the LR at the end of each epoch.
   100    class PrintLR(tf.keras.callbacks.Callback):
   101  
   102      def on_epoch_end(self, epoch, logs=None): #pylint: disable=no-self-use
   103        print('\nLearning rate for epoch {} is {}'.format(
   104          epoch + 1, multi_worker_model.optimizer.lr.numpy()))
   105  
   106    callbacks = [
   107        tf.keras.callbacks.TensorBoard(log_dir='./logs'),
   108        tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
   109                                           save_weights_only=True),
   110        tf.keras.callbacks.LearningRateScheduler(decay),
   111        PrintLR()
   112    ]
   113  
   114    # Keras' `model.fit()` trains the model with specified number of epochs and
   115    # number of steps per epoch. Note that the numbers here are for demonstration
   116    # purposes only and may not sufficiently produce a model with good quality.
   117    multi_worker_model.fit(ds_train,
   118                           epochs=10,
   119                           steps_per_epoch=70,
   120                           callbacks=callbacks)
   121  
   122    # Saving a model
   123    # Let `is_chief` be a utility function that inspects the cluster spec and
   124    # current task type and returns True if the worker is the chief and False
   125    # otherwise.
   126    def is_chief():
   127      return TASK_INDEX == 0
   128  
   129    if is_chief():
   130      model_path = args.saved_model_dir
   131  
   132    else:
   133      # Save to a path that is unique across workers.
   134      model_path = args.saved_model_dir + '/worker_tmp_' + str(TASK_INDEX)
   135  
   136    multi_worker_model.save(model_path)
   137  
   138  
   139  if __name__ == '__main__':
   140    os.environ['NCCL_DEBUG'] = 'INFO'
   141  
   142    tfds.disable_progress_bar()
   143  
   144    # to decide if a worker is chief, get TASK_INDEX in Cluster info
   145    tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}')
   146    TASK_INDEX = tf_config['task']['index']
   147  
   148    parser = argparse.ArgumentParser()
   149    parser.add_argument('--saved_model_dir',
   150                        type=str,
   151                        required=True,
   152                        help='Tensorflow export directory.')
   153  
   154    parser.add_argument('--checkpoint_dir',
   155                        type=str,
   156                        required=True,
   157                        help='Tensorflow checkpoint directory.')
   158  
   159    parsed_args = parser.parse_args()
   160    main(parsed_args)