github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/sideinputs.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Internal side input transforms and implementations. 19 20 For internal use only; no backwards-compatibility guarantees. 21 22 Important: this module is an implementation detail and should not be used 23 directly by pipeline writers. Instead, users should use the helper methods 24 AsSingleton, AsIter, AsList and AsDict in apache_beam.pvalue. 25 """ 26 27 # pytype: skip-file 28 29 import re 30 from typing import TYPE_CHECKING 31 from typing import Any 32 from typing import Callable 33 from typing import Dict 34 35 from apache_beam.transforms import window 36 37 if TYPE_CHECKING: 38 from apache_beam import pvalue 39 40 WindowMappingFn = Callable[[window.BoundedWindow], window.BoundedWindow] 41 42 SIDE_INPUT_PREFIX = 'python_side_input' 43 44 SIDE_INPUT_REGEX = SIDE_INPUT_PREFIX + '([0-9]+)(-.*)?$' 45 46 47 # Top-level function so we can identify it later. 48 def _global_window_mapping_fn(w, global_window=window.GlobalWindow()): 49 # type: (...) -> window.GlobalWindow 50 return global_window 51 52 53 def default_window_mapping_fn(target_window_fn): 54 # type: (window.WindowFn) -> WindowMappingFn 55 if target_window_fn == window.GlobalWindows(): 56 return _global_window_mapping_fn 57 58 if isinstance(target_window_fn, window.Sessions): 59 raise RuntimeError("Sessions is not allowed in side inputs") 60 61 def map_via_end(source_window): 62 # type: (window.BoundedWindow) -> window.BoundedWindow 63 return list( 64 target_window_fn.assign( 65 window.WindowFn.AssignContext(source_window.max_timestamp())))[-1] 66 67 return map_via_end 68 69 70 def get_sideinput_index(tag): 71 # type: (str) -> int 72 match = re.match(SIDE_INPUT_REGEX, tag, re.DOTALL) 73 if match: 74 return int(match.group(1)) 75 else: 76 raise RuntimeError("Invalid tag %r" % tag) 77 78 79 class SideInputMap(object): 80 """Represents a mapping of windows to side input values.""" 81 def __init__( 82 self, 83 view_class, # type: pvalue.AsSideInput 84 view_options, 85 iterable): 86 self._window_mapping_fn = view_options.get( 87 'window_mapping_fn', _global_window_mapping_fn) 88 self._view_class = view_class 89 self._view_options = view_options 90 self._iterable = iterable 91 self._cache = {} # type: Dict[window.BoundedWindow, Any] 92 93 def __getitem__(self, window): 94 # type: (window.BoundedWindow) -> Any 95 if window not in self._cache: 96 target_window = self._window_mapping_fn(window) 97 self._cache[window] = self._view_class._from_runtime_iterable( 98 _FilteringIterable(self._iterable, target_window), self._view_options) 99 return self._cache[window] 100 101 def is_globally_windowed(self): 102 # type: () -> bool 103 return self._window_mapping_fn == _global_window_mapping_fn 104 105 106 class _FilteringIterable(object): 107 """An iterable containing only those values in the given window. 108 """ 109 def __init__(self, iterable, target_window): 110 self._iterable = iterable 111 self._target_window = target_window 112 113 def __iter__(self): 114 for wv in self._iterable: 115 if self._target_window in wv.windows: 116 yield wv.value 117 118 def __reduce__(self): 119 # Pickle self as an already filtered list. 120 return list, (list(self), )