github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/utils.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  """
    20  Util/helper functions used in apache_beam.ml.inference.
    21  """
    22  import os
    23  from functools import partial
    24  from typing import Any
    25  from typing import Dict
    26  from typing import Iterable
    27  from typing import Optional
    28  from typing import Union
    29  
    30  import apache_beam as beam
    31  from apache_beam.io.fileio import EmptyMatchTreatment
    32  from apache_beam.io.fileio import MatchContinuously
    33  from apache_beam.ml.inference.base import ModelMetadata
    34  from apache_beam.ml.inference.base import PredictionResult
    35  from apache_beam.transforms import trigger
    36  from apache_beam.transforms import window
    37  from apache_beam.transforms.userstate import CombiningValueStateSpec
    38  from apache_beam.utils.timestamp import MAX_TIMESTAMP
    39  from apache_beam.utils.timestamp import Timestamp
    40  
    41  _START_TIME_STAMP = Timestamp.now()
    42  
    43  
    44  def _convert_to_result(
    45      batch: Iterable,
    46      predictions: Union[Iterable, Dict[Any, Iterable]],
    47      model_id: Optional[str] = None,
    48  ) -> Iterable[PredictionResult]:
    49    if isinstance(predictions, dict):
    50      # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
    51      # key_type2: Iterable<val_type2>, ...} where each Iterable is of
    52      # length batch_size, to a list of dictionaries:
    53      # [{key_type1: value_type1, key_type2: value_type2}]
    54      predictions_per_tensor = [
    55          dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
    56      ]
    57      return [
    58          PredictionResult(x, y, model_id) for x,
    59          y in zip(batch, predictions_per_tensor)
    60      ]
    61    return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)]
    62  
    63  
    64  class _ConvertIterToSingleton(beam.DoFn):
    65    """
    66    Internal only; No backwards compatibility.
    67  
    68    The MatchContinuously transform examines all files present in a given
    69    directory and returns those that have timestamps older than the
    70    pipeline's start time. This can produce an Iterable rather than a
    71    Singleton. This class only returns the file path when it is first
    72    encountered, and it is cached as part of the side input caching mechanism.
    73    If the path is seen again, it will not return anything.
    74    By doing this, we can ensure that the output of this transform can be wrapped
    75    with beam.pvalue.AsSingleton().
    76    """
    77    COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
    78  
    79    def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
    80      counter = count_state.read()
    81      if counter == 0:
    82        count_state.add(1)
    83        yield element[1]
    84  
    85  
    86  class _GetLatestFileByTimeStamp(beam.DoFn):
    87    """
    88    Internal only; No backwards compatibility.
    89  
    90    This DoFn checks the timestamps of files against the time that the pipeline
    91    began running. It returns the files that were modified after the pipeline
    92    started. If no such files are found, it returns a default file as fallback.
    93     """
    94    TIME_STATE = CombiningValueStateSpec(
    95        'max', combine_fn=partial(max, default=_START_TIME_STAMP))
    96  
    97    def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)):
    98      _, file_metadata = element
    99      new_ts = file_metadata.last_updated_in_seconds
   100      old_ts = time_state.read()
   101      if new_ts > old_ts:
   102        time_state.clear()
   103        time_state.add(new_ts)
   104        model_path = file_metadata.path
   105      else:
   106        model_path = ''
   107  
   108      model_name = os.path.splitext(os.path.basename(model_path))[0]
   109      return [
   110          (model_path, ModelMetadata(model_id=model_path, model_name=model_name))
   111      ]
   112  
   113  
   114  class WatchFilePattern(beam.PTransform):
   115    def __init__(
   116        self,
   117        file_pattern,
   118        interval=360,
   119        stop_timestamp=MAX_TIMESTAMP,
   120    ):
   121      """
   122      Watches a directory for updates to files matching a given file pattern.
   123  
   124      Args:
   125        file_pattern: The file path to read from as a local file path or a
   126          GCS ``gs://`` path. The path can contain glob characters
   127          (``*``, ``?``, and ``[...]`` sets).
   128          interval: Interval at which to check for files matching file_pattern
   129          in seconds.
   130        stop_timestamp: Timestamp after which no more files will be checked.
   131  
   132      **Note**:
   133  
   134      1. Any previously used filenames cannot be reused. If a file is added
   135          or updated to a previously used filename, this transform will ignore
   136          that update. To trigger a model update, always upload a file with
   137          unique name.
   138      2. Initially, before the pipeline startup time, WatchFilePattern expects
   139          at least one file present that matches the file_pattern.
   140      3. This transform is supported in streaming mode since
   141          MatchContinuously produces an unbounded source. Running in batch
   142          mode can lead to undesired results or result in pipeline being stuck.
   143  
   144  
   145      """
   146      self.file_pattern = file_pattern
   147      self.interval = interval
   148      self.stop_timestamp = stop_timestamp
   149  
   150    def expand(self, pcoll) -> beam.PCollection[ModelMetadata]:
   151      return (
   152          pcoll
   153          | 'MatchContinuously' >> MatchContinuously(
   154              file_pattern=self.file_pattern,
   155              interval=self.interval,
   156              stop_timestamp=self.stop_timestamp,
   157              empty_match_treatment=EmptyMatchTreatment.DISALLOW)
   158          | "AttachKey" >> beam.Map(lambda x: (x.path, x))
   159          | "GetLatestFileMetaData" >> beam.ParDo(_GetLatestFileByTimeStamp())
   160          | "AcceptNewSideInputOnly" >> beam.ParDo(_ConvertIterToSingleton())
   161          | 'ApplyGlobalWindow' >> beam.transforms.WindowInto(
   162              window.GlobalWindows(),
   163              trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
   164              accumulation_mode=trigger.AccumulationMode.DISCARDING))