github.com/kubeflow/training-operator@v1.7.0/examples/tensorflow/distribution_strategy/keras-API/multi_worker_strategy-with-keras.py (about) 1 # Copyright 2020 The Kubeflow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """An example of multi-worker training with Keras model using Strategy API.""" 16 17 from __future__ import absolute_import, division, print_function 18 19 import argparse 20 import json 21 import os 22 23 import tensorflow_datasets as tfds 24 import tensorflow as tf 25 from tensorflow.keras import layers, models 26 27 28 def make_datasets_unbatched(): 29 BUFFER_SIZE = 10000 30 31 # Scaling MNIST data from (0, 255] to (0., 1.] 32 def scale(image, label): 33 image = tf.cast(image, tf.float32) 34 image /= 255 35 return image, label 36 37 datasets, _ = tfds.load(name='mnist', with_info=True, as_supervised=True) 38 39 return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE) 40 41 42 def build_and_compile_cnn_model(): 43 model = models.Sequential() 44 model.add( 45 layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) 46 model.add(layers.MaxPooling2D((2, 2))) 47 model.add(layers.Conv2D(64, (3, 3), activation='relu')) 48 model.add(layers.MaxPooling2D((2, 2))) 49 model.add(layers.Conv2D(64, (3, 3), activation='relu')) 50 model.add(layers.Flatten()) 51 model.add(layers.Dense(64, activation='relu')) 52 model.add(layers.Dense(10, activation='softmax')) 53 54 model.summary() 55 56 model.compile(optimizer='adam', 57 loss='sparse_categorical_crossentropy', 58 metrics=['accuracy']) 59 60 return model 61 62 63 def decay(epoch): 64 if epoch < 3: #pylint: disable=no-else-return 65 return 1e-3 66 if 3 <= epoch < 7: 67 return 1e-4 68 return 1e-5 69 70 71 def main(args): 72 73 # MultiWorkerMirroredStrategy creates copies of all variables in the model's 74 # layers on each device across all workers 75 # if your GPUs don't support NCCL, replace "communication" with another 76 strategy = tf.distribute.MultiWorkerMirroredStrategy( 77 communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CollectiveCommunication.AUTO)) 78 79 BATCH_SIZE_PER_REPLICA = 64 80 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync 81 82 with strategy.scope(): 83 ds_train = make_datasets_unbatched().batch(BATCH_SIZE).repeat() 84 options = tf.data.Options() 85 options.experimental_distribute.auto_shard_policy = \ 86 tf.data.experimental.AutoShardPolicy.DATA 87 ds_train = ds_train.with_options(options) 88 # Model building/compiling need to be within `strategy.scope()`. 89 multi_worker_model = build_and_compile_cnn_model() 90 91 # Define the checkpoint directory to store the checkpoints 92 checkpoint_dir = args.checkpoint_dir 93 94 # Name of the checkpoint files 95 checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") 96 97 # Function for decaying the learning rate. 98 # You can define any decay function you need. 99 # Callback for printing the LR at the end of each epoch. 100 class PrintLR(tf.keras.callbacks.Callback): 101 102 def on_epoch_end(self, epoch, logs=None): #pylint: disable=no-self-use 103 print('\nLearning rate for epoch {} is {}'.format( 104 epoch + 1, multi_worker_model.optimizer.lr.numpy())) 105 106 callbacks = [ 107 tf.keras.callbacks.TensorBoard(log_dir='./logs'), 108 tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, 109 save_weights_only=True), 110 tf.keras.callbacks.LearningRateScheduler(decay), 111 PrintLR() 112 ] 113 114 # Keras' `model.fit()` trains the model with specified number of epochs and 115 # number of steps per epoch. Note that the numbers here are for demonstration 116 # purposes only and may not sufficiently produce a model with good quality. 117 multi_worker_model.fit(ds_train, 118 epochs=10, 119 steps_per_epoch=70, 120 callbacks=callbacks) 121 122 # Saving a model 123 # Let `is_chief` be a utility function that inspects the cluster spec and 124 # current task type and returns True if the worker is the chief and False 125 # otherwise. 126 def is_chief(): 127 return TASK_INDEX == 0 128 129 if is_chief(): 130 model_path = args.saved_model_dir 131 132 else: 133 # Save to a path that is unique across workers. 134 model_path = args.saved_model_dir + '/worker_tmp_' + str(TASK_INDEX) 135 136 multi_worker_model.save(model_path) 137 138 139 if __name__ == '__main__': 140 os.environ['NCCL_DEBUG'] = 'INFO' 141 142 tfds.disable_progress_bar() 143 144 # to decide if a worker is chief, get TASK_INDEX in Cluster info 145 tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}') 146 TASK_INDEX = tf_config['task']['index'] 147 148 parser = argparse.ArgumentParser() 149 parser.add_argument('--saved_model_dir', 150 type=str, 151 required=True, 152 help='Tensorflow export directory.') 153 154 parser.add_argument('--checkpoint_dir', 155 type=str, 156 required=True, 157 help='Tensorflow checkpoint directory.') 158 159 parsed_args = parser.parse_args() 160 main(parsed_args)