github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/task.py (about)

     1  # Copyright 2019 Google LLC. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     https://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Trainer for the chicago_taxi demo."""
    16  # pytype: skip-file
    17  
    18  import argparse
    19  import os
    20  
    21  import tensorflow as tf
    22  import tensorflow_model_analysis as tfma
    23  import tensorflow_transform as tft
    24  from tensorflow import estimator as tf_estimator
    25  from trainer import model
    26  from trainer import taxi
    27  
    28  SERVING_MODEL_DIR = 'serving_model_dir'
    29  EVAL_MODEL_DIR = 'eval_model_dir'
    30  
    31  TRAIN_BATCH_SIZE = 40
    32  EVAL_BATCH_SIZE = 40
    33  
    34  # Number of nodes in the first layer of the DNN
    35  FIRST_DNN_LAYER_SIZE = 100
    36  NUM_DNN_LAYERS = 4
    37  DNN_DECAY_FACTOR = 0.7
    38  
    39  
    40  def train_and_maybe_evaluate(hparams):
    41    """Run the training and evaluate using the high level API.
    42  
    43    Args:
    44      hparams: Holds hyperparameters used to train the model as name/value pairs.
    45  
    46    Returns:
    47      The estimator that was used for training (and maybe eval)
    48    """
    49    schema = taxi.read_schema(hparams.schema_file)
    50    tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir)
    51  
    52    train_input = lambda: model.input_fn(
    53        hparams.train_files, tf_transform_output, batch_size=TRAIN_BATCH_SIZE)
    54  
    55    eval_input = lambda: model.input_fn(
    56        hparams.eval_files, tf_transform_output, batch_size=EVAL_BATCH_SIZE)
    57  
    58    train_spec = tf_estimator.TrainSpec(
    59        train_input, max_steps=hparams.train_steps)
    60  
    61    serving_receiver_fn = lambda: model.example_serving_receiver_fn(
    62        tf_transform_output, schema)
    63  
    64    exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
    65    eval_spec = tf_estimator.EvalSpec(
    66        eval_input,
    67        steps=hparams.eval_steps,
    68        exporters=[exporter],
    69        name='chicago-taxi-eval')
    70  
    71    run_config = tf_estimator.RunConfig(
    72        save_checkpoints_steps=999, keep_checkpoint_max=1)
    73  
    74    serving_model_dir = os.path.join(hparams.output_dir, SERVING_MODEL_DIR)
    75    run_config = run_config.replace(model_dir=serving_model_dir)
    76  
    77    estimator = model.build_estimator(
    78        tf_transform_output,
    79  
    80        # Construct layers sizes with exponetial decay
    81        hidden_units=[
    82            max(2, int(FIRST_DNN_LAYER_SIZE * DNN_DECAY_FACTOR**i))
    83            for i in range(NUM_DNN_LAYERS)
    84        ],
    85        config=run_config)
    86  
    87    tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    88  
    89    return estimator
    90  
    91  
    92  def run_experiment(hparams):
    93    """Train the model then export it for tf.model_analysis evaluation.
    94  
    95    Args:
    96      hparams: Holds hyperparameters used to train the model as name/value pairs.
    97    """
    98    estimator = train_and_maybe_evaluate(hparams)
    99  
   100    schema = taxi.read_schema(hparams.schema_file)
   101    tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir)
   102  
   103    # Save a model for tfma eval
   104    eval_model_dir = os.path.join(hparams.output_dir, EVAL_MODEL_DIR)
   105  
   106    receiver_fn = lambda: model.eval_input_receiver_fn(
   107        tf_transform_output, schema)
   108  
   109    tfma.export.export_eval_savedmodel(
   110        estimator=estimator,
   111        export_dir_base=eval_model_dir,
   112        eval_input_receiver_fn=receiver_fn)
   113  
   114  
   115  def main():
   116    parser = argparse.ArgumentParser()
   117    # Input Arguments
   118    parser.add_argument(
   119        '--train-files',
   120        help='GCS or local paths to training data',
   121        nargs='+',
   122        required=True)
   123  
   124    parser.add_argument(
   125        '--tf-transform-dir',
   126        help='Tf-transform directory with model from preprocessing step',
   127        required=True)
   128  
   129    parser.add_argument(
   130        '--output-dir',
   131        help="""\
   132            Directory under which which the serving model (under /serving_model_dir)\
   133            and the tf-mode-analysis model (under /eval_model_dir) will be written\
   134            """,
   135        required=True)
   136  
   137    parser.add_argument(
   138        '--eval-files',
   139        help='GCS or local paths to evaluation data',
   140        nargs='+',
   141        required=True)
   142    # Training arguments
   143    parser.add_argument(
   144        '--job-dir',
   145        help='GCS location to write checkpoints and export models',
   146        required=True)
   147  
   148    # Argument to turn on all logging
   149    parser.add_argument(
   150        '--verbosity',
   151        choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],
   152        default='INFO',
   153    )
   154    # Experiment arguments
   155    parser.add_argument(
   156        '--train-steps',
   157        help='Count of steps to run the training job for',
   158        required=True,
   159        type=int)
   160    parser.add_argument(
   161        '--eval-steps',
   162        help='Number of steps to run evalution for at each checkpoint',
   163        default=100,
   164        type=int)
   165    parser.add_argument(
   166        '--schema-file', help='File holding the schema for the input data')
   167  
   168    args = parser.parse_args()
   169  
   170    # Set python level verbosity
   171    tf.compat.v1.logging.set_verbosity(args.verbosity)
   172    # Set C++ Graph Execution level verbosity
   173    os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(
   174        getattr(tf.compat.v1.logging, args.verbosity) / 10)
   175  
   176    # Run the training job
   177    hparams = tf.contrib.training.HParams(**args.__dict__)
   178    run_experiment(hparams)
   179  
   180  
   181  if __name__ == '__main__':
   182    main()