github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/per_entity_training.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline to demonstrate per-entity training.
    19  
    20  This pipeline reads data from a CSV file, that contains information
    21  about 15 different attributes like salary >=50k, education level,
    22  native country, age, occupation and others. The pipeline does some filtering
    23  by selecting certain education level, discarding missing values and empty rows.
    24  The pipeline then groups the rows based on education level and
    25  trains Decision Trees for each group and finally saves them.
    26  """
    27  
    28  import argparse
    29  import logging
    30  import pickle
    31  
    32  import pandas as pd
    33  from sklearn.compose import ColumnTransformer
    34  from sklearn.pipeline import Pipeline
    35  from sklearn.preprocessing import LabelEncoder
    36  from sklearn.preprocessing import MinMaxScaler
    37  from sklearn.preprocessing import OneHotEncoder
    38  from sklearn.tree import DecisionTreeClassifier
    39  
    40  import apache_beam as beam
    41  from apache_beam.io import fileio
    42  from apache_beam.options.pipeline_options import PipelineOptions
    43  from apache_beam.options.pipeline_options import SetupOptions
    44  
    45  
    46  class CreateKey(beam.DoFn):
    47    def process(self, element, *args, **kwargs):
    48      # 3rd column of the dataset is Education
    49      idx = 3
    50      key = element.pop(idx)
    51      yield (key, element)
    52  
    53  
    54  def custom_filter(element):
    55    """Discard data point if contains ?,
    56    doesn't have all features, or
    57    doesn't have Bachelors, Masters or a Doctorate Degree"""
    58    return len(element) == 15 and '?' not in element \
    59        and ' Bachelors' in element or ' Masters' in element \
    60        or ' Doctorate' in element
    61  
    62  
    63  class PrepareDataforTraining(beam.DoFn):
    64    """Preprocess data in a format suitable for training."""
    65    def process(self, element, *args, **kwargs):
    66      key, values = element
    67      #Convert to dataframe
    68      df = pd.DataFrame(values)
    69      last_ix = len(df.columns) - 1
    70      X, y = df.drop(last_ix, axis=1), df[last_ix]
    71      # select categorical and numerical features
    72      cat_ix = X.select_dtypes(include=['object', 'bool']).columns
    73      num_ix = X.select_dtypes(include=['int64', 'float64']).columns
    74      # label encode the target variable to have the classes 0 and 1
    75      y = LabelEncoder().fit_transform(y)
    76      yield (X, y, cat_ix, num_ix, key)
    77  
    78  
    79  class TrainModel(beam.DoFn):
    80    """Takes preprocessed data as input,
    81    transforms categorical columns using OneHotEncoder,
    82    normalizes numerical columns and then
    83    fits a decision tree classifier.
    84    """
    85    def process(self, element, *args, **kwargs):
    86      X, y, cat_ix, num_ix, key = element
    87      steps = [('c', OneHotEncoder(handle_unknown='ignore'), cat_ix),
    88               ('n', MinMaxScaler(), num_ix)]
    89      # one hot encode categorical, normalize numerical
    90      ct = ColumnTransformer(steps)
    91      # wrap the model in a pipeline
    92      pipeline = Pipeline(steps=[('t', ct), ('m', DecisionTreeClassifier())])
    93      pipeline.fit(X, y)
    94      yield (key, pipeline)
    95  
    96  
    97  class ModelSink(fileio.FileSink):
    98    def open(self, fh):
    99      self._fh = fh
   100  
   101    def write(self, record):
   102      _, trained_model = record
   103      pickled_model = pickle.dumps(trained_model)
   104      self._fh.write(pickled_model)
   105  
   106    def flush(self):
   107      self._fh.flush()
   108  
   109  
   110  def parse_known_args(argv):
   111    """Parses args for the workflow."""
   112    parser = argparse.ArgumentParser()
   113    parser.add_argument(
   114        '--input',
   115        dest='input',
   116        help='Path to the text file containing sentences.')
   117    parser.add_argument(
   118        '--output-dir',
   119        dest='output',
   120        required=True,
   121        help='Path of directory for saving trained models.')
   122    return parser.parse_known_args(argv)
   123  
   124  
   125  def run(
   126      argv=None,
   127      save_main_session=True,
   128  ):
   129    """
   130    Args:
   131      argv: Command line arguments defined for this example.
   132      save_main_session: Used for internal testing.
   133    """
   134    known_args, pipeline_args = parse_known_args(argv)
   135    pipeline_options = PipelineOptions(pipeline_args)
   136    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   137    with beam.Pipeline(options=pipeline_options) as pipeline:
   138      _ = (
   139          pipeline | "Read Data" >> beam.io.ReadFromText(known_args.input)
   140          | "Split data to make List" >> beam.Map(lambda x: x.split(','))
   141          | "Filter rows" >> beam.Filter(custom_filter)
   142          | "Create Key" >> beam.ParDo(CreateKey())
   143          | "Group by education" >> beam.GroupByKey()
   144          | "Prepare Data" >> beam.ParDo(PrepareDataforTraining())
   145          | "Train Model" >> beam.ParDo(TrainModel())
   146          |
   147          "Save" >> fileio.WriteToFiles(path=known_args.output, sink=ModelSink()))
   148  
   149  
   150  if __name__ == "__main__":
   151    logging.getLogger().setLevel(logging.INFO)
   152    run()