github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/process_tfma.py (about) 1 # Copyright 2019 Google LLC. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # https://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Runs a batch job for performing Tensorflow Model Analysis.""" 16 17 # pytype: skip-file 18 19 import argparse 20 21 import tensorflow as tf 22 import tensorflow_model_analysis as tfma 23 from tensorflow_model_analysis.evaluators import evaluator 24 25 import apache_beam as beam 26 from apache_beam.io.gcp.bigquery import ReadFromBigQuery 27 from apache_beam.metrics.metric import MetricsFilter 28 from apache_beam.testing.load_tests.load_test_metrics_utils import MeasureTime 29 from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader 30 from trainer import taxi 31 32 33 def process_tfma( 34 schema_file, 35 big_query_table=None, 36 eval_model_dir=None, 37 max_eval_rows=None, 38 pipeline_args=None, 39 publish_to_bq=False, 40 project=None, 41 metrics_table=None, 42 metrics_dataset=None): 43 """Runs a batch job to evaluate the eval_model against the given input. 44 45 Args: 46 schema_file: A file containing a text-serialized Schema that describes the 47 eval data. 48 big_query_table: A BigQuery table name specified as DATASET.TABLE which 49 should be the input for evaluation. This can only be set if input_csv is 50 None. 51 eval_model_dir: A directory where the eval model is located. 52 max_eval_rows: Number of rows to query from BigQuery. 53 pipeline_args: additional DataflowRunner or DirectRunner args passed to 54 the beam pipeline. 55 publish_to_bq: 56 project: 57 metrics_dataset: 58 metrics_table: 59 60 Raises: 61 ValueError: if input_csv and big_query_table are not specified correctly. 62 """ 63 64 if big_query_table is None: 65 raise ValueError('--big_query_table should be provided.') 66 67 slice_spec = [ 68 tfma.slicer.SingleSliceSpec(), 69 tfma.slicer.SingleSliceSpec(columns=['trip_start_hour']) 70 ] 71 metrics_namespace = metrics_table 72 73 schema = taxi.read_schema(schema_file) 74 75 eval_shared_model = tfma.default_eval_shared_model( 76 eval_saved_model_path=eval_model_dir, 77 add_metrics_callbacks=[ 78 tfma.post_export_metrics.calibration_plot_and_prediction_histogram(), 79 tfma.post_export_metrics.auc_plots() 80 ]) 81 82 metrics_monitor = None 83 if publish_to_bq: 84 metrics_monitor = MetricsReader( 85 publish_to_bq=publish_to_bq, 86 project_name=project, 87 bq_table=metrics_table, 88 bq_dataset=metrics_dataset, 89 namespace=metrics_namespace, 90 filters=MetricsFilter().with_namespace(metrics_namespace)) 91 92 pipeline = beam.Pipeline(argv=pipeline_args) 93 94 query = taxi.make_sql(big_query_table, max_eval_rows, for_eval=True) 95 raw_feature_spec = taxi.get_raw_feature_spec(schema) 96 raw_data = ( 97 pipeline 98 | 'ReadBigQuery' >> ReadFromBigQuery( 99 query=query, project=project, use_standard_sql=True) 100 | 'Measure time: Start' >> beam.ParDo(MeasureTime(metrics_namespace)) 101 | 'CleanData' >> 102 beam.Map(lambda x: (taxi.clean_raw_data_dict(x, raw_feature_spec)))) 103 104 # Examples must be in clean tf-example format. 105 coder = taxi.make_proto_coder(schema) 106 # Prepare arguments for Extract, Evaluate and Write steps 107 extractors = tfma.default_extractors( 108 eval_shared_model=eval_shared_model, 109 slice_spec=slice_spec, 110 desired_batch_size=None, 111 materialize=False) 112 113 evaluators = tfma.default_evaluators( 114 eval_shared_model=eval_shared_model, 115 desired_batch_size=None, 116 num_bootstrap_samples=1) 117 _ = ( 118 raw_data 119 | 'ToSerializedTFExample' >> beam.Map(coder.encode) 120 | 'Extract Results' >> tfma.InputsToExtracts() 121 | 'Extract and evaluate' >> tfma.ExtractAndEvaluate( 122 extractors=extractors, evaluators=evaluators) 123 | 'Map Evaluations to PCollection' >> MapEvalToPCollection() 124 | 'Measure time: End' >> beam.ParDo(MeasureTime(metrics_namespace))) 125 result = pipeline.run() 126 result.wait_until_finish() 127 if metrics_monitor: 128 metrics_monitor.publish_metrics(result) 129 130 131 @beam.ptransform_fn 132 @beam.typehints.with_input_types(evaluator.Evaluation) 133 @beam.typehints.with_output_types(beam.typehints.Any) 134 def MapEvalToPCollection( # pylint: disable=invalid-name 135 evaluation): 136 return evaluation['metrics'] 137 138 139 def main(): 140 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 141 142 parser = argparse.ArgumentParser() 143 144 parser.add_argument( 145 '--eval_model_dir', 146 help='Input path to the model which will be evaluated.') 147 parser.add_argument( 148 '--big_query_table', 149 help='BigQuery path to input examples which will be evaluated.') 150 parser.add_argument( 151 '--max_eval_rows', 152 help='Maximum number of rows to evaluate on.', 153 default=None, 154 type=int) 155 parser.add_argument( 156 '--schema_file', help='File holding the schema for the input data') 157 parser.add_argument( 158 '--publish_to_big_query', 159 help='Whether to publish to BQ', 160 default=None, 161 type=bool) 162 parser.add_argument( 163 '--metrics_dataset', help='BQ dataset', default=None, type=str) 164 parser.add_argument( 165 '--metrics_table', 166 help='BQ table for storing metrics', 167 default=None, 168 type=str) 169 parser.add_argument( 170 '--metric_reporting_project', 171 help='BQ table project', 172 default=None, 173 type=str) 174 175 known_args, pipeline_args = parser.parse_known_args() 176 177 process_tfma( 178 big_query_table=known_args.big_query_table, 179 eval_model_dir=known_args.eval_model_dir, 180 max_eval_rows=known_args.max_eval_rows, 181 schema_file=known_args.schema_file, 182 pipeline_args=pipeline_args, 183 publish_to_bq=known_args.publish_to_big_query, 184 metrics_table=known_args.metrics_table, 185 metrics_dataset=known_args.metrics_dataset, 186 project=known_args.metric_reporting_project) 187 188 189 if __name__ == '__main__': 190 main()