github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/process_tfma.py (about)

     1  # Copyright 2019 Google LLC. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     https://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Runs a batch job for performing Tensorflow Model Analysis."""
    16  
    17  # pytype: skip-file
    18  
    19  import argparse
    20  
    21  import tensorflow as tf
    22  import tensorflow_model_analysis as tfma
    23  from tensorflow_model_analysis.evaluators import evaluator
    24  
    25  import apache_beam as beam
    26  from apache_beam.io.gcp.bigquery import ReadFromBigQuery
    27  from apache_beam.metrics.metric import MetricsFilter
    28  from apache_beam.testing.load_tests.load_test_metrics_utils import MeasureTime
    29  from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader
    30  from trainer import taxi
    31  
    32  
    33  def process_tfma(
    34      schema_file,
    35      big_query_table=None,
    36      eval_model_dir=None,
    37      max_eval_rows=None,
    38      pipeline_args=None,
    39      publish_to_bq=False,
    40      project=None,
    41      metrics_table=None,
    42      metrics_dataset=None):
    43    """Runs a batch job to evaluate the eval_model against the given input.
    44  
    45    Args:
    46    schema_file: A file containing a text-serialized Schema that describes the
    47        eval data.
    48    big_query_table: A BigQuery table name specified as DATASET.TABLE which
    49        should be the input for evaluation. This can only be set if input_csv is
    50        None.
    51    eval_model_dir: A directory where the eval model is located.
    52    max_eval_rows: Number of rows to query from BigQuery.
    53    pipeline_args: additional DataflowRunner or DirectRunner args passed to
    54    the beam pipeline.
    55    publish_to_bq:
    56    project:
    57    metrics_dataset:
    58    metrics_table:
    59  
    60    Raises:
    61    ValueError: if input_csv and big_query_table are not specified correctly.
    62    """
    63  
    64    if big_query_table is None:
    65      raise ValueError('--big_query_table should be provided.')
    66  
    67    slice_spec = [
    68        tfma.slicer.SingleSliceSpec(),
    69        tfma.slicer.SingleSliceSpec(columns=['trip_start_hour'])
    70    ]
    71    metrics_namespace = metrics_table
    72  
    73    schema = taxi.read_schema(schema_file)
    74  
    75    eval_shared_model = tfma.default_eval_shared_model(
    76        eval_saved_model_path=eval_model_dir,
    77        add_metrics_callbacks=[
    78            tfma.post_export_metrics.calibration_plot_and_prediction_histogram(),
    79            tfma.post_export_metrics.auc_plots()
    80        ])
    81  
    82    metrics_monitor = None
    83    if publish_to_bq:
    84      metrics_monitor = MetricsReader(
    85          publish_to_bq=publish_to_bq,
    86          project_name=project,
    87          bq_table=metrics_table,
    88          bq_dataset=metrics_dataset,
    89          namespace=metrics_namespace,
    90          filters=MetricsFilter().with_namespace(metrics_namespace))
    91  
    92    pipeline = beam.Pipeline(argv=pipeline_args)
    93  
    94    query = taxi.make_sql(big_query_table, max_eval_rows, for_eval=True)
    95    raw_feature_spec = taxi.get_raw_feature_spec(schema)
    96    raw_data = (
    97        pipeline
    98        | 'ReadBigQuery' >> ReadFromBigQuery(
    99            query=query, project=project, use_standard_sql=True)
   100        | 'Measure time: Start' >> beam.ParDo(MeasureTime(metrics_namespace))
   101        | 'CleanData' >>
   102        beam.Map(lambda x: (taxi.clean_raw_data_dict(x, raw_feature_spec))))
   103  
   104    # Examples must be in clean tf-example format.
   105    coder = taxi.make_proto_coder(schema)
   106    # Prepare arguments for Extract, Evaluate and Write steps
   107    extractors = tfma.default_extractors(
   108        eval_shared_model=eval_shared_model,
   109        slice_spec=slice_spec,
   110        desired_batch_size=None,
   111        materialize=False)
   112  
   113    evaluators = tfma.default_evaluators(
   114        eval_shared_model=eval_shared_model,
   115        desired_batch_size=None,
   116        num_bootstrap_samples=1)
   117    _ = (
   118        raw_data
   119        | 'ToSerializedTFExample' >> beam.Map(coder.encode)
   120        | 'Extract Results' >> tfma.InputsToExtracts()
   121        | 'Extract and evaluate' >> tfma.ExtractAndEvaluate(
   122            extractors=extractors, evaluators=evaluators)
   123        | 'Map Evaluations to PCollection' >> MapEvalToPCollection()
   124        | 'Measure time: End' >> beam.ParDo(MeasureTime(metrics_namespace)))
   125    result = pipeline.run()
   126    result.wait_until_finish()
   127    if metrics_monitor:
   128      metrics_monitor.publish_metrics(result)
   129  
   130  
   131  @beam.ptransform_fn
   132  @beam.typehints.with_input_types(evaluator.Evaluation)
   133  @beam.typehints.with_output_types(beam.typehints.Any)
   134  def MapEvalToPCollection(  # pylint: disable=invalid-name
   135      evaluation):
   136    return evaluation['metrics']
   137  
   138  
   139  def main():
   140    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
   141  
   142    parser = argparse.ArgumentParser()
   143  
   144    parser.add_argument(
   145        '--eval_model_dir',
   146        help='Input path to the model which will be evaluated.')
   147    parser.add_argument(
   148        '--big_query_table',
   149        help='BigQuery path to input examples which will be evaluated.')
   150    parser.add_argument(
   151        '--max_eval_rows',
   152        help='Maximum number of rows to evaluate on.',
   153        default=None,
   154        type=int)
   155    parser.add_argument(
   156        '--schema_file', help='File holding the schema for the input data')
   157    parser.add_argument(
   158        '--publish_to_big_query',
   159        help='Whether to publish to BQ',
   160        default=None,
   161        type=bool)
   162    parser.add_argument(
   163        '--metrics_dataset', help='BQ dataset', default=None, type=str)
   164    parser.add_argument(
   165        '--metrics_table',
   166        help='BQ table for storing metrics',
   167        default=None,
   168        type=str)
   169    parser.add_argument(
   170        '--metric_reporting_project',
   171        help='BQ table project',
   172        default=None,
   173        type=str)
   174  
   175    known_args, pipeline_args = parser.parse_known_args()
   176  
   177    process_tfma(
   178        big_query_table=known_args.big_query_table,
   179        eval_model_dir=known_args.eval_model_dir,
   180        max_eval_rows=known_args.max_eval_rows,
   181        schema_file=known_args.schema_file,
   182        pipeline_args=pipeline_args,
   183        publish_to_bq=known_args.publish_to_big_query,
   184        metrics_table=known_args.metrics_table,
   185        metrics_dataset=known_args.metrics_dataset,
   186        project=known_args.metric_reporting_project)
   187  
   188  
   189  if __name__ == '__main__':
   190    main()