github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/sideinputs.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Internal side input transforms and implementations.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  
    22  Important: this module is an implementation detail and should not be used
    23  directly by pipeline writers. Instead, users should use the helper methods
    24  AsSingleton, AsIter, AsList and AsDict in apache_beam.pvalue.
    25  """
    26  
    27  # pytype: skip-file
    28  
    29  import re
    30  from typing import TYPE_CHECKING
    31  from typing import Any
    32  from typing import Callable
    33  from typing import Dict
    34  
    35  from apache_beam.transforms import window
    36  
    37  if TYPE_CHECKING:
    38    from apache_beam import pvalue
    39  
    40  WindowMappingFn = Callable[[window.BoundedWindow], window.BoundedWindow]
    41  
    42  SIDE_INPUT_PREFIX = 'python_side_input'
    43  
    44  SIDE_INPUT_REGEX = SIDE_INPUT_PREFIX + '([0-9]+)(-.*)?$'
    45  
    46  
    47  # Top-level function so we can identify it later.
    48  def _global_window_mapping_fn(w, global_window=window.GlobalWindow()):
    49    # type: (...) -> window.GlobalWindow
    50    return global_window
    51  
    52  
    53  def default_window_mapping_fn(target_window_fn):
    54    # type: (window.WindowFn) -> WindowMappingFn
    55    if target_window_fn == window.GlobalWindows():
    56      return _global_window_mapping_fn
    57  
    58    if isinstance(target_window_fn, window.Sessions):
    59      raise RuntimeError("Sessions is not allowed in side inputs")
    60  
    61    def map_via_end(source_window):
    62      # type: (window.BoundedWindow) -> window.BoundedWindow
    63      return list(
    64          target_window_fn.assign(
    65              window.WindowFn.AssignContext(source_window.max_timestamp())))[-1]
    66  
    67    return map_via_end
    68  
    69  
    70  def get_sideinput_index(tag):
    71    # type: (str) -> int
    72    match = re.match(SIDE_INPUT_REGEX, tag, re.DOTALL)
    73    if match:
    74      return int(match.group(1))
    75    else:
    76      raise RuntimeError("Invalid tag %r" % tag)
    77  
    78  
    79  class SideInputMap(object):
    80    """Represents a mapping of windows to side input values."""
    81    def __init__(
    82        self,
    83        view_class,  # type: pvalue.AsSideInput
    84        view_options,
    85        iterable):
    86      self._window_mapping_fn = view_options.get(
    87          'window_mapping_fn', _global_window_mapping_fn)
    88      self._view_class = view_class
    89      self._view_options = view_options
    90      self._iterable = iterable
    91      self._cache = {}  # type: Dict[window.BoundedWindow, Any]
    92  
    93    def __getitem__(self, window):
    94      # type: (window.BoundedWindow) -> Any
    95      if window not in self._cache:
    96        target_window = self._window_mapping_fn(window)
    97        self._cache[window] = self._view_class._from_runtime_iterable(
    98            _FilteringIterable(self._iterable, target_window), self._view_options)
    99      return self._cache[window]
   100  
   101    def is_globally_windowed(self):
   102      # type: () -> bool
   103      return self._window_mapping_fn == _global_window_mapping_fn
   104  
   105  
   106  class _FilteringIterable(object):
   107    """An iterable containing only those values in the given window.
   108    """
   109    def __init__(self, iterable, target_window):
   110      self._iterable = iterable
   111      self._target_window = target_window
   112  
   113    def __iter__(self):
   114      for wv in self._iterable:
   115        if self._target_window in wv.windows:
   116          yield wv.value
   117  
   118    def __reduce__(self):
   119      # Pickle self as an already filtered list.
   120      return list, (list(self), )