github.com/kubeflow/training-operator@v1.7.0/examples/tensorflow/tf_sample/tf_smoke.py (about) 1 """Train a simple TF program to verify we can execute ops. 2 3 The program does a simple matrix multiplication. 4 5 Only the master assigns ops to devices/workers. 6 7 The master will assign ops to every task in the cluster. This way we can verify 8 that distributed training is working by executing ops on all devices. 9 """ 10 import argparse 11 import json 12 import logging 13 import os 14 import retrying 15 16 import tensorflow as tf 17 18 19 def parse_args(): 20 """Parse the command line arguments.""" 21 parser = argparse.ArgumentParser() 22 23 parser.add_argument( 24 "--sleep_secs", 25 default=0, 26 type=int, 27 help=("Amount of time to sleep at the end")) 28 29 # TODO(jlewi): We ignore unknown arguments because the backend is currently 30 # setting some flags to empty values like metadata path. 31 args, _ = parser.parse_known_args() 32 return args 33 34 # Add retries to deal with things like gRPC errors that result in 35 # UnavailableError. 36 @retrying.retry(wait_exponential_multiplier=1000, wait_exponential_max=10000, 37 stop_max_delay=60*3*1000) 38 def run(server, cluster_spec): # pylint: disable=too-many-statements, too-many-locals 39 """Build the graph and run the example. 40 41 Args: 42 server: The TensorFlow server to use. 43 44 Raises: 45 RuntimeError: If the expected log entries aren't found. 46 """ 47 48 # construct the graph and create a saver object 49 with tf.Graph().as_default(): # pylint: disable=not-context-manager 50 # The initial value should be such that type is correctly inferred as 51 # float. 52 width = 10 53 height = 10 54 results = [] 55 56 # The master assigns ops to every TFProcess in the cluster. 57 for job_name in cluster_spec.keys(): 58 for i in range(len(cluster_spec[job_name])): 59 d = "/job:{0}/task:{1}".format(job_name, i) 60 with tf.device(d): 61 a = tf.constant(range(width * height), shape=[height, width]) 62 b = tf.constant(range(width * height), shape=[height, width]) 63 c = tf.multiply(a, b) 64 results.append(c) 65 66 init_op = tf.global_variables_initializer() 67 68 if server: 69 target = server.target 70 else: 71 # Create a direct session. 72 target = "" 73 74 logging.info("Server target: %s", target) 75 with tf.Session( 76 target, config=tf.ConfigProto(log_device_placement=True)) as sess: 77 sess.run(init_op) 78 for r in results: 79 result = sess.run(r) 80 logging.info("Result: %s", result) 81 82 83 def main(): 84 """Run training. 85 86 Raises: 87 ValueError: If the arguments are invalid. 88 """ 89 logging.info("Tensorflow version: %s", tf.__version__) 90 logging.info("Tensorflow git version: %s", tf.__git_version__) 91 92 tf_config_json = os.environ.get("TF_CONFIG", "{}") 93 tf_config = json.loads(tf_config_json) 94 logging.info("tf_config: %s", tf_config) 95 96 task = tf_config.get("task", {}) 97 logging.info("task: %s", task) 98 99 cluster_spec = tf_config.get("cluster", {}) 100 logging.info("cluster_spec: %s", cluster_spec) 101 102 server = None 103 device_func = None 104 if cluster_spec: 105 cluster_spec_object = tf.train.ClusterSpec(cluster_spec) 106 server_def = tf.train.ServerDef( 107 cluster=cluster_spec_object.as_cluster_def(), 108 protocol="grpc", 109 job_name=task["type"], 110 task_index=task["index"]) 111 112 logging.info("server_def: %s", server_def) 113 114 logging.info("Building server.") 115 # Create and start a server for the local task. 116 server = tf.train.Server(server_def) 117 logging.info("Finished building server.") 118 119 # Assigns ops to the local worker by default. 120 device_func = tf.train.replica_device_setter( 121 worker_device="/job:worker/task:%d" % server_def.task_index, 122 cluster=server_def.cluster) 123 else: 124 # This should return a null op device setter since we are using 125 # all the defaults. 126 logging.error("Using default device function.") 127 device_func = tf.train.replica_device_setter() 128 129 job_type = task.get("type", "").lower() 130 if job_type == "ps": 131 logging.info("Running PS code.") 132 server.join() 133 elif job_type == "worker": 134 logging.info("Running Worker code.") 135 # The worker just blocks because we let the master assign all ops. 136 server.join() 137 elif job_type in ["master", "chief"] or not job_type: 138 logging.info("Running master/chief.") 139 with tf.device(device_func): 140 run(server=server, cluster_spec=cluster_spec) 141 else: 142 raise ValueError("invalid job_type %s" % (job_type,)) 143 144 145 if __name__ == "__main__": 146 logging.getLogger().setLevel(logging.INFO) 147 main()