github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/task.py (about) 1 # Copyright 2019 Google LLC. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # https://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Trainer for the chicago_taxi demo.""" 16 # pytype: skip-file 17 18 import argparse 19 import os 20 21 import tensorflow as tf 22 import tensorflow_model_analysis as tfma 23 import tensorflow_transform as tft 24 from tensorflow import estimator as tf_estimator 25 from trainer import model 26 from trainer import taxi 27 28 SERVING_MODEL_DIR = 'serving_model_dir' 29 EVAL_MODEL_DIR = 'eval_model_dir' 30 31 TRAIN_BATCH_SIZE = 40 32 EVAL_BATCH_SIZE = 40 33 34 # Number of nodes in the first layer of the DNN 35 FIRST_DNN_LAYER_SIZE = 100 36 NUM_DNN_LAYERS = 4 37 DNN_DECAY_FACTOR = 0.7 38 39 40 def train_and_maybe_evaluate(hparams): 41 """Run the training and evaluate using the high level API. 42 43 Args: 44 hparams: Holds hyperparameters used to train the model as name/value pairs. 45 46 Returns: 47 The estimator that was used for training (and maybe eval) 48 """ 49 schema = taxi.read_schema(hparams.schema_file) 50 tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir) 51 52 train_input = lambda: model.input_fn( 53 hparams.train_files, tf_transform_output, batch_size=TRAIN_BATCH_SIZE) 54 55 eval_input = lambda: model.input_fn( 56 hparams.eval_files, tf_transform_output, batch_size=EVAL_BATCH_SIZE) 57 58 train_spec = tf_estimator.TrainSpec( 59 train_input, max_steps=hparams.train_steps) 60 61 serving_receiver_fn = lambda: model.example_serving_receiver_fn( 62 tf_transform_output, schema) 63 64 exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) 65 eval_spec = tf_estimator.EvalSpec( 66 eval_input, 67 steps=hparams.eval_steps, 68 exporters=[exporter], 69 name='chicago-taxi-eval') 70 71 run_config = tf_estimator.RunConfig( 72 save_checkpoints_steps=999, keep_checkpoint_max=1) 73 74 serving_model_dir = os.path.join(hparams.output_dir, SERVING_MODEL_DIR) 75 run_config = run_config.replace(model_dir=serving_model_dir) 76 77 estimator = model.build_estimator( 78 tf_transform_output, 79 80 # Construct layers sizes with exponetial decay 81 hidden_units=[ 82 max(2, int(FIRST_DNN_LAYER_SIZE * DNN_DECAY_FACTOR**i)) 83 for i in range(NUM_DNN_LAYERS) 84 ], 85 config=run_config) 86 87 tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec) 88 89 return estimator 90 91 92 def run_experiment(hparams): 93 """Train the model then export it for tf.model_analysis evaluation. 94 95 Args: 96 hparams: Holds hyperparameters used to train the model as name/value pairs. 97 """ 98 estimator = train_and_maybe_evaluate(hparams) 99 100 schema = taxi.read_schema(hparams.schema_file) 101 tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir) 102 103 # Save a model for tfma eval 104 eval_model_dir = os.path.join(hparams.output_dir, EVAL_MODEL_DIR) 105 106 receiver_fn = lambda: model.eval_input_receiver_fn( 107 tf_transform_output, schema) 108 109 tfma.export.export_eval_savedmodel( 110 estimator=estimator, 111 export_dir_base=eval_model_dir, 112 eval_input_receiver_fn=receiver_fn) 113 114 115 def main(): 116 parser = argparse.ArgumentParser() 117 # Input Arguments 118 parser.add_argument( 119 '--train-files', 120 help='GCS or local paths to training data', 121 nargs='+', 122 required=True) 123 124 parser.add_argument( 125 '--tf-transform-dir', 126 help='Tf-transform directory with model from preprocessing step', 127 required=True) 128 129 parser.add_argument( 130 '--output-dir', 131 help="""\ 132 Directory under which which the serving model (under /serving_model_dir)\ 133 and the tf-mode-analysis model (under /eval_model_dir) will be written\ 134 """, 135 required=True) 136 137 parser.add_argument( 138 '--eval-files', 139 help='GCS or local paths to evaluation data', 140 nargs='+', 141 required=True) 142 # Training arguments 143 parser.add_argument( 144 '--job-dir', 145 help='GCS location to write checkpoints and export models', 146 required=True) 147 148 # Argument to turn on all logging 149 parser.add_argument( 150 '--verbosity', 151 choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'], 152 default='INFO', 153 ) 154 # Experiment arguments 155 parser.add_argument( 156 '--train-steps', 157 help='Count of steps to run the training job for', 158 required=True, 159 type=int) 160 parser.add_argument( 161 '--eval-steps', 162 help='Number of steps to run evalution for at each checkpoint', 163 default=100, 164 type=int) 165 parser.add_argument( 166 '--schema-file', help='File holding the schema for the input data') 167 168 args = parser.parse_args() 169 170 # Set python level verbosity 171 tf.compat.v1.logging.set_verbosity(args.verbosity) 172 # Set C++ Graph Execution level verbosity 173 os.environ['TF_CPP_MIN_LOG_LEVEL'] = str( 174 getattr(tf.compat.v1.logging, args.verbosity) / 10) 175 176 # Run the training job 177 hparams = tf.contrib.training.HParams(**args.__dict__) 178 run_experiment(hparams) 179 180 181 if __name__ == '__main__': 182 main()