github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/model.py (about) 1 # Copyright 2019 Google LLC. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # https://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Defines the model used to predict who will tip in the Chicago Taxi demo.""" 16 # pytype: skip-file 17 18 import tensorflow as tf 19 import tensorflow_model_analysis as tfma 20 from tensorflow import estimator as tf_estimator 21 from trainer import taxi 22 23 24 def build_estimator(tf_transform_output, config, hidden_units=None): 25 """Build an estimator for predicting the tipping behavior of taxi riders. 26 27 Args: 28 tf_transform_output: A TFTransformOutput. 29 config: tf.contrib.learn.RunConfig defining the runtime environment for the 30 estimator (including model_dir). 31 hidden_units: [int], the layer sizes of the DNN (input layer first) 32 33 Returns: 34 Resulting DNNLinearCombinedClassifier. 35 """ 36 transformed_feature_spec = ( 37 tf_transform_output.transformed_feature_spec().copy()) 38 39 transformed_feature_spec.pop(taxi.transformed_name(taxi.LABEL_KEY)) 40 41 real_valued_columns = [ 42 tf.feature_column.numeric_column(key, shape=()) 43 for key in taxi.transformed_names(taxi.DENSE_FLOAT_FEATURE_KEYS) 44 ] 45 categorical_columns = [ 46 tf.feature_column.categorical_column_with_identity( 47 key, num_buckets=taxi.VOCAB_SIZE + taxi.OOV_SIZE, default_value=0) 48 for key in taxi.transformed_names(taxi.VOCAB_FEATURE_KEYS) 49 ] 50 categorical_columns += [ 51 tf.feature_column.categorical_column_with_identity( 52 key, num_buckets=taxi.FEATURE_BUCKET_COUNT, default_value=0) 53 for key in taxi.transformed_names(taxi.BUCKET_FEATURE_KEYS) 54 ] 55 categorical_columns += [ 56 tf.feature_column.categorical_column_with_identity( 57 key, num_buckets=num_buckets, default_value=0) 58 for key, num_buckets in zip( 59 taxi.transformed_names(taxi.CATEGORICAL_FEATURE_KEYS), # 60 taxi.MAX_CATEGORICAL_FEATURE_VALUES) 61 ] 62 return tf_estimator.DNNLinearCombinedClassifier( 63 config=config, 64 linear_feature_columns=categorical_columns, 65 dnn_feature_columns=real_valued_columns, 66 dnn_hidden_units=hidden_units or [100, 70, 50, 25]) 67 68 69 def example_serving_receiver_fn(tf_transform_output, schema): 70 """Build the serving in inputs. 71 72 Args: 73 tf_transform_output: A TFTransformOutput. 74 schema: the schema of the input data. 75 76 Returns: 77 Tensorflow graph which parses examples, applying tf-transform to them. 78 """ 79 raw_feature_spec = taxi.get_raw_feature_spec(schema) 80 raw_feature_spec.pop(taxi.LABEL_KEY) 81 82 raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( 83 raw_feature_spec, default_batch_size=None) 84 serving_input_receiver = raw_input_fn() 85 86 transformed_features = tf_transform_output.transform_raw_features( 87 serving_input_receiver.features) 88 89 return tf_estimator.export.ServingInputReceiver( 90 transformed_features, serving_input_receiver.receiver_tensors) 91 92 93 def eval_input_receiver_fn(tf_transform_output, schema): 94 """Build everything needed for the tf-model-analysis to run the model. 95 96 Args: 97 tf_transform_output: A TFTransformOutput. 98 schema: the schema of the input data. 99 100 Returns: 101 EvalInputReceiver function, which contains: 102 - Tensorflow graph which parses raw untransformed features, applies the 103 tf-transform preprocessing operators. 104 - Set of raw, untransformed features. 105 - Label against which predictions will be compared. 106 """ 107 # Notice that the inputs are raw features, not transformed features here. 108 raw_feature_spec = taxi.get_raw_feature_spec(schema) 109 110 serialized_tf_example = tf.placeholder( 111 dtype=tf.string, shape=[None], name='input_example_tensor') 112 113 # Add a parse_example operator to the tensorflow graph, which will parse 114 # raw, untransformed, tf examples. 115 features = tf.parse_example(serialized_tf_example, raw_feature_spec) 116 117 # Now that we have our raw examples, process them through the tf-transform 118 # function computed during the preprocessing step. 119 transformed_features = tf_transform_output.transform_raw_features(features) 120 121 # The key name MUST be 'examples'. 122 receiver_tensors = {'examples': serialized_tf_example} 123 124 # NOTE: Model is driven by transformed features (since training works on the 125 # materialized output of TFT, but slicing will happen on raw features. 126 features.update(transformed_features) 127 128 return tfma.export.EvalInputReceiver( 129 features=features, 130 receiver_tensors=receiver_tensors, 131 labels=transformed_features[taxi.transformed_name(taxi.LABEL_KEY)]) 132 133 134 def _gzip_reader_fn(): 135 """Small utility returning a record reader that can read gzip'ed files.""" 136 return tf.TFRecordReader( 137 options=tf.python_io.TFRecordOptions( 138 compression_type=tf.python_io.TFRecordCompressionType.GZIP)) 139 140 141 def input_fn(filenames, tf_transform_output, batch_size=200): 142 """Generates features and labels for training or evaluation. 143 144 Args: 145 filenames: [str] list of CSV files to read data from. 146 tf_transform_output: A TFTransformOutput. 147 batch_size: int First dimension size of the Tensors returned by input_fn 148 149 Returns: 150 A (features, indices) tuple where features is a dictionary of 151 Tensors, and indices is a single Tensor of label indices. 152 """ 153 transformed_feature_spec = ( 154 tf_transform_output.transformed_feature_spec().copy()) 155 156 transformed_features = tf.contrib.learn.io.read_batch_features( 157 filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn) 158 159 # We pop the label because we do not want to use it as a feature while we're 160 # training. 161 return transformed_features, transformed_features.pop( 162 taxi.transformed_name(taxi.LABEL_KEY))