github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/preprocess.py (about) 1 # Copyright 2019 Google LLC. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # https://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Preprocessor applying tf.transform to the chicago_taxi data.""" 16 # pytype: skip-file 17 18 import argparse 19 import os 20 21 import tensorflow as tf 22 import tensorflow_transform as transform 23 import tensorflow_transform.beam as tft_beam 24 from tensorflow_transform.coders import example_proto_coder 25 from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils 26 27 import apache_beam as beam 28 from apache_beam.io.gcp.bigquery import ReadFromBigQuery 29 from apache_beam.metrics.metric import MetricsFilter 30 from apache_beam.testing.load_tests.load_test_metrics_utils import ( 31 MeasureTime, MetricsReader) 32 from trainer import taxi 33 34 35 def _fill_in_missing(x): 36 """Replace missing values in a SparseTensor. 37 38 Fills in missing values of `x` with '' or 0, and converts to a dense tensor. 39 40 Args: 41 x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 42 in the second dimension. 43 44 Returns: 45 A rank 1 tensor where missing values of `x` have been filled in. 46 """ 47 default_value = '' if x.dtype == tf.string else 0 48 return tf.squeeze( 49 tf.sparse.to_dense( 50 tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), 51 default_value), 52 axis=1) 53 54 55 def transform_data( 56 input_handle, 57 outfile_prefix, 58 working_dir, 59 schema_file, 60 transform_dir=None, 61 max_rows=None, 62 pipeline_args=None, 63 publish_to_bq=False, 64 project=None, 65 metrics_table=None, 66 metrics_dataset=None): 67 """The main tf.transform method which analyzes and transforms data. 68 69 Args: 70 input_handle: BigQuery table name to process specified as DATASET.TABLE or 71 path to csv file with input data. 72 outfile_prefix: Filename prefix for emitted transformed examples 73 working_dir: Directory in which transformed examples and transform function 74 will be emitted. 75 schema_file: An file path that contains a text-serialized TensorFlow 76 metadata schema of the input data. 77 transform_dir: Directory in which the transform output is located. If 78 provided, this will load the transform_fn from disk instead of computing 79 it over the data. Hint: this is useful for transforming eval data. 80 max_rows: Number of rows to query from BigQuery 81 pipeline_args: additional DataflowRunner or DirectRunner args passed to the 82 beam pipeline. 83 """ 84 def preprocessing_fn(inputs): 85 """tf.transform's callback function for preprocessing inputs. 86 87 Args: 88 inputs: map from feature keys to raw not-yet-transformed features. 89 90 Returns: 91 Map from string feature key to transformed feature operations. 92 """ 93 outputs = {} 94 for key in taxi.DENSE_FLOAT_FEATURE_KEYS: 95 # Preserve this feature as a dense float, setting nan's to the mean. 96 outputs[taxi.transformed_name(key)] = transform.scale_to_z_score( 97 _fill_in_missing(inputs[key])) 98 99 for key in taxi.VOCAB_FEATURE_KEYS: 100 # Build a vocabulary for this feature. 101 outputs[taxi.transformed_name( 102 key)] = transform.compute_and_apply_vocabulary( 103 _fill_in_missing(inputs[key]), 104 top_k=taxi.VOCAB_SIZE, 105 num_oov_buckets=taxi.OOV_SIZE) 106 107 for key in taxi.BUCKET_FEATURE_KEYS: 108 outputs[taxi.transformed_name(key)] = transform.bucketize( 109 _fill_in_missing(inputs[key]), taxi.FEATURE_BUCKET_COUNT) 110 111 for key in taxi.CATEGORICAL_FEATURE_KEYS: 112 outputs[taxi.transformed_name(key)] = _fill_in_missing(inputs[key]) 113 114 # Was this passenger a big tipper? 115 taxi_fare = _fill_in_missing(inputs[taxi.FARE_KEY]) 116 tips = _fill_in_missing(inputs[taxi.LABEL_KEY]) 117 outputs[taxi.transformed_name(taxi.LABEL_KEY)] = tf.where( 118 tf.is_nan(taxi_fare), 119 tf.cast(tf.zeros_like(taxi_fare), tf.int64), 120 # Test if the tip was > 20% of the fare. 121 tf.cast( 122 tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), 123 tf.int64)) 124 125 return outputs 126 127 namespace = metrics_table 128 metrics_monitor = None 129 if publish_to_bq: 130 metrics_monitor = MetricsReader( 131 publish_to_bq=publish_to_bq, 132 project_name=project, 133 bq_table=metrics_table, 134 bq_dataset=metrics_dataset, 135 namespace=namespace, 136 filters=MetricsFilter().with_namespace(namespace)) 137 schema = taxi.read_schema(schema_file) 138 raw_feature_spec = taxi.get_raw_feature_spec(schema) 139 raw_schema = schema_utils.schema_from_feature_spec(raw_feature_spec) 140 raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema) 141 142 pipeline = beam.Pipeline(argv=pipeline_args) 143 with tft_beam.Context(temp_dir=working_dir): 144 query = taxi.make_sql(input_handle, max_rows, for_eval=False) 145 raw_data = ( 146 pipeline 147 | 'ReadBigQuery' >> ReadFromBigQuery( 148 query=query, project=project, use_standard_sql=True) 149 | 'Measure time: start' >> beam.ParDo(MeasureTime(namespace))) 150 decode_transform = beam.Map( 151 taxi.clean_raw_data_dict, raw_feature_spec=raw_feature_spec) 152 153 if transform_dir is None: 154 decoded_data = raw_data | 'DecodeForAnalyze' >> decode_transform 155 transform_fn = ((decoded_data, raw_data_metadata) | 156 ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn))) 157 158 _ = ( 159 transform_fn | 160 ('WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir))) 161 else: 162 transform_fn = pipeline | tft_beam.ReadTransformFn(transform_dir) 163 164 # Shuffling the data before materialization will improve Training 165 # effectiveness downstream. Here we shuffle the raw_data (as opposed to 166 # decoded data) since it has a compact representation. 167 shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle() 168 169 decoded_data = shuffled_data | 'DecodeForTransform' >> decode_transform 170 (transformed_data, 171 transformed_metadata) = (((decoded_data, raw_data_metadata), transform_fn) 172 | 'Transform' >> tft_beam.TransformDataset()) 173 174 coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema) 175 _ = ( 176 transformed_data 177 | 'SerializeExamples' >> beam.Map(coder.encode) 178 | 'Measure time: end' >> beam.ParDo(MeasureTime(namespace)) 179 | 'WriteExamples' >> beam.io.WriteToTFRecord( 180 os.path.join(working_dir, outfile_prefix), file_name_suffix='.gz')) 181 result = pipeline.run() 182 result.wait_until_finish() 183 if metrics_monitor: 184 metrics_monitor.publish_metrics(result) 185 186 187 def main(): 188 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 189 parser = argparse.ArgumentParser() 190 parser.add_argument( 191 '--input', 192 help=('Input BigQuery table to process specified as: ' 193 'DATASET.TABLE')) 194 195 parser.add_argument( 196 '--schema_file', help='File holding the schema for the input data') 197 198 parser.add_argument( 199 '--output_dir', 200 help=( 201 'Directory in which transformed examples and function ' 202 'will be emitted.')) 203 204 parser.add_argument( 205 '--outfile_prefix', 206 help='Filename prefix for emitted transformed examples') 207 208 parser.add_argument( 209 '--transform_dir', 210 required=False, 211 default=None, 212 help='Directory in which the transform output is located') 213 214 parser.add_argument( 215 '--max_rows', 216 help='Number of rows to query from BigQuery', 217 default=None, 218 type=int) 219 parser.add_argument( 220 '--publish_to_big_query', 221 help='Whether to publish to BQ', 222 default=None, 223 type=bool) 224 225 parser.add_argument( 226 '--metrics_dataset', help='BQ dataset', default=None, type=str) 227 228 parser.add_argument( 229 '--metrics_table', help='BQ table', default=None, type=str) 230 231 parser.add_argument( 232 '--metric_reporting_project', 233 help='BQ table project', 234 default=None, 235 type=str) 236 237 known_args, pipeline_args = parser.parse_known_args() 238 transform_data( 239 input_handle=known_args.input, 240 outfile_prefix=known_args.outfile_prefix, 241 working_dir=known_args.output_dir, 242 schema_file=known_args.schema_file, 243 transform_dir=known_args.transform_dir, 244 max_rows=known_args.max_rows, 245 pipeline_args=pipeline_args, 246 publish_to_bq=known_args.publish_to_big_query, 247 metrics_dataset=known_args.metrics_dataset, 248 metrics_table=known_args.metrics_table, 249 project=known_args.metric_reporting_project) 250 251 252 if __name__ == '__main__': 253 main()