github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/run_inference_side_inputs.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """
    19  Used for internal testing. No backwards compatibility.
    20  """
    21  
    22  import argparse
    23  import logging
    24  import time
    25  from typing import Iterable
    26  from typing import Optional
    27  from typing import Sequence
    28  
    29  import apache_beam as beam
    30  from apache_beam.ml.inference import base
    31  from apache_beam.options.pipeline_options import PipelineOptions
    32  from apache_beam.options.pipeline_options import SetupOptions
    33  from apache_beam.transforms import trigger
    34  from apache_beam.transforms import window
    35  from apache_beam.transforms.periodicsequence import PeriodicImpulse
    36  from apache_beam.transforms.userstate import CombiningValueStateSpec
    37  
    38  
    39  # create some fake models which returns different inference results.
    40  class FakeModelDefault:
    41    def predict(self, example: int) -> int:
    42      return example
    43  
    44  
    45  class FakeModelAdd(FakeModelDefault):
    46    def predict(self, example: int) -> int:
    47      return example + 1
    48  
    49  
    50  class FakeModelSub(FakeModelDefault):
    51    def predict(self, example: int) -> int:
    52      return example - 1
    53  
    54  
    55  class FakeModelHandlerReturnsPredictionResult(
    56      base.ModelHandler[int, base.PredictionResult, FakeModelDefault]):
    57    def __init__(self, clock=None, model_id='model_default'):
    58      self.model_id = model_id
    59      self._fake_clock = clock
    60  
    61    def load_model(self):
    62      if self._fake_clock:
    63        self._fake_clock.current_time_ns += 500_000_000  # 500ms
    64      if self.model_id == 'model_add.pkl':
    65        return FakeModelAdd()
    66      elif self.model_id == 'model_sub.pkl':
    67        return FakeModelSub()
    68      return FakeModelDefault()
    69  
    70    def run_inference(
    71        self,
    72        batch: Sequence[int],
    73        model: FakeModelDefault,
    74        inference_args=None) -> Iterable[base.PredictionResult]:
    75      for example in batch:
    76        yield base.PredictionResult(
    77            model_id=self.model_id,
    78            example=example,
    79            inference=model.predict(example))
    80  
    81    def update_model_path(self, model_path: Optional[str] = None):
    82      self.model_id = model_path if model_path else self.model_id
    83  
    84  
    85  def run(argv=None, save_main_session=True):
    86    parser = argparse.ArgumentParser()
    87    first_ts = time.time()
    88    side_input_interval = 60
    89    main_input_interval = 20
    90    # give some time for dataflow to start.
    91    last_ts = first_ts + 1200
    92    mid_ts = (first_ts + last_ts) / 2
    93  
    94    _, pipeline_args = parser.parse_known_args(argv)
    95    options = PipelineOptions(pipeline_args)
    96    options.view_as(SetupOptions).save_main_session = save_main_session
    97  
    98    class GetModel(beam.DoFn):
    99      def process(self, element) -> Iterable[base.ModelMetadata]:
   100        if time.time() > mid_ts:
   101          yield base.ModelMetadata(
   102              model_id='model_add.pkl', model_name='model_add')
   103        else:
   104          yield base.ModelMetadata(
   105              model_id='model_sub.pkl', model_name='model_sub')
   106  
   107    class _EmitSingletonSideInput(beam.DoFn):
   108      COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
   109  
   110      def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
   111        _, path = element
   112        counter = count_state.read()
   113        if counter == 0:
   114          count_state.add(1)
   115          yield path
   116  
   117    def validate_prediction_result(x: base.PredictionResult):
   118      model_id = x.model_id
   119      if model_id == 'model_sub.pkl':
   120        assert (x.example == 1 and x.inference == 0)
   121  
   122      if model_id == 'model_add.pkl':
   123        assert (x.example == 1 and x.inference == 2)
   124  
   125      if model_id == 'model_default':
   126        assert (x.example == 1 and x.inference == 1)
   127  
   128    with beam.Pipeline(options=options) as pipeline:
   129      side_input = (
   130          pipeline
   131          | "SideInputPColl" >> PeriodicImpulse(
   132              first_ts, last_ts, fire_interval=side_input_interval)
   133          | "GetModelId" >> beam.ParDo(GetModel())
   134          | "AttachKey" >> beam.Map(lambda x: (x, x))
   135          # due to periodic impulse, which has a start timestamp before
   136          # Dataflow pipeline process data, it can trigger in multiple
   137          # firings, causing an Iterable instead of singleton. So, using
   138          # the _EmitSingletonSideInput DoFn will ensure unique path will be
   139          # fired only once.
   140          | "GetSingleton" >> beam.ParDo(_EmitSingletonSideInput())
   141          | "ApplySideInputWindow" >> beam.WindowInto(
   142              window.GlobalWindows(),
   143              trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
   144              accumulation_mode=trigger.AccumulationMode.DISCARDING))
   145  
   146      model_handler = FakeModelHandlerReturnsPredictionResult()
   147      inference_pcoll = (
   148          pipeline
   149          | "MainInputPColl" >> PeriodicImpulse(
   150              first_ts,
   151              last_ts,
   152              fire_interval=main_input_interval,
   153              apply_windowing=True)
   154          | beam.Map(lambda x: 1)
   155          | base.RunInference(
   156              model_handler=model_handler, model_metadata_pcoll=side_input))
   157  
   158      _ = inference_pcoll | "AssertPredictionResult" >> beam.Map(
   159          validate_prediction_result)
   160  
   161      _ = inference_pcoll | "Logging" >> beam.Map(logging.info)
   162  
   163  
   164  if __name__ == '__main__':
   165    run()