github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/cloudml/pipelines/workflow.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import argparse
    19  import logging
    20  import os
    21  
    22  import apache_beam as beam
    23  import tensorflow_transform as tft
    24  import tensorflow_transform.beam as tft_beam
    25  from apache_beam.testing.benchmarks.cloudml.criteo_tft import criteo
    26  from tensorflow_transform import coders
    27  from tensorflow_transform.tf_metadata import dataset_metadata
    28  from tensorflow_transform.tf_metadata import schema_utils
    29  from tfx_bsl.public import tfxio
    30  
    31  # Name of the column for the synthetic version of the benchmark.
    32  _SYNTHETIC_COLUMN = 'x'
    33  
    34  
    35  class _RecordBatchToPyDict(beam.PTransform):
    36    """Converts PCollections of pa.RecordBatch to python dicts."""
    37    def __init__(self, input_feature_spec):
    38      self._input_feature_spec = input_feature_spec
    39  
    40    def expand(self, pcoll):
    41      def format_values(instance):
    42        return {
    43            k: v.squeeze(0).tolist()
    44            if v is not None else self._input_feature_spec[k].default_value
    45            for k,
    46            v in instance.items()
    47        }
    48  
    49      return (
    50          pcoll
    51          | 'RecordBatchToDicts' >>
    52          beam.FlatMap(lambda x: x.to_pandas().to_dict(orient='records'))
    53          | 'FormatPyDictValues' >> beam.Map(format_values))
    54  
    55  
    56  def _synthetic_preprocessing_fn(inputs):
    57    return {
    58        _SYNTHETIC_COLUMN: tft.compute_and_apply_vocabulary(
    59            inputs[_SYNTHETIC_COLUMN],
    60  
    61            # Execute more codepaths but do no frequency filtration.
    62            frequency_threshold=1,
    63  
    64            # Execute more codepaths but do no top filtration.
    65            top_k=2**31 - 1,
    66  
    67            # Execute more codepaths
    68            num_oov_buckets=10)
    69    }
    70  
    71  
    72  class _PredictionHistogramFn(beam.DoFn):
    73    def __init__(self):
    74      # Beam Metrics API for Distributions only works with integers but
    75      # predictions are floating point numbers. We thus store a "quantized"
    76      # distribution of the prediction with sufficient granularity and for ease
    77      # of human interpretation (eg as a percentage for logistic regression).
    78      self._prediction_distribution = beam.metrics.Metrics.distribution(
    79          self.__class__, 'int(scores[0]*100)')
    80  
    81    def process(self, element):
    82      self._prediction_distribution.update(int(element['scores'][0] * 100))
    83  
    84  
    85  def setup_pipeline(p, args):
    86    if args.classifier == 'criteo':
    87      input_feature_spec = criteo.make_input_feature_spec()
    88      input_schema = schema_utils.schema_from_feature_spec(input_feature_spec)
    89      input_tfxio = tfxio.BeamRecordCsvTFXIO(
    90          physical_format='text',
    91          column_names=criteo.make_ordered_column_names(),
    92          schema=input_schema,
    93          delimiter=criteo.DEFAULT_DELIMITER,
    94          telemetry_descriptors=['CriteoCloudMLBenchmark'])
    95      preprocessing_fn = criteo.make_preprocessing_fn(args.frequency_threshold)
    96    else:
    97      assert False, 'Unknown args classifier <{}>'.format(args.classifier)
    98  
    99    input_data = p | 'ReadFromText' >> beam.io.textio.ReadFromText(
   100        args.input, coder=beam.coders.BytesCoder())
   101  
   102    if args.benchmark_type == 'tft':
   103      logging.info('TFT benchmark')
   104  
   105      # Setting TFXIO output format only for Criteo benchmarks to make sure that
   106      # both codepaths are covered.
   107      output_record_batches = args.classifier == 'criteo'
   108  
   109      # pylint: disable=expression-not-assigned
   110      input_metadata = dataset_metadata.DatasetMetadata(schema=input_schema)
   111      (
   112          input_metadata
   113          | 'WriteInputMetadata' >> tft_beam.WriteMetadata(
   114              os.path.join(args.output, 'raw_metadata'), pipeline=p))
   115  
   116      with tft_beam.Context(temp_dir=os.path.join(args.output, 'tmp'),
   117                            use_deep_copy_optimization=True):
   118        decoded_input_data = (
   119            input_data | 'DecodeForAnalyze' >> input_tfxio.BeamSource())
   120        transform_fn = ((decoded_input_data, input_tfxio.TensorAdapterConfig())
   121                        | 'Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn))
   122  
   123      if args.shuffle:
   124        # Shuffle the data before any decoding (more compact representation).
   125        input_data |= 'Shuffle' >> beam.transforms.Reshuffle()  # pylint: disable=no-value-for-parameter
   126  
   127      decoded_input_data = (
   128          input_data | 'DecodeForTransform' >> input_tfxio.BeamSource())
   129      (dataset,
   130       metadata) = ((decoded_input_data, input_tfxio.TensorAdapterConfig()),
   131                    transform_fn) | 'Transform' >> tft_beam.TransformDataset(
   132                        output_record_batches=output_record_batches)
   133  
   134      if output_record_batches:
   135  
   136        def record_batch_to_examples(batch, unary_passthrough_features):
   137          """Encodes transformed data as tf.Examples."""
   138          # Ignore unary pass-through features.
   139          del unary_passthrough_features
   140          # From beam: "imports, functions and other variables defined in the
   141          # global context of your __main__ file of your Dataflow pipeline are, by
   142          # default, not available in the worker execution environment, and such
   143          # references will cause a NameError, unless the --save_main_session
   144          # pipeline option is set to True. Please see
   145          # https://cloud.google.com/dataflow/faq#how-do-i-handle-nameerrors ."
   146          from tfx_bsl.coders.example_coder import RecordBatchToExamples
   147          return RecordBatchToExamples(batch)
   148  
   149        encode_ptransform = beam.FlatMapTuple(record_batch_to_examples)
   150      else:
   151        example_coder = coders.ExampleProtoCoder(metadata.schema)
   152        encode_ptransform = beam.Map(example_coder.encode)
   153  
   154      # TODO: Use WriteDataset instead when it becomes available.
   155      (
   156          dataset
   157          | 'Encode' >> encode_ptransform
   158          | 'Write' >> beam.io.WriteToTFRecord(
   159              os.path.join(args.output, 'features_train'),
   160              file_name_suffix='.tfrecord.gz'))
   161      # transform_fn | beam.Map(print)
   162      transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(args.output)
   163  
   164      # TODO: Remember to eventually also save the statistics.
   165    else:
   166      logging.fatal('Unknown benchmark type: %s', args.benchmark_type)
   167  
   168  
   169  def parse_known_args(argv):
   170    """Parses args for this workflow."""
   171    parser = argparse.ArgumentParser()
   172    parser.add_argument(
   173        '--input',
   174        dest='input',
   175        required=True,
   176        help='Input path for input files.')
   177    parser.add_argument(
   178        '--output',
   179        dest='output',
   180        required=True,
   181        help='Output path for output files.')
   182    parser.add_argument(
   183        '--classifier',
   184        dest='classifier',
   185        required=True,
   186        help='Name of classifier to use.')
   187    parser.add_argument(
   188        '--frequency_threshold',
   189        dest='frequency_threshold',
   190        default=5,  # TODO: Align default with TFT (ie 0).
   191        help='Threshold for minimum number of unique values for a category.')
   192    parser.add_argument(
   193        '--shuffle',
   194        action='store_false',
   195        dest='shuffle',
   196        default=True,
   197        help='Skips shuffling the data.')
   198    parser.add_argument(
   199        '--benchmark_type',
   200        dest='benchmark_type',
   201        required=True,
   202        help='Type of benchmark to run.')
   203  
   204    return parser.parse_known_args(argv)
   205  
   206  
   207  def run(argv=None):
   208    """Main entry point; defines and runs the pipeline."""
   209    known_args, pipeline_args = parse_known_args(argv)
   210    with beam.Pipeline(argv=pipeline_args) as p:
   211      setup_pipeline(p, known_args)
   212  
   213  
   214  if __name__ == '__main__':
   215    run()