github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/run_inference_side_inputs.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """ 19 Used for internal testing. No backwards compatibility. 20 """ 21 22 import argparse 23 import logging 24 import time 25 from typing import Iterable 26 from typing import Optional 27 from typing import Sequence 28 29 import apache_beam as beam 30 from apache_beam.ml.inference import base 31 from apache_beam.options.pipeline_options import PipelineOptions 32 from apache_beam.options.pipeline_options import SetupOptions 33 from apache_beam.transforms import trigger 34 from apache_beam.transforms import window 35 from apache_beam.transforms.periodicsequence import PeriodicImpulse 36 from apache_beam.transforms.userstate import CombiningValueStateSpec 37 38 39 # create some fake models which returns different inference results. 40 class FakeModelDefault: 41 def predict(self, example: int) -> int: 42 return example 43 44 45 class FakeModelAdd(FakeModelDefault): 46 def predict(self, example: int) -> int: 47 return example + 1 48 49 50 class FakeModelSub(FakeModelDefault): 51 def predict(self, example: int) -> int: 52 return example - 1 53 54 55 class FakeModelHandlerReturnsPredictionResult( 56 base.ModelHandler[int, base.PredictionResult, FakeModelDefault]): 57 def __init__(self, clock=None, model_id='model_default'): 58 self.model_id = model_id 59 self._fake_clock = clock 60 61 def load_model(self): 62 if self._fake_clock: 63 self._fake_clock.current_time_ns += 500_000_000 # 500ms 64 if self.model_id == 'model_add.pkl': 65 return FakeModelAdd() 66 elif self.model_id == 'model_sub.pkl': 67 return FakeModelSub() 68 return FakeModelDefault() 69 70 def run_inference( 71 self, 72 batch: Sequence[int], 73 model: FakeModelDefault, 74 inference_args=None) -> Iterable[base.PredictionResult]: 75 for example in batch: 76 yield base.PredictionResult( 77 model_id=self.model_id, 78 example=example, 79 inference=model.predict(example)) 80 81 def update_model_path(self, model_path: Optional[str] = None): 82 self.model_id = model_path if model_path else self.model_id 83 84 85 def run(argv=None, save_main_session=True): 86 parser = argparse.ArgumentParser() 87 first_ts = time.time() 88 side_input_interval = 60 89 main_input_interval = 20 90 # give some time for dataflow to start. 91 last_ts = first_ts + 1200 92 mid_ts = (first_ts + last_ts) / 2 93 94 _, pipeline_args = parser.parse_known_args(argv) 95 options = PipelineOptions(pipeline_args) 96 options.view_as(SetupOptions).save_main_session = save_main_session 97 98 class GetModel(beam.DoFn): 99 def process(self, element) -> Iterable[base.ModelMetadata]: 100 if time.time() > mid_ts: 101 yield base.ModelMetadata( 102 model_id='model_add.pkl', model_name='model_add') 103 else: 104 yield base.ModelMetadata( 105 model_id='model_sub.pkl', model_name='model_sub') 106 107 class _EmitSingletonSideInput(beam.DoFn): 108 COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) 109 110 def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): 111 _, path = element 112 counter = count_state.read() 113 if counter == 0: 114 count_state.add(1) 115 yield path 116 117 def validate_prediction_result(x: base.PredictionResult): 118 model_id = x.model_id 119 if model_id == 'model_sub.pkl': 120 assert (x.example == 1 and x.inference == 0) 121 122 if model_id == 'model_add.pkl': 123 assert (x.example == 1 and x.inference == 2) 124 125 if model_id == 'model_default': 126 assert (x.example == 1 and x.inference == 1) 127 128 with beam.Pipeline(options=options) as pipeline: 129 side_input = ( 130 pipeline 131 | "SideInputPColl" >> PeriodicImpulse( 132 first_ts, last_ts, fire_interval=side_input_interval) 133 | "GetModelId" >> beam.ParDo(GetModel()) 134 | "AttachKey" >> beam.Map(lambda x: (x, x)) 135 # due to periodic impulse, which has a start timestamp before 136 # Dataflow pipeline process data, it can trigger in multiple 137 # firings, causing an Iterable instead of singleton. So, using 138 # the _EmitSingletonSideInput DoFn will ensure unique path will be 139 # fired only once. 140 | "GetSingleton" >> beam.ParDo(_EmitSingletonSideInput()) 141 | "ApplySideInputWindow" >> beam.WindowInto( 142 window.GlobalWindows(), 143 trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), 144 accumulation_mode=trigger.AccumulationMode.DISCARDING)) 145 146 model_handler = FakeModelHandlerReturnsPredictionResult() 147 inference_pcoll = ( 148 pipeline 149 | "MainInputPColl" >> PeriodicImpulse( 150 first_ts, 151 last_ts, 152 fire_interval=main_input_interval, 153 apply_windowing=True) 154 | beam.Map(lambda x: 1) 155 | base.RunInference( 156 model_handler=model_handler, model_metadata_pcoll=side_input)) 157 158 _ = inference_pcoll | "AssertPredictionResult" >> beam.Map( 159 validate_prediction_result) 160 161 _ = inference_pcoll | "Logging" >> beam.Map(logging.info) 162 163 164 if __name__ == '__main__': 165 run()