github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/model.py (about)

     1  # Copyright 2019 Google LLC. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     https://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Defines the model used to predict who will tip in the Chicago Taxi demo."""
    16  # pytype: skip-file
    17  
    18  import tensorflow as tf
    19  import tensorflow_model_analysis as tfma
    20  from tensorflow import estimator as tf_estimator
    21  from trainer import taxi
    22  
    23  
    24  def build_estimator(tf_transform_output, config, hidden_units=None):
    25    """Build an estimator for predicting the tipping behavior of taxi riders.
    26  
    27    Args:
    28      tf_transform_output: A TFTransformOutput.
    29      config: tf.contrib.learn.RunConfig defining the runtime environment for the
    30        estimator (including model_dir).
    31      hidden_units: [int], the layer sizes of the DNN (input layer first)
    32  
    33    Returns:
    34      Resulting DNNLinearCombinedClassifier.
    35    """
    36    transformed_feature_spec = (
    37        tf_transform_output.transformed_feature_spec().copy())
    38  
    39    transformed_feature_spec.pop(taxi.transformed_name(taxi.LABEL_KEY))
    40  
    41    real_valued_columns = [
    42        tf.feature_column.numeric_column(key, shape=())
    43        for key in taxi.transformed_names(taxi.DENSE_FLOAT_FEATURE_KEYS)
    44    ]
    45    categorical_columns = [
    46        tf.feature_column.categorical_column_with_identity(
    47            key, num_buckets=taxi.VOCAB_SIZE + taxi.OOV_SIZE, default_value=0)
    48        for key in taxi.transformed_names(taxi.VOCAB_FEATURE_KEYS)
    49    ]
    50    categorical_columns += [
    51        tf.feature_column.categorical_column_with_identity(
    52            key, num_buckets=taxi.FEATURE_BUCKET_COUNT, default_value=0)
    53        for key in taxi.transformed_names(taxi.BUCKET_FEATURE_KEYS)
    54    ]
    55    categorical_columns += [
    56        tf.feature_column.categorical_column_with_identity(
    57            key, num_buckets=num_buckets, default_value=0)
    58        for key, num_buckets in zip(
    59            taxi.transformed_names(taxi.CATEGORICAL_FEATURE_KEYS),  #
    60            taxi.MAX_CATEGORICAL_FEATURE_VALUES)
    61    ]
    62    return tf_estimator.DNNLinearCombinedClassifier(
    63        config=config,
    64        linear_feature_columns=categorical_columns,
    65        dnn_feature_columns=real_valued_columns,
    66        dnn_hidden_units=hidden_units or [100, 70, 50, 25])
    67  
    68  
    69  def example_serving_receiver_fn(tf_transform_output, schema):
    70    """Build the serving in inputs.
    71  
    72    Args:
    73      tf_transform_output: A TFTransformOutput.
    74      schema: the schema of the input data.
    75  
    76    Returns:
    77      Tensorflow graph which parses examples, applying tf-transform to them.
    78    """
    79    raw_feature_spec = taxi.get_raw_feature_spec(schema)
    80    raw_feature_spec.pop(taxi.LABEL_KEY)
    81  
    82    raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn(
    83        raw_feature_spec, default_batch_size=None)
    84    serving_input_receiver = raw_input_fn()
    85  
    86    transformed_features = tf_transform_output.transform_raw_features(
    87        serving_input_receiver.features)
    88  
    89    return tf_estimator.export.ServingInputReceiver(
    90        transformed_features, serving_input_receiver.receiver_tensors)
    91  
    92  
    93  def eval_input_receiver_fn(tf_transform_output, schema):
    94    """Build everything needed for the tf-model-analysis to run the model.
    95  
    96    Args:
    97      tf_transform_output: A TFTransformOutput.
    98      schema: the schema of the input data.
    99  
   100    Returns:
   101      EvalInputReceiver function, which contains:
   102        - Tensorflow graph which parses raw untransformed features, applies the
   103          tf-transform preprocessing operators.
   104        - Set of raw, untransformed features.
   105        - Label against which predictions will be compared.
   106    """
   107    # Notice that the inputs are raw features, not transformed features here.
   108    raw_feature_spec = taxi.get_raw_feature_spec(schema)
   109  
   110    serialized_tf_example = tf.placeholder(
   111        dtype=tf.string, shape=[None], name='input_example_tensor')
   112  
   113    # Add a parse_example operator to the tensorflow graph, which will parse
   114    # raw, untransformed, tf examples.
   115    features = tf.parse_example(serialized_tf_example, raw_feature_spec)
   116  
   117    # Now that we have our raw examples, process them through the tf-transform
   118    # function computed during the preprocessing step.
   119    transformed_features = tf_transform_output.transform_raw_features(features)
   120  
   121    # The key name MUST be 'examples'.
   122    receiver_tensors = {'examples': serialized_tf_example}
   123  
   124    # NOTE: Model is driven by transformed features (since training works on the
   125    # materialized output of TFT, but slicing will happen on raw features.
   126    features.update(transformed_features)
   127  
   128    return tfma.export.EvalInputReceiver(
   129        features=features,
   130        receiver_tensors=receiver_tensors,
   131        labels=transformed_features[taxi.transformed_name(taxi.LABEL_KEY)])
   132  
   133  
   134  def _gzip_reader_fn():
   135    """Small utility returning a record reader that can read gzip'ed files."""
   136    return tf.TFRecordReader(
   137        options=tf.python_io.TFRecordOptions(
   138            compression_type=tf.python_io.TFRecordCompressionType.GZIP))
   139  
   140  
   141  def input_fn(filenames, tf_transform_output, batch_size=200):
   142    """Generates features and labels for training or evaluation.
   143  
   144    Args:
   145      filenames: [str] list of CSV files to read data from.
   146      tf_transform_output: A TFTransformOutput.
   147      batch_size: int First dimension size of the Tensors returned by input_fn
   148  
   149    Returns:
   150      A (features, indices) tuple where features is a dictionary of
   151        Tensors, and indices is a single Tensor of label indices.
   152    """
   153    transformed_feature_spec = (
   154        tf_transform_output.transformed_feature_spec().copy())
   155  
   156    transformed_features = tf.contrib.learn.io.read_batch_features(
   157        filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)
   158  
   159    # We pop the label because we do not want to use it as a feature while we're
   160    # training.
   161    return transformed_features, transformed_features.pop(
   162        taxi.transformed_name(taxi.LABEL_KEY))