github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A streaming pipeline that uses RunInference API and windowing that classifies
    19  the quality of milk as good, bad or medium based on pH, temperature,
    20  taste, odor, fat, turbidity and color. Each minute new measurements come in
    21  and a sliding window aggregates the number of good, bad and medium
    22  samples.
    23  
    24  This example uses the milk quality prediction dataset from kaggle.
    25  https://www.kaggle.com/datasets/cpluzshrijayan/milkquality
    26  
    27  
    28  In order to set this example up, you will need two things.
    29  1. Download the data in csv format from kaggle and host it.
    30  2. Split the dataset in a training set and test set (preprocess_data function).
    31  3. Train the classifier.
    32  """
    33  
    34  import argparse
    35  import logging
    36  import time
    37  from typing import NamedTuple
    38  
    39  import pandas
    40  from sklearn.model_selection import train_test_split
    41  
    42  import apache_beam as beam
    43  import xgboost
    44  from apache_beam import window
    45  from apache_beam.ml.inference import RunInference
    46  from apache_beam.ml.inference.base import PredictionResult
    47  from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerPandas
    48  from apache_beam.options.pipeline_options import PipelineOptions
    49  from apache_beam.options.pipeline_options import SetupOptions
    50  from apache_beam.runners.runner import PipelineResult
    51  from apache_beam.testing.test_stream import TestStream
    52  
    53  
    54  def parse_known_args(argv):
    55    """Parses args for the workflow."""
    56    parser = argparse.ArgumentParser()
    57    parser.add_argument(
    58        '--dataset',
    59        dest='dataset',
    60        required=True,
    61        help='Path to the csv containing Kaggle Milk Quality dataset.')
    62    parser.add_argument(
    63        '--pipeline_input_data',
    64        dest='pipeline_input_data',
    65        required=True,
    66        help='Path to store the csv containing input data for the pipeline.'
    67        'This will be generated as part of preprocessing the data.')
    68    parser.add_argument(
    69        '--training_set',
    70        dest='training_set',
    71        required=True,
    72        help='Path to store the csv containing the training set.'
    73        'This will be generated as part of preprocessing the data.')
    74    parser.add_argument(
    75        '--labels',
    76        dest='labels',
    77        required=True,
    78        help='Path to store the csv containing the labels used in training.'
    79        'This will be generated as part of preprocessing the data.')
    80    parser.add_argument(
    81        '--model_state',
    82        dest='model_state',
    83        required=True,
    84        help='Path to the state of the XGBoost model loaded for Inference.')
    85    return parser.parse_known_args(argv)
    86  
    87  
    88  def preprocess_data(
    89      dataset_path: str,
    90      training_set_path: str,
    91      labels_path: str,
    92      test_set_path: str):
    93    """
    94      Helper function to split the dataset into a training set
    95      and its labels and a test set. The training set and
    96      its labels are used to train a lightweight model.
    97      The test set is used to create a test streaming pipeline.
    98      Args:
    99          dataset_path: path to csv file containing the Kaggle
   100           milk quality dataset
   101          training_set_path: path to output the training samples
   102          labels_path:  path to output the labels for the training set
   103          test_set_path: path to output the test samples
   104      """
   105    df = pandas.read_csv(dataset_path)
   106    df['Grade'].replace(['low', 'medium', 'high'], [0, 1, 2], inplace=True)
   107    x = df.drop(columns=['Grade'])
   108    y = df['Grade']
   109    x_train, x_test, y_train, _ = \
   110        train_test_split(x, y, test_size=0.60, random_state=99)
   111    x_train.to_csv(training_set_path, index=False)
   112    y_train.to_csv(labels_path, index=False)
   113    x_test.to_csv(test_set_path, index=False)
   114  
   115  
   116  def train_model(
   117      samples_path: str, labels_path: str, model_state_output_path: str):
   118    """Function to train the XGBoost model.
   119      Args:
   120        samples_path: path to csv file containing the training data
   121        labels_path: path to csv file containing the labels for the training data
   122        model_state_output_path: Path to store the trained model
   123    """
   124    samples = pandas.read_csv(samples_path)
   125    labels = pandas.read_csv(labels_path)
   126    xgb = xgboost.XGBClassifier(max_depth=3)
   127    xgb.fit(samples, labels)
   128    xgb.save_model(model_state_output_path)
   129    return xgb
   130  
   131  
   132  class MilkQualityAggregation(NamedTuple):
   133    bad_quality_measurements: int
   134    medium_quality_measurements: int
   135    high_quality_measurements: int
   136  
   137  
   138  class AggregateMilkQualityResults(beam.CombineFn):
   139    """Simple aggregation to keep track of the number
   140     of samples with good, bad and medium quality milk."""
   141    def create_accumulator(self):
   142      return MilkQualityAggregation(0, 0, 0)
   143  
   144    def add_input(
   145        self, accumulator: MilkQualityAggregation, element: PredictionResult):
   146      quality = element.inference[0]
   147      if quality == 0:
   148        return MilkQualityAggregation(
   149            accumulator.bad_quality_measurements + 1,
   150            accumulator.medium_quality_measurements,
   151            accumulator.high_quality_measurements)
   152      elif quality == 1:
   153        return MilkQualityAggregation(
   154            accumulator.bad_quality_measurements,
   155            accumulator.medium_quality_measurements + 1,
   156            accumulator.high_quality_measurements)
   157      else:
   158        return MilkQualityAggregation(
   159            accumulator.bad_quality_measurements,
   160            accumulator.medium_quality_measurements,
   161            accumulator.high_quality_measurements + 1)
   162  
   163    def merge_accumulators(self, accumulators: MilkQualityAggregation):
   164      return MilkQualityAggregation(
   165          sum(
   166              aggregation.bad_quality_measurements
   167              for aggregation in accumulators),
   168          sum(
   169              aggregation.medium_quality_measurements
   170              for aggregation in accumulators),
   171          sum(
   172              aggregation.high_quality_measurements
   173              for aggregation in accumulators),
   174      )
   175  
   176    def extract_output(self, accumulator: MilkQualityAggregation):
   177      return accumulator
   178  
   179  
   180  def run(
   181      argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
   182    """
   183      Args:
   184        argv: Command line arguments defined for this example.
   185        save_main_session: Used for internal testing.
   186        test_pipeline: Used for internal testing.
   187    """
   188    known_args, pipeline_args = parse_known_args(argv)
   189    pipeline_options = PipelineOptions(pipeline_args)
   190    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   191  
   192    milk_quality_data = pandas.read_csv(known_args.pipeline_input_data)
   193  
   194    start = time.mktime(time.strptime('2023/06/29 10:00:00', '%Y/%m/%d %H:%M:%S'))
   195  
   196    # Create a test stream
   197    test_stream = TestStream()
   198  
   199    # Watermark is set to 10:00:00
   200    test_stream.advance_watermark_to(start)
   201  
   202    # Split the dataframe in individual samples
   203    samples = [
   204        milk_quality_data.iloc[i:i + 1] for i in range(len(milk_quality_data))
   205    ]
   206  
   207    for watermark_offset, sample in enumerate(samples, 1):
   208      test_stream.advance_watermark_to(start + watermark_offset)
   209      test_stream.add_elements([sample])
   210  
   211    test_stream.advance_watermark_to_infinity()
   212  
   213    model_handler = XGBoostModelHandlerPandas(
   214        model_class=xgboost.XGBClassifier, model_state=known_args.model_state)
   215  
   216    with beam.Pipeline() as p:
   217      _ = (
   218          p | test_stream
   219          | 'window' >> beam.WindowInto(window.SlidingWindows(30, 5))
   220          | "RunInference" >> RunInference(model_handler)
   221          | 'Count number of elements in window' >> beam.CombineGlobally(
   222              AggregateMilkQualityResults()).without_defaults()
   223          | 'Print' >> beam.Map(print))
   224  
   225  
   226  if __name__ == '__main__':
   227    logging.getLogger().setLevel(logging.INFO)
   228  
   229    known_args, _ = parse_known_args(None)
   230  
   231    preprocess_data(
   232        known_args.dataset,
   233        training_set_path=known_args.training_set,
   234        labels_path=known_args.labels,
   235        test_set_path=known_args.pipeline_input_data)
   236    train_model(
   237        samples_path=known_args.training_set,
   238        labels_path=known_args.labels,
   239        model_state_output_path=known_args.model_state)
   240    run()