github.com/kubeflow/training-operator@v1.7.0/examples/tensorflow/tf_sample/tf_smoke.py (about)

     1  """Train a simple TF program to verify we can execute ops.
     2  
     3  The program does a simple matrix multiplication.
     4  
     5  Only the master assigns ops to devices/workers.
     6  
     7  The master will assign ops to every task in the cluster. This way we can verify
     8  that distributed training is working by executing ops on all devices.
     9  """
    10  import argparse
    11  import json
    12  import logging
    13  import os
    14  import retrying
    15  
    16  import tensorflow as tf
    17  
    18  
    19  def parse_args():
    20    """Parse the command line arguments."""
    21    parser = argparse.ArgumentParser()
    22  
    23    parser.add_argument(
    24        "--sleep_secs",
    25        default=0,
    26        type=int,
    27        help=("Amount of time to sleep at the end"))
    28  
    29    # TODO(jlewi): We ignore unknown arguments because the backend is currently
    30    # setting some flags to empty values like metadata path.
    31    args, _ = parser.parse_known_args()
    32    return args
    33  
    34  # Add retries to deal with things like gRPC errors that result in
    35  # UnavailableError.
    36  @retrying.retry(wait_exponential_multiplier=1000, wait_exponential_max=10000,
    37                  stop_max_delay=60*3*1000)
    38  def run(server, cluster_spec):  # pylint: disable=too-many-statements, too-many-locals
    39    """Build the graph and run the example.
    40  
    41    Args:
    42      server: The TensorFlow server to use.
    43  
    44    Raises:
    45      RuntimeError: If the expected log entries aren't found.
    46    """
    47  
    48    # construct the graph and create a saver object
    49    with tf.Graph().as_default():  # pylint: disable=not-context-manager
    50      # The initial value should be such that type is correctly inferred as
    51      # float.
    52      width = 10
    53      height = 10
    54      results = []
    55  
    56      # The master assigns ops to every TFProcess in the cluster.
    57      for job_name in cluster_spec.keys():
    58        for i in range(len(cluster_spec[job_name])):
    59          d = "/job:{0}/task:{1}".format(job_name, i)
    60          with tf.device(d):
    61            a = tf.constant(range(width * height), shape=[height, width])
    62            b = tf.constant(range(width * height), shape=[height, width])
    63            c = tf.multiply(a, b)
    64            results.append(c)
    65  
    66      init_op = tf.global_variables_initializer()
    67  
    68      if server:
    69        target = server.target
    70      else:
    71        # Create a direct session.
    72        target = ""
    73  
    74      logging.info("Server target: %s", target)
    75      with tf.Session(
    76              target, config=tf.ConfigProto(log_device_placement=True)) as sess:
    77        sess.run(init_op)
    78        for r in results:
    79          result = sess.run(r)
    80          logging.info("Result: %s", result)
    81  
    82  
    83  def main():
    84    """Run training.
    85  
    86    Raises:
    87      ValueError: If the arguments are invalid.
    88    """
    89    logging.info("Tensorflow version: %s", tf.__version__)
    90    logging.info("Tensorflow git version: %s", tf.__git_version__)
    91  
    92    tf_config_json = os.environ.get("TF_CONFIG", "{}")
    93    tf_config = json.loads(tf_config_json)
    94    logging.info("tf_config: %s", tf_config)
    95  
    96    task = tf_config.get("task", {})
    97    logging.info("task: %s", task)
    98  
    99    cluster_spec = tf_config.get("cluster", {})
   100    logging.info("cluster_spec: %s", cluster_spec)
   101  
   102    server = None
   103    device_func = None
   104    if cluster_spec:
   105      cluster_spec_object = tf.train.ClusterSpec(cluster_spec)
   106      server_def = tf.train.ServerDef(
   107          cluster=cluster_spec_object.as_cluster_def(),
   108          protocol="grpc",
   109          job_name=task["type"],
   110          task_index=task["index"])
   111  
   112      logging.info("server_def: %s", server_def)
   113  
   114      logging.info("Building server.")
   115      # Create and start a server for the local task.
   116      server = tf.train.Server(server_def)
   117      logging.info("Finished building server.")
   118  
   119      # Assigns ops to the local worker by default.
   120      device_func = tf.train.replica_device_setter(
   121          worker_device="/job:worker/task:%d" % server_def.task_index,
   122          cluster=server_def.cluster)
   123    else:
   124      # This should return a null op device setter since we are using
   125      # all the defaults.
   126      logging.error("Using default device function.")
   127      device_func = tf.train.replica_device_setter()
   128  
   129    job_type = task.get("type", "").lower()
   130    if job_type == "ps":
   131      logging.info("Running PS code.")
   132      server.join()
   133    elif job_type == "worker":
   134      logging.info("Running Worker code.")
   135      # The worker just blocks because we let the master assign all ops.
   136      server.join()
   137    elif job_type in ["master", "chief"] or not job_type:
   138      logging.info("Running master/chief.")
   139      with tf.device(device_func):
   140        run(server=server, cluster_spec=cluster_spec)
   141    else:
   142      raise ValueError("invalid job_type %s" % (job_type,))
   143  
   144  
   145  if __name__ == "__main__":
   146    logging.getLogger().setLevel(logging.INFO)
   147    main()