github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py (about) 1 # Copyright 2019 Google LLC. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # https://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Utility and schema methods for the chicago_taxi sample.""" 16 # pytype: skip-file 17 18 from tensorflow_transform import coders as tft_coders 19 from tensorflow_transform.tf_metadata import schema_utils 20 21 from google.protobuf import text_format # type: ignore # typeshed out of date 22 from tensorflow.python.lib.io import file_io 23 from tensorflow_metadata.proto.v0 import schema_pb2 24 25 # Categorical features are assumed to each have a maximum value in the dataset. 26 MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] 27 28 CATEGORICAL_FEATURE_KEYS = [ 29 'trip_start_hour', 30 'trip_start_day', 31 'trip_start_month', 32 'pickup_census_tract', 33 'dropoff_census_tract', 34 'pickup_community_area', 35 'dropoff_community_area' 36 ] 37 38 DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] 39 40 # Number of buckets used by tf.transform for encoding each feature. 41 FEATURE_BUCKET_COUNT = 10 42 43 BUCKET_FEATURE_KEYS = [ 44 'pickup_latitude', 45 'pickup_longitude', 46 'dropoff_latitude', 47 'dropoff_longitude' 48 ] 49 50 # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform 51 VOCAB_SIZE = 1000 52 53 # Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed. 54 OOV_SIZE = 10 55 56 VOCAB_FEATURE_KEYS = [ 57 'payment_type', 58 'company', 59 ] 60 61 LABEL_KEY = 'tips' 62 FARE_KEY = 'fare' 63 64 CSV_COLUMN_NAMES = [ 65 'pickup_community_area', 66 'fare', 67 'trip_start_month', 68 'trip_start_hour', 69 'trip_start_day', 70 'trip_start_timestamp', 71 'pickup_latitude', 72 'pickup_longitude', 73 'dropoff_latitude', 74 'dropoff_longitude', 75 'trip_miles', 76 'pickup_census_tract', 77 'dropoff_census_tract', 78 'payment_type', 79 'company', 80 'trip_seconds', 81 'dropoff_community_area', 82 'tips', 83 ] 84 85 86 def transformed_name(key): 87 return key + '_xf' 88 89 90 def transformed_names(keys): 91 return [transformed_name(key) for key in keys] 92 93 94 # Tf.Transform considers these features as "raw" 95 def get_raw_feature_spec(schema): 96 return schema_utils.schema_as_feature_spec(schema).feature_spec 97 98 99 def make_proto_coder(schema): 100 raw_feature_spec = get_raw_feature_spec(schema) 101 raw_schema = schema_utils.schema_from_feature_spec(raw_feature_spec) 102 return tft_coders.ExampleProtoCoder(raw_schema) 103 104 105 def make_csv_coder(schema): 106 """Return a coder for tf.transform to read csv files.""" 107 raw_feature_spec = get_raw_feature_spec(schema) 108 parsing_schema = schema_utils.schema_from_feature_spec(raw_feature_spec) 109 return tft_coders.CsvCoder(CSV_COLUMN_NAMES, parsing_schema) 110 111 112 def clean_raw_data_dict(input_dict, raw_feature_spec): 113 """Clean raw data dict.""" 114 output_dict = {} 115 116 for key in raw_feature_spec: 117 if key not in input_dict or not input_dict[key]: 118 output_dict[key] = [] 119 else: 120 output_dict[key] = [input_dict[key]] 121 return output_dict 122 123 124 def make_sql(table_name, max_rows=None, for_eval=False): 125 """Creates the sql command for pulling data from BigQuery. 126 127 Args: 128 table_name: BigQuery table name 129 max_rows: if set, limits the number of rows pulled from BigQuery 130 for_eval: True if this is for evaluation, false otherwise 131 132 Returns: 133 sql command as string 134 """ 135 if for_eval: 136 # 1/3 of the dataset used for eval 137 where_clause = 'WHERE MOD(FARM_FINGERPRINT(unique_key), 3) = 0 ' \ 138 'AND pickup_latitude is not null AND pickup_longitude ' \ 139 'is not null AND dropoff_latitude is not null ' \ 140 'AND dropoff_longitude is not null' 141 else: 142 # 2/3 of the dataset used for training 143 where_clause = 'WHERE MOD(FARM_FINGERPRINT(unique_key), 3) > 0 ' \ 144 'AND pickup_latitude is not null AND pickup_longitude ' \ 145 'is not null AND dropoff_latitude is not null ' \ 146 'AND dropoff_longitude is not null' 147 148 limit_clause = '' 149 if max_rows: 150 limit_clause = 'LIMIT {max_rows}'.format(max_rows=max_rows) 151 return """ 152 SELECT 153 CAST(pickup_community_area AS string) AS pickup_community_area, 154 CAST(dropoff_community_area AS string) AS dropoff_community_area, 155 CAST(pickup_census_tract AS string) AS pickup_census_tract, 156 CAST(dropoff_census_tract AS string) AS dropoff_census_tract, 157 fare, 158 EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month, 159 EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour, 160 EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day, 161 UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp, 162 pickup_latitude, 163 pickup_longitude, 164 dropoff_latitude, 165 dropoff_longitude, 166 trip_miles, 167 payment_type, 168 company, 169 trip_seconds, 170 tips 171 FROM `{table_name}` 172 {where_clause} 173 {limit_clause} 174 """.format( 175 table_name=table_name, 176 where_clause=where_clause, 177 limit_clause=limit_clause) 178 179 180 def read_schema(path): 181 """Reads a schema from the provided location. 182 183 Args: 184 path: The location of the file holding a serialized Schema proto. 185 186 Returns: 187 An instance of Schema or None if the input argument is None 188 """ 189 result = schema_pb2.Schema() 190 contents = file_io.read_file_to_string(path) 191 text_format.Parse(contents, result) 192 return result