github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/per_entity_training.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A pipeline to demonstrate per-entity training. 19 20 This pipeline reads data from a CSV file, that contains information 21 about 15 different attributes like salary >=50k, education level, 22 native country, age, occupation and others. The pipeline does some filtering 23 by selecting certain education level, discarding missing values and empty rows. 24 The pipeline then groups the rows based on education level and 25 trains Decision Trees for each group and finally saves them. 26 """ 27 28 import argparse 29 import logging 30 import pickle 31 32 import pandas as pd 33 from sklearn.compose import ColumnTransformer 34 from sklearn.pipeline import Pipeline 35 from sklearn.preprocessing import LabelEncoder 36 from sklearn.preprocessing import MinMaxScaler 37 from sklearn.preprocessing import OneHotEncoder 38 from sklearn.tree import DecisionTreeClassifier 39 40 import apache_beam as beam 41 from apache_beam.io import fileio 42 from apache_beam.options.pipeline_options import PipelineOptions 43 from apache_beam.options.pipeline_options import SetupOptions 44 45 46 class CreateKey(beam.DoFn): 47 def process(self, element, *args, **kwargs): 48 # 3rd column of the dataset is Education 49 idx = 3 50 key = element.pop(idx) 51 yield (key, element) 52 53 54 def custom_filter(element): 55 """Discard data point if contains ?, 56 doesn't have all features, or 57 doesn't have Bachelors, Masters or a Doctorate Degree""" 58 return len(element) == 15 and '?' not in element \ 59 and ' Bachelors' in element or ' Masters' in element \ 60 or ' Doctorate' in element 61 62 63 class PrepareDataforTraining(beam.DoFn): 64 """Preprocess data in a format suitable for training.""" 65 def process(self, element, *args, **kwargs): 66 key, values = element 67 #Convert to dataframe 68 df = pd.DataFrame(values) 69 last_ix = len(df.columns) - 1 70 X, y = df.drop(last_ix, axis=1), df[last_ix] 71 # select categorical and numerical features 72 cat_ix = X.select_dtypes(include=['object', 'bool']).columns 73 num_ix = X.select_dtypes(include=['int64', 'float64']).columns 74 # label encode the target variable to have the classes 0 and 1 75 y = LabelEncoder().fit_transform(y) 76 yield (X, y, cat_ix, num_ix, key) 77 78 79 class TrainModel(beam.DoFn): 80 """Takes preprocessed data as input, 81 transforms categorical columns using OneHotEncoder, 82 normalizes numerical columns and then 83 fits a decision tree classifier. 84 """ 85 def process(self, element, *args, **kwargs): 86 X, y, cat_ix, num_ix, key = element 87 steps = [('c', OneHotEncoder(handle_unknown='ignore'), cat_ix), 88 ('n', MinMaxScaler(), num_ix)] 89 # one hot encode categorical, normalize numerical 90 ct = ColumnTransformer(steps) 91 # wrap the model in a pipeline 92 pipeline = Pipeline(steps=[('t', ct), ('m', DecisionTreeClassifier())]) 93 pipeline.fit(X, y) 94 yield (key, pipeline) 95 96 97 class ModelSink(fileio.FileSink): 98 def open(self, fh): 99 self._fh = fh 100 101 def write(self, record): 102 _, trained_model = record 103 pickled_model = pickle.dumps(trained_model) 104 self._fh.write(pickled_model) 105 106 def flush(self): 107 self._fh.flush() 108 109 110 def parse_known_args(argv): 111 """Parses args for the workflow.""" 112 parser = argparse.ArgumentParser() 113 parser.add_argument( 114 '--input', 115 dest='input', 116 help='Path to the text file containing sentences.') 117 parser.add_argument( 118 '--output-dir', 119 dest='output', 120 required=True, 121 help='Path of directory for saving trained models.') 122 return parser.parse_known_args(argv) 123 124 125 def run( 126 argv=None, 127 save_main_session=True, 128 ): 129 """ 130 Args: 131 argv: Command line arguments defined for this example. 132 save_main_session: Used for internal testing. 133 """ 134 known_args, pipeline_args = parse_known_args(argv) 135 pipeline_options = PipelineOptions(pipeline_args) 136 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 137 with beam.Pipeline(options=pipeline_options) as pipeline: 138 _ = ( 139 pipeline | "Read Data" >> beam.io.ReadFromText(known_args.input) 140 | "Split data to make List" >> beam.Map(lambda x: x.split(',')) 141 | "Filter rows" >> beam.Filter(custom_filter) 142 | "Create Key" >> beam.ParDo(CreateKey()) 143 | "Group by education" >> beam.GroupByKey() 144 | "Prepare Data" >> beam.ParDo(PrepareDataforTraining()) 145 | "Train Model" >> beam.ParDo(TrainModel()) 146 | 147 "Save" >> fileio.WriteToFiles(path=known_args.output, sink=ModelSink())) 148 149 150 if __name__ == "__main__": 151 logging.getLogger().setLevel(logging.INFO) 152 run()