github.com/kubeflow/training-operator@v1.7.0/examples/tensorflow/dist-mnist/dist_mnist.py (about)

     1  # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  # ==============================================================================
    15  """Distributed MNIST training and validation, with model replicas.
    16  
    17  A simple softmax model with one hidden layer is defined. The parameters
    18  (weights and biases) are located on one parameter server (ps), while the ops
    19  are executed on two worker nodes by default. The TF sessions also run on the
    20  worker node.
    21  Multiple invocations of this script can be done in parallel, with different
    22  values for --task_index. There should be exactly one invocation with
    23  --task_index, which will create a master session that carries out variable
    24  initialization. The other, non-master, sessions will wait for the master
    25  session to finish the initialization before proceeding to the training stage.
    26  
    27  The coordination between the multiple worker invocations occurs due to
    28  the definition of the parameters on the same ps devices. The parameter updates
    29  from one worker is visible to all other workers. As such, the workers can
    30  perform forward computation and gradient calculation in parallel, which
    31  should lead to increased training speed for the simple model.
    32  """
    33  
    34  from __future__ import absolute_import
    35  from __future__ import division
    36  from __future__ import print_function
    37  
    38  import json
    39  import math
    40  import os
    41  import sys
    42  import tempfile
    43  import time
    44  
    45  import tensorflow as tf
    46  from tensorflow.examples.tutorials.mnist import input_data
    47  
    48  flags = tf.app.flags
    49  flags.DEFINE_string("data_dir", "/tmp/mnist-data",
    50                      "Directory for storing mnist data")
    51  flags.DEFINE_boolean("download_only", False,
    52                       "Only perform downloading of data; Do not proceed to "
    53                       "session preparation, model definition or training")
    54  flags.DEFINE_integer("task_index", None,
    55                       "Worker task index, should be >= 0. task_index=0 is "
    56                       "the master worker task the performs the variable "
    57                       "initialization ")
    58  flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine."
    59                       "If you don't use GPU, please set it to '0'")
    60  flags.DEFINE_integer("replicas_to_aggregate", None,
    61                       "Number of replicas to aggregate before parameter update"
    62                       "is applied (For sync_replicas mode only; default: "
    63                       "num_workers)")
    64  flags.DEFINE_integer("hidden_units", 100,
    65                       "Number of units in the hidden layer of the NN")
    66  flags.DEFINE_integer("train_steps", 20000,
    67                       "Number of (global) training steps to perform")
    68  flags.DEFINE_integer("batch_size", 100, "Training batch size")
    69  flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
    70  flags.DEFINE_boolean(
    71      "sync_replicas", False,
    72      "Use the sync_replicas (synchronized replicas) mode, "
    73      "wherein the parameter updates from workers are aggregated "
    74      "before applied to avoid stale gradients")
    75  flags.DEFINE_boolean(
    76      "existing_servers", False, "Whether servers already exists. If True, "
    77      "will use the worker hosts via their GRPC URLs (one client process "
    78      "per worker host). Otherwise, will create an in-process TensorFlow "
    79      "server.")
    80  flags.DEFINE_string("ps_hosts", "localhost:2222",
    81                      "Comma-separated list of hostname:port pairs")
    82  flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
    83                      "Comma-separated list of hostname:port pairs")
    84  flags.DEFINE_string("job_name", None, "job name: worker or ps")
    85  
    86  FLAGS = flags.FLAGS
    87  
    88  IMAGE_PIXELS = 28
    89  
    90  # Example:
    91  #   cluster = {'ps': ['host1:2222', 'host2:2222'],
    92  #              'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
    93  #   os.environ['TF_CONFIG'] = json.dumps(
    94  #       {'cluster': cluster,
    95  #        'task': {'type': 'worker', 'index': 1}})
    96  
    97  def main(unused_argv):
    98    # Parse environment variable TF_CONFIG to get job_name and task_index
    99  
   100    # If not explicitly specified in the constructor and the TF_CONFIG
   101    # environment variable is present, load cluster_spec from TF_CONFIG.
   102    tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}')
   103    task_config = tf_config.get('task', {})
   104    task_type = task_config.get('type')
   105    task_index = task_config.get('index')
   106  
   107    FLAGS.job_name = task_type
   108    FLAGS.task_index = task_index
   109  
   110    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
   111    if FLAGS.download_only:
   112      sys.exit(0)
   113  
   114    if FLAGS.job_name is None or FLAGS.job_name == "":
   115      raise ValueError("Must specify an explicit `job_name`")
   116    if FLAGS.task_index is None or FLAGS.task_index == "":
   117      raise ValueError("Must specify an explicit `task_index`")
   118  
   119    print("job name = %s" % FLAGS.job_name)
   120    print("task index = %d" % FLAGS.task_index)
   121  
   122    cluster_config = tf_config.get('cluster', {})
   123    ps_hosts = cluster_config.get('ps')
   124    worker_hosts = cluster_config.get('worker')
   125  
   126    ps_hosts_str = ','.join(ps_hosts)
   127    worker_hosts_str = ','.join(worker_hosts)
   128  
   129    FLAGS.ps_hosts = ps_hosts_str
   130    FLAGS.worker_hosts = worker_hosts_str
   131  
   132    # Construct the cluster and start the server
   133    ps_spec = FLAGS.ps_hosts.split(",")
   134    worker_spec = FLAGS.worker_hosts.split(",")
   135  
   136    # Get the number of workers.
   137    num_workers = len(worker_spec)
   138  
   139    cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})
   140  
   141    if not FLAGS.existing_servers:
   142      # Not using existing servers. Create an in-process server.
   143      server = tf.train.Server(
   144          cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
   145      if FLAGS.job_name == "ps":
   146        server.join()
   147  
   148    is_chief = (FLAGS.task_index == 0)
   149    if FLAGS.num_gpus > 0:
   150      # Avoid gpu allocation conflict: now allocate task_num -> #gpu
   151      # for each worker in the corresponding machine
   152      gpu = (FLAGS.task_index % FLAGS.num_gpus)
   153      worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
   154    elif FLAGS.num_gpus == 0:
   155      # Just allocate the CPU to worker server
   156      cpu = 0
   157      worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
   158    # The device setter will automatically place Variables ops on separate
   159    # parameter servers (ps). The non-Variable ops will be placed on the workers.
   160    # The ps use CPU and workers use corresponding GPU
   161    with tf.device(
   162        tf.train.replica_device_setter(
   163            worker_device=worker_device,
   164            ps_device="/job:ps/cpu:0",
   165            cluster=cluster)):
   166      global_step = tf.Variable(0, name="global_step", trainable=False)
   167  
   168      # Variables of the hidden layer
   169      hid_w = tf.Variable(
   170          tf.truncated_normal(
   171              [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
   172              stddev=1.0 / IMAGE_PIXELS),
   173          name="hid_w")
   174      hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
   175  
   176      # Variables of the softmax layer
   177      sm_w = tf.Variable(
   178          tf.truncated_normal(
   179              [FLAGS.hidden_units, 10],
   180              stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
   181          name="sm_w")
   182      sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
   183  
   184      # Ops: located on the worker specified with FLAGS.task_index
   185      x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
   186      y_ = tf.placeholder(tf.float32, [None, 10])
   187  
   188      hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
   189      hid = tf.nn.relu(hid_lin)
   190  
   191      y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
   192      cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
   193  
   194      opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
   195  
   196      if FLAGS.sync_replicas:
   197        if FLAGS.replicas_to_aggregate is None:
   198          replicas_to_aggregate = num_workers
   199        else:
   200          replicas_to_aggregate = FLAGS.replicas_to_aggregate
   201  
   202        opt = tf.train.SyncReplicasOptimizer(
   203            opt,
   204            replicas_to_aggregate=replicas_to_aggregate,
   205            total_num_replicas=num_workers,
   206            name="mnist_sync_replicas")
   207  
   208      train_step = opt.minimize(cross_entropy, global_step=global_step)
   209  
   210      if FLAGS.sync_replicas:
   211        local_init_op = opt.local_step_init_op
   212        if is_chief:
   213          local_init_op = opt.chief_init_op
   214  
   215        ready_for_local_init_op = opt.ready_for_local_init_op
   216  
   217        # Initial token and chief queue runners required by the sync_replicas mode
   218        chief_queue_runner = opt.get_chief_queue_runner()
   219        sync_init_op = opt.get_init_tokens_op()
   220  
   221      init_op = tf.global_variables_initializer()
   222      train_dir = tempfile.mkdtemp()
   223  
   224      if FLAGS.sync_replicas:
   225        sv = tf.train.Supervisor(
   226            is_chief=is_chief,
   227            logdir=train_dir,
   228            init_op=init_op,
   229            local_init_op=local_init_op,
   230            ready_for_local_init_op=ready_for_local_init_op,
   231            recovery_wait_secs=1,
   232            global_step=global_step)
   233      else:
   234        sv = tf.train.Supervisor(
   235            is_chief=is_chief,
   236            logdir=train_dir,
   237            init_op=init_op,
   238            recovery_wait_secs=1,
   239            global_step=global_step)
   240  
   241      sess_config = tf.ConfigProto(
   242          allow_soft_placement=True,
   243          log_device_placement=False,
   244          device_filters=["/job:ps",
   245                          "/job:worker/task:%d" % FLAGS.task_index])
   246  
   247      # The chief worker (task_index==0) session will prepare the session,
   248      # while the remaining workers will wait for the preparation to complete.
   249      if is_chief:
   250        print("Worker %d: Initializing session..." % FLAGS.task_index)
   251      else:
   252        print("Worker %d: Waiting for session to be initialized..." %
   253              FLAGS.task_index)
   254  
   255      if FLAGS.existing_servers:
   256        server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
   257        print("Using existing server at: %s" % server_grpc_url)
   258  
   259        sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
   260      else:
   261        sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
   262  
   263      print("Worker %d: Session initialization complete." % FLAGS.task_index)
   264  
   265      if FLAGS.sync_replicas and is_chief:
   266        # Chief worker will start the chief queue runner and call the init op.
   267        sess.run(sync_init_op)
   268        sv.start_queue_runners(sess, [chief_queue_runner])
   269  
   270      # Perform training
   271      time_begin = time.time()
   272      print("Training begins @ %f" % time_begin)
   273  
   274      local_step = 0
   275      while True:
   276        # Training feed
   277        batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
   278        train_feed = {x: batch_xs, y_: batch_ys}
   279  
   280        _, step = sess.run([train_step, global_step], feed_dict=train_feed)
   281        local_step += 1
   282  
   283        now = time.time()
   284        print("%f: Worker %d: training step %d done (global step: %d)" %
   285              (now, FLAGS.task_index, local_step, step))
   286  
   287        if step >= FLAGS.train_steps:
   288          break
   289  
   290      time_end = time.time()
   291      print("Training ends @ %f" % time_end)
   292      training_time = time_end - time_begin
   293      print("Training elapsed time: %f s" % training_time)
   294  
   295      # Validation feed
   296      val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
   297      val_xent = sess.run(cross_entropy, feed_dict=val_feed)
   298      print("After %d training step(s), validation cross entropy = %g" %
   299            (FLAGS.train_steps, val_xent))
   300  
   301  
   302  if __name__ == "__main__":
   303    tf.app.run()