github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A streaming pipeline that uses RunInference API and windowing that classifies 19 the quality of milk as good, bad or medium based on pH, temperature, 20 taste, odor, fat, turbidity and color. Each minute new measurements come in 21 and a sliding window aggregates the number of good, bad and medium 22 samples. 23 24 This example uses the milk quality prediction dataset from kaggle. 25 https://www.kaggle.com/datasets/cpluzshrijayan/milkquality 26 27 28 In order to set this example up, you will need two things. 29 1. Download the data in csv format from kaggle and host it. 30 2. Split the dataset in a training set and test set (preprocess_data function). 31 3. Train the classifier. 32 """ 33 34 import argparse 35 import logging 36 import time 37 from typing import NamedTuple 38 39 import pandas 40 from sklearn.model_selection import train_test_split 41 42 import apache_beam as beam 43 import xgboost 44 from apache_beam import window 45 from apache_beam.ml.inference import RunInference 46 from apache_beam.ml.inference.base import PredictionResult 47 from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerPandas 48 from apache_beam.options.pipeline_options import PipelineOptions 49 from apache_beam.options.pipeline_options import SetupOptions 50 from apache_beam.runners.runner import PipelineResult 51 from apache_beam.testing.test_stream import TestStream 52 53 54 def parse_known_args(argv): 55 """Parses args for the workflow.""" 56 parser = argparse.ArgumentParser() 57 parser.add_argument( 58 '--dataset', 59 dest='dataset', 60 required=True, 61 help='Path to the csv containing Kaggle Milk Quality dataset.') 62 parser.add_argument( 63 '--pipeline_input_data', 64 dest='pipeline_input_data', 65 required=True, 66 help='Path to store the csv containing input data for the pipeline.' 67 'This will be generated as part of preprocessing the data.') 68 parser.add_argument( 69 '--training_set', 70 dest='training_set', 71 required=True, 72 help='Path to store the csv containing the training set.' 73 'This will be generated as part of preprocessing the data.') 74 parser.add_argument( 75 '--labels', 76 dest='labels', 77 required=True, 78 help='Path to store the csv containing the labels used in training.' 79 'This will be generated as part of preprocessing the data.') 80 parser.add_argument( 81 '--model_state', 82 dest='model_state', 83 required=True, 84 help='Path to the state of the XGBoost model loaded for Inference.') 85 return parser.parse_known_args(argv) 86 87 88 def preprocess_data( 89 dataset_path: str, 90 training_set_path: str, 91 labels_path: str, 92 test_set_path: str): 93 """ 94 Helper function to split the dataset into a training set 95 and its labels and a test set. The training set and 96 its labels are used to train a lightweight model. 97 The test set is used to create a test streaming pipeline. 98 Args: 99 dataset_path: path to csv file containing the Kaggle 100 milk quality dataset 101 training_set_path: path to output the training samples 102 labels_path: path to output the labels for the training set 103 test_set_path: path to output the test samples 104 """ 105 df = pandas.read_csv(dataset_path) 106 df['Grade'].replace(['low', 'medium', 'high'], [0, 1, 2], inplace=True) 107 x = df.drop(columns=['Grade']) 108 y = df['Grade'] 109 x_train, x_test, y_train, _ = \ 110 train_test_split(x, y, test_size=0.60, random_state=99) 111 x_train.to_csv(training_set_path, index=False) 112 y_train.to_csv(labels_path, index=False) 113 x_test.to_csv(test_set_path, index=False) 114 115 116 def train_model( 117 samples_path: str, labels_path: str, model_state_output_path: str): 118 """Function to train the XGBoost model. 119 Args: 120 samples_path: path to csv file containing the training data 121 labels_path: path to csv file containing the labels for the training data 122 model_state_output_path: Path to store the trained model 123 """ 124 samples = pandas.read_csv(samples_path) 125 labels = pandas.read_csv(labels_path) 126 xgb = xgboost.XGBClassifier(max_depth=3) 127 xgb.fit(samples, labels) 128 xgb.save_model(model_state_output_path) 129 return xgb 130 131 132 class MilkQualityAggregation(NamedTuple): 133 bad_quality_measurements: int 134 medium_quality_measurements: int 135 high_quality_measurements: int 136 137 138 class AggregateMilkQualityResults(beam.CombineFn): 139 """Simple aggregation to keep track of the number 140 of samples with good, bad and medium quality milk.""" 141 def create_accumulator(self): 142 return MilkQualityAggregation(0, 0, 0) 143 144 def add_input( 145 self, accumulator: MilkQualityAggregation, element: PredictionResult): 146 quality = element.inference[0] 147 if quality == 0: 148 return MilkQualityAggregation( 149 accumulator.bad_quality_measurements + 1, 150 accumulator.medium_quality_measurements, 151 accumulator.high_quality_measurements) 152 elif quality == 1: 153 return MilkQualityAggregation( 154 accumulator.bad_quality_measurements, 155 accumulator.medium_quality_measurements + 1, 156 accumulator.high_quality_measurements) 157 else: 158 return MilkQualityAggregation( 159 accumulator.bad_quality_measurements, 160 accumulator.medium_quality_measurements, 161 accumulator.high_quality_measurements + 1) 162 163 def merge_accumulators(self, accumulators: MilkQualityAggregation): 164 return MilkQualityAggregation( 165 sum( 166 aggregation.bad_quality_measurements 167 for aggregation in accumulators), 168 sum( 169 aggregation.medium_quality_measurements 170 for aggregation in accumulators), 171 sum( 172 aggregation.high_quality_measurements 173 for aggregation in accumulators), 174 ) 175 176 def extract_output(self, accumulator: MilkQualityAggregation): 177 return accumulator 178 179 180 def run( 181 argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: 182 """ 183 Args: 184 argv: Command line arguments defined for this example. 185 save_main_session: Used for internal testing. 186 test_pipeline: Used for internal testing. 187 """ 188 known_args, pipeline_args = parse_known_args(argv) 189 pipeline_options = PipelineOptions(pipeline_args) 190 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 191 192 milk_quality_data = pandas.read_csv(known_args.pipeline_input_data) 193 194 start = time.mktime(time.strptime('2023/06/29 10:00:00', '%Y/%m/%d %H:%M:%S')) 195 196 # Create a test stream 197 test_stream = TestStream() 198 199 # Watermark is set to 10:00:00 200 test_stream.advance_watermark_to(start) 201 202 # Split the dataframe in individual samples 203 samples = [ 204 milk_quality_data.iloc[i:i + 1] for i in range(len(milk_quality_data)) 205 ] 206 207 for watermark_offset, sample in enumerate(samples, 1): 208 test_stream.advance_watermark_to(start + watermark_offset) 209 test_stream.add_elements([sample]) 210 211 test_stream.advance_watermark_to_infinity() 212 213 model_handler = XGBoostModelHandlerPandas( 214 model_class=xgboost.XGBClassifier, model_state=known_args.model_state) 215 216 with beam.Pipeline() as p: 217 _ = ( 218 p | test_stream 219 | 'window' >> beam.WindowInto(window.SlidingWindows(30, 5)) 220 | "RunInference" >> RunInference(model_handler) 221 | 'Count number of elements in window' >> beam.CombineGlobally( 222 AggregateMilkQualityResults()).without_defaults() 223 | 'Print' >> beam.Map(print)) 224 225 226 if __name__ == '__main__': 227 logging.getLogger().setLevel(logging.INFO) 228 229 known_args, _ = parse_known_args(None) 230 231 preprocess_data( 232 known_args.dataset, 233 training_set_path=known_args.training_set, 234 labels_path=known_args.labels, 235 test_set_path=known_args.pipeline_input_data) 236 train_model( 237 samples_path=known_args.training_set, 238 labels_path=known_args.labels, 239 model_state_output_path=known_args.model_state) 240 run()