github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py (about)

     1  # Copyright 2019 Google LLC. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     https://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Utility and schema methods for the chicago_taxi sample."""
    16  # pytype: skip-file
    17  
    18  from tensorflow_transform import coders as tft_coders
    19  from tensorflow_transform.tf_metadata import schema_utils
    20  
    21  from google.protobuf import text_format  # type: ignore  # typeshed out of date
    22  from tensorflow.python.lib.io import file_io
    23  from tensorflow_metadata.proto.v0 import schema_pb2
    24  
    25  # Categorical features are assumed to each have a maximum value in the dataset.
    26  MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12]
    27  
    28  CATEGORICAL_FEATURE_KEYS = [
    29      'trip_start_hour',
    30      'trip_start_day',
    31      'trip_start_month',
    32      'pickup_census_tract',
    33      'dropoff_census_tract',
    34      'pickup_community_area',
    35      'dropoff_community_area'
    36  ]
    37  
    38  DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds']
    39  
    40  # Number of buckets used by tf.transform for encoding each feature.
    41  FEATURE_BUCKET_COUNT = 10
    42  
    43  BUCKET_FEATURE_KEYS = [
    44      'pickup_latitude',
    45      'pickup_longitude',
    46      'dropoff_latitude',
    47      'dropoff_longitude'
    48  ]
    49  
    50  # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform
    51  VOCAB_SIZE = 1000
    52  
    53  # Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed.
    54  OOV_SIZE = 10
    55  
    56  VOCAB_FEATURE_KEYS = [
    57      'payment_type',
    58      'company',
    59  ]
    60  
    61  LABEL_KEY = 'tips'
    62  FARE_KEY = 'fare'
    63  
    64  CSV_COLUMN_NAMES = [
    65      'pickup_community_area',
    66      'fare',
    67      'trip_start_month',
    68      'trip_start_hour',
    69      'trip_start_day',
    70      'trip_start_timestamp',
    71      'pickup_latitude',
    72      'pickup_longitude',
    73      'dropoff_latitude',
    74      'dropoff_longitude',
    75      'trip_miles',
    76      'pickup_census_tract',
    77      'dropoff_census_tract',
    78      'payment_type',
    79      'company',
    80      'trip_seconds',
    81      'dropoff_community_area',
    82      'tips',
    83  ]
    84  
    85  
    86  def transformed_name(key):
    87    return key + '_xf'
    88  
    89  
    90  def transformed_names(keys):
    91    return [transformed_name(key) for key in keys]
    92  
    93  
    94  # Tf.Transform considers these features as "raw"
    95  def get_raw_feature_spec(schema):
    96    return schema_utils.schema_as_feature_spec(schema).feature_spec
    97  
    98  
    99  def make_proto_coder(schema):
   100    raw_feature_spec = get_raw_feature_spec(schema)
   101    raw_schema = schema_utils.schema_from_feature_spec(raw_feature_spec)
   102    return tft_coders.ExampleProtoCoder(raw_schema)
   103  
   104  
   105  def make_csv_coder(schema):
   106    """Return a coder for tf.transform to read csv files."""
   107    raw_feature_spec = get_raw_feature_spec(schema)
   108    parsing_schema = schema_utils.schema_from_feature_spec(raw_feature_spec)
   109    return tft_coders.CsvCoder(CSV_COLUMN_NAMES, parsing_schema)
   110  
   111  
   112  def clean_raw_data_dict(input_dict, raw_feature_spec):
   113    """Clean raw data dict."""
   114    output_dict = {}
   115  
   116    for key in raw_feature_spec:
   117      if key not in input_dict or not input_dict[key]:
   118        output_dict[key] = []
   119      else:
   120        output_dict[key] = [input_dict[key]]
   121    return output_dict
   122  
   123  
   124  def make_sql(table_name, max_rows=None, for_eval=False):
   125    """Creates the sql command for pulling data from BigQuery.
   126  
   127    Args:
   128      table_name: BigQuery table name
   129      max_rows: if set, limits the number of rows pulled from BigQuery
   130      for_eval: True if this is for evaluation, false otherwise
   131  
   132    Returns:
   133      sql command as string
   134    """
   135    if for_eval:
   136      # 1/3 of the dataset used for eval
   137      where_clause = 'WHERE MOD(FARM_FINGERPRINT(unique_key), 3) = 0 ' \
   138                     'AND pickup_latitude is not null AND pickup_longitude ' \
   139                     'is not null AND dropoff_latitude is not null ' \
   140                     'AND dropoff_longitude is not null'
   141    else:
   142      # 2/3 of the dataset used for training
   143      where_clause = 'WHERE MOD(FARM_FINGERPRINT(unique_key), 3) > 0 ' \
   144                     'AND pickup_latitude is not null AND pickup_longitude ' \
   145                     'is not null AND dropoff_latitude is not null ' \
   146                     'AND dropoff_longitude is not null'
   147  
   148    limit_clause = ''
   149    if max_rows:
   150      limit_clause = 'LIMIT {max_rows}'.format(max_rows=max_rows)
   151    return """
   152    SELECT
   153        CAST(pickup_community_area AS string) AS pickup_community_area,
   154        CAST(dropoff_community_area AS string) AS dropoff_community_area,
   155        CAST(pickup_census_tract AS string) AS pickup_census_tract,
   156        CAST(dropoff_census_tract AS string) AS dropoff_census_tract,
   157        fare,
   158        EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month,
   159        EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour,
   160        EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day,
   161        UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp,
   162        pickup_latitude,
   163        pickup_longitude,
   164        dropoff_latitude,
   165        dropoff_longitude,
   166        trip_miles,
   167        payment_type,
   168        company,
   169        trip_seconds,
   170        tips
   171    FROM `{table_name}`
   172    {where_clause}
   173    {limit_clause}
   174  """.format(
   175        table_name=table_name,
   176        where_clause=where_clause,
   177        limit_clause=limit_clause)
   178  
   179  
   180  def read_schema(path):
   181    """Reads a schema from the provided location.
   182  
   183    Args:
   184      path: The location of the file holding a serialized Schema proto.
   185  
   186    Returns:
   187      An instance of Schema or None if the input argument is None
   188    """
   189    result = schema_pb2.Schema()
   190    contents = file_io.read_file_to_string(path)
   191    text_format.Parse(contents, result)
   192    return result