github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/preprocess.py (about)

     1  # Copyright 2019 Google LLC. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     https://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Preprocessor applying tf.transform to the chicago_taxi data."""
    16  # pytype: skip-file
    17  
    18  import argparse
    19  import os
    20  
    21  import tensorflow as tf
    22  import tensorflow_transform as transform
    23  import tensorflow_transform.beam as tft_beam
    24  from tensorflow_transform.coders import example_proto_coder
    25  from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils
    26  
    27  import apache_beam as beam
    28  from apache_beam.io.gcp.bigquery import ReadFromBigQuery
    29  from apache_beam.metrics.metric import MetricsFilter
    30  from apache_beam.testing.load_tests.load_test_metrics_utils import (
    31      MeasureTime, MetricsReader)
    32  from trainer import taxi
    33  
    34  
    35  def _fill_in_missing(x):
    36    """Replace missing values in a SparseTensor.
    37  
    38    Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
    39  
    40    Args:
    41      x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
    42        in the second dimension.
    43  
    44    Returns:
    45      A rank 1 tensor where missing values of `x` have been filled in.
    46    """
    47    default_value = '' if x.dtype == tf.string else 0
    48    return tf.squeeze(
    49        tf.sparse.to_dense(
    50            tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
    51            default_value),
    52        axis=1)
    53  
    54  
    55  def transform_data(
    56      input_handle,
    57      outfile_prefix,
    58      working_dir,
    59      schema_file,
    60      transform_dir=None,
    61      max_rows=None,
    62      pipeline_args=None,
    63      publish_to_bq=False,
    64      project=None,
    65      metrics_table=None,
    66      metrics_dataset=None):
    67    """The main tf.transform method which analyzes and transforms data.
    68  
    69    Args:
    70      input_handle: BigQuery table name to process specified as DATASET.TABLE or
    71        path to csv file with input data.
    72      outfile_prefix: Filename prefix for emitted transformed examples
    73      working_dir: Directory in which transformed examples and transform function
    74        will be emitted.
    75      schema_file: An file path that contains a text-serialized TensorFlow
    76        metadata schema of the input data.
    77      transform_dir: Directory in which the transform output is located. If
    78        provided, this will load the transform_fn from disk instead of computing
    79        it over the data. Hint: this is useful for transforming eval data.
    80      max_rows: Number of rows to query from BigQuery
    81      pipeline_args: additional DataflowRunner or DirectRunner args passed to the
    82        beam pipeline.
    83    """
    84    def preprocessing_fn(inputs):
    85      """tf.transform's callback function for preprocessing inputs.
    86  
    87      Args:
    88        inputs: map from feature keys to raw not-yet-transformed features.
    89  
    90      Returns:
    91        Map from string feature key to transformed feature operations.
    92      """
    93      outputs = {}
    94      for key in taxi.DENSE_FLOAT_FEATURE_KEYS:
    95        # Preserve this feature as a dense float, setting nan's to the mean.
    96        outputs[taxi.transformed_name(key)] = transform.scale_to_z_score(
    97            _fill_in_missing(inputs[key]))
    98  
    99      for key in taxi.VOCAB_FEATURE_KEYS:
   100        # Build a vocabulary for this feature.
   101        outputs[taxi.transformed_name(
   102            key)] = transform.compute_and_apply_vocabulary(
   103                _fill_in_missing(inputs[key]),
   104                top_k=taxi.VOCAB_SIZE,
   105                num_oov_buckets=taxi.OOV_SIZE)
   106  
   107      for key in taxi.BUCKET_FEATURE_KEYS:
   108        outputs[taxi.transformed_name(key)] = transform.bucketize(
   109            _fill_in_missing(inputs[key]), taxi.FEATURE_BUCKET_COUNT)
   110  
   111      for key in taxi.CATEGORICAL_FEATURE_KEYS:
   112        outputs[taxi.transformed_name(key)] = _fill_in_missing(inputs[key])
   113  
   114      # Was this passenger a big tipper?
   115      taxi_fare = _fill_in_missing(inputs[taxi.FARE_KEY])
   116      tips = _fill_in_missing(inputs[taxi.LABEL_KEY])
   117      outputs[taxi.transformed_name(taxi.LABEL_KEY)] = tf.where(
   118          tf.is_nan(taxi_fare),
   119          tf.cast(tf.zeros_like(taxi_fare), tf.int64),
   120          # Test if the tip was > 20% of the fare.
   121          tf.cast(
   122              tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))),
   123              tf.int64))
   124  
   125      return outputs
   126  
   127    namespace = metrics_table
   128    metrics_monitor = None
   129    if publish_to_bq:
   130      metrics_monitor = MetricsReader(
   131          publish_to_bq=publish_to_bq,
   132          project_name=project,
   133          bq_table=metrics_table,
   134          bq_dataset=metrics_dataset,
   135          namespace=namespace,
   136          filters=MetricsFilter().with_namespace(namespace))
   137    schema = taxi.read_schema(schema_file)
   138    raw_feature_spec = taxi.get_raw_feature_spec(schema)
   139    raw_schema = schema_utils.schema_from_feature_spec(raw_feature_spec)
   140    raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)
   141  
   142    pipeline = beam.Pipeline(argv=pipeline_args)
   143    with tft_beam.Context(temp_dir=working_dir):
   144      query = taxi.make_sql(input_handle, max_rows, for_eval=False)
   145      raw_data = (
   146          pipeline
   147          | 'ReadBigQuery' >> ReadFromBigQuery(
   148              query=query, project=project, use_standard_sql=True)
   149          | 'Measure time: start' >> beam.ParDo(MeasureTime(namespace)))
   150      decode_transform = beam.Map(
   151          taxi.clean_raw_data_dict, raw_feature_spec=raw_feature_spec)
   152  
   153      if transform_dir is None:
   154        decoded_data = raw_data | 'DecodeForAnalyze' >> decode_transform
   155        transform_fn = ((decoded_data, raw_data_metadata) |
   156                        ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))
   157  
   158        _ = (
   159            transform_fn |
   160            ('WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir)))
   161      else:
   162        transform_fn = pipeline | tft_beam.ReadTransformFn(transform_dir)
   163  
   164      # Shuffling the data before materialization will improve Training
   165      # effectiveness downstream. Here we shuffle the raw_data (as opposed to
   166      # decoded data) since it has a compact representation.
   167      shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle()
   168  
   169      decoded_data = shuffled_data | 'DecodeForTransform' >> decode_transform
   170      (transformed_data,
   171       transformed_metadata) = (((decoded_data, raw_data_metadata), transform_fn)
   172                                | 'Transform' >> tft_beam.TransformDataset())
   173  
   174      coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema)
   175      _ = (
   176          transformed_data
   177          | 'SerializeExamples' >> beam.Map(coder.encode)
   178          | 'Measure time: end' >> beam.ParDo(MeasureTime(namespace))
   179          | 'WriteExamples' >> beam.io.WriteToTFRecord(
   180              os.path.join(working_dir, outfile_prefix), file_name_suffix='.gz'))
   181    result = pipeline.run()
   182    result.wait_until_finish()
   183    if metrics_monitor:
   184      metrics_monitor.publish_metrics(result)
   185  
   186  
   187  def main():
   188    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
   189    parser = argparse.ArgumentParser()
   190    parser.add_argument(
   191        '--input',
   192        help=('Input BigQuery table to process specified as: '
   193              'DATASET.TABLE'))
   194  
   195    parser.add_argument(
   196        '--schema_file', help='File holding the schema for the input data')
   197  
   198    parser.add_argument(
   199        '--output_dir',
   200        help=(
   201            'Directory in which transformed examples and function '
   202            'will be emitted.'))
   203  
   204    parser.add_argument(
   205        '--outfile_prefix',
   206        help='Filename prefix for emitted transformed examples')
   207  
   208    parser.add_argument(
   209        '--transform_dir',
   210        required=False,
   211        default=None,
   212        help='Directory in which the transform output is located')
   213  
   214    parser.add_argument(
   215        '--max_rows',
   216        help='Number of rows to query from BigQuery',
   217        default=None,
   218        type=int)
   219    parser.add_argument(
   220        '--publish_to_big_query',
   221        help='Whether to publish to BQ',
   222        default=None,
   223        type=bool)
   224  
   225    parser.add_argument(
   226        '--metrics_dataset', help='BQ dataset', default=None, type=str)
   227  
   228    parser.add_argument(
   229        '--metrics_table', help='BQ table', default=None, type=str)
   230  
   231    parser.add_argument(
   232        '--metric_reporting_project',
   233        help='BQ table project',
   234        default=None,
   235        type=str)
   236  
   237    known_args, pipeline_args = parser.parse_known_args()
   238    transform_data(
   239        input_handle=known_args.input,
   240        outfile_prefix=known_args.outfile_prefix,
   241        working_dir=known_args.output_dir,
   242        schema_file=known_args.schema_file,
   243        transform_dir=known_args.transform_dir,
   244        max_rows=known_args.max_rows,
   245        pipeline_args=pipeline_args,
   246        publish_to_bq=known_args.publish_to_big_query,
   247        metrics_dataset=known_args.metrics_dataset,
   248        metrics_table=known_args.metrics_table,
   249        project=known_args.metric_reporting_project)
   250  
   251  
   252  if __name__ == '__main__':
   253    main()