github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/utils.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 """ 20 Util/helper functions used in apache_beam.ml.inference. 21 """ 22 import os 23 from functools import partial 24 from typing import Any 25 from typing import Dict 26 from typing import Iterable 27 from typing import Optional 28 from typing import Union 29 30 import apache_beam as beam 31 from apache_beam.io.fileio import EmptyMatchTreatment 32 from apache_beam.io.fileio import MatchContinuously 33 from apache_beam.ml.inference.base import ModelMetadata 34 from apache_beam.ml.inference.base import PredictionResult 35 from apache_beam.transforms import trigger 36 from apache_beam.transforms import window 37 from apache_beam.transforms.userstate import CombiningValueStateSpec 38 from apache_beam.utils.timestamp import MAX_TIMESTAMP 39 from apache_beam.utils.timestamp import Timestamp 40 41 _START_TIME_STAMP = Timestamp.now() 42 43 44 def _convert_to_result( 45 batch: Iterable, 46 predictions: Union[Iterable, Dict[Any, Iterable]], 47 model_id: Optional[str] = None, 48 ) -> Iterable[PredictionResult]: 49 if isinstance(predictions, dict): 50 # Go from one dictionary of type: {key_type1: Iterable<val_type1>, 51 # key_type2: Iterable<val_type2>, ...} where each Iterable is of 52 # length batch_size, to a list of dictionaries: 53 # [{key_type1: value_type1, key_type2: value_type2}] 54 predictions_per_tensor = [ 55 dict(zip(predictions.keys(), v)) for v in zip(*predictions.values()) 56 ] 57 return [ 58 PredictionResult(x, y, model_id) for x, 59 y in zip(batch, predictions_per_tensor) 60 ] 61 return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)] 62 63 64 class _ConvertIterToSingleton(beam.DoFn): 65 """ 66 Internal only; No backwards compatibility. 67 68 The MatchContinuously transform examines all files present in a given 69 directory and returns those that have timestamps older than the 70 pipeline's start time. This can produce an Iterable rather than a 71 Singleton. This class only returns the file path when it is first 72 encountered, and it is cached as part of the side input caching mechanism. 73 If the path is seen again, it will not return anything. 74 By doing this, we can ensure that the output of this transform can be wrapped 75 with beam.pvalue.AsSingleton(). 76 """ 77 COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) 78 79 def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): 80 counter = count_state.read() 81 if counter == 0: 82 count_state.add(1) 83 yield element[1] 84 85 86 class _GetLatestFileByTimeStamp(beam.DoFn): 87 """ 88 Internal only; No backwards compatibility. 89 90 This DoFn checks the timestamps of files against the time that the pipeline 91 began running. It returns the files that were modified after the pipeline 92 started. If no such files are found, it returns a default file as fallback. 93 """ 94 TIME_STATE = CombiningValueStateSpec( 95 'max', combine_fn=partial(max, default=_START_TIME_STAMP)) 96 97 def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)): 98 _, file_metadata = element 99 new_ts = file_metadata.last_updated_in_seconds 100 old_ts = time_state.read() 101 if new_ts > old_ts: 102 time_state.clear() 103 time_state.add(new_ts) 104 model_path = file_metadata.path 105 else: 106 model_path = '' 107 108 model_name = os.path.splitext(os.path.basename(model_path))[0] 109 return [ 110 (model_path, ModelMetadata(model_id=model_path, model_name=model_name)) 111 ] 112 113 114 class WatchFilePattern(beam.PTransform): 115 def __init__( 116 self, 117 file_pattern, 118 interval=360, 119 stop_timestamp=MAX_TIMESTAMP, 120 ): 121 """ 122 Watches a directory for updates to files matching a given file pattern. 123 124 Args: 125 file_pattern: The file path to read from as a local file path or a 126 GCS ``gs://`` path. The path can contain glob characters 127 (``*``, ``?``, and ``[...]`` sets). 128 interval: Interval at which to check for files matching file_pattern 129 in seconds. 130 stop_timestamp: Timestamp after which no more files will be checked. 131 132 **Note**: 133 134 1. Any previously used filenames cannot be reused. If a file is added 135 or updated to a previously used filename, this transform will ignore 136 that update. To trigger a model update, always upload a file with 137 unique name. 138 2. Initially, before the pipeline startup time, WatchFilePattern expects 139 at least one file present that matches the file_pattern. 140 3. This transform is supported in streaming mode since 141 MatchContinuously produces an unbounded source. Running in batch 142 mode can lead to undesired results or result in pipeline being stuck. 143 144 145 """ 146 self.file_pattern = file_pattern 147 self.interval = interval 148 self.stop_timestamp = stop_timestamp 149 150 def expand(self, pcoll) -> beam.PCollection[ModelMetadata]: 151 return ( 152 pcoll 153 | 'MatchContinuously' >> MatchContinuously( 154 file_pattern=self.file_pattern, 155 interval=self.interval, 156 stop_timestamp=self.stop_timestamp, 157 empty_match_treatment=EmptyMatchTreatment.DISALLOW) 158 | "AttachKey" >> beam.Map(lambda x: (x.path, x)) 159 | "GetLatestFileMetaData" >> beam.ParDo(_GetLatestFileByTimeStamp()) 160 | "AcceptNewSideInputOnly" >> beam.ParDo(_ConvertIterToSingleton()) 161 | 'ApplyGlobalWindow' >> beam.transforms.WindowInto( 162 window.GlobalWindows(), 163 trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), 164 accumulation_mode=trigger.AccumulationMode.DISCARDING))