github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/direct_userstate.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Support for user state in the BundleBasedDirectRunner.""" 19 # pytype: skip-file 20 21 import copy 22 import itertools 23 24 from apache_beam.transforms import userstate 25 from apache_beam.transforms.trigger import _ListStateTag 26 from apache_beam.transforms.trigger import _ReadModifyWriteStateTag 27 from apache_beam.transforms.trigger import _SetStateTag 28 29 30 class DirectRuntimeState(userstate.RuntimeState): 31 def __init__(self, state_spec, state_tag, current_value_accessor): 32 self._state_spec = state_spec 33 self._state_tag = state_tag 34 self._current_value_accessor = current_value_accessor 35 36 @staticmethod 37 def for_spec(state_spec, state_tag, current_value_accessor): 38 if isinstance(state_spec, userstate.ReadModifyWriteStateSpec): 39 return ReadModifyWriteRuntimeState( 40 state_spec, state_tag, current_value_accessor) 41 elif isinstance(state_spec, userstate.BagStateSpec): 42 return BagRuntimeState(state_spec, state_tag, current_value_accessor) 43 elif isinstance(state_spec, userstate.CombiningValueStateSpec): 44 return CombiningValueRuntimeState( 45 state_spec, state_tag, current_value_accessor) 46 elif isinstance(state_spec, userstate.SetStateSpec): 47 return SetRuntimeState(state_spec, state_tag, current_value_accessor) 48 else: 49 raise ValueError('Invalid state spec: %s' % state_spec) 50 51 def _encode(self, value): 52 return self._state_spec.coder.encode(value) 53 54 def _decode(self, value): 55 return self._state_spec.coder.decode(value) 56 57 58 # Sentinel designating an unread value. 59 UNREAD_VALUE = object() 60 61 62 class ReadModifyWriteRuntimeState(DirectRuntimeState, 63 userstate.ReadModifyWriteRuntimeState): 64 def __init__(self, state_spec, state_tag, current_value_accessor): 65 super().__init__(state_spec, state_tag, current_value_accessor) 66 self._value = UNREAD_VALUE 67 self._cleared = False 68 self._modified = False 69 70 def read(self): 71 if self._cleared: 72 return None 73 if self._value is UNREAD_VALUE: 74 self._value = self._current_value_accessor() 75 if not self._value: 76 return None 77 return self._decode(self._value[0]) 78 79 def write(self, value): 80 self._cleared = False 81 self._modified = True 82 self._value = [self._encode(value)] 83 84 def clear(self): 85 self._cleared = True 86 self._modified = False 87 self._value = [] 88 89 def is_cleared(self): 90 return self._cleared 91 92 def is_modified(self): 93 return self._modified 94 95 96 class BagRuntimeState(DirectRuntimeState, userstate.BagRuntimeState): 97 def __init__(self, state_spec, state_tag, current_value_accessor): 98 super().__init__(state_spec, state_tag, current_value_accessor) 99 self._cached_value = UNREAD_VALUE 100 self._cleared = False 101 self._new_values = [] 102 103 def read(self): 104 if self._cached_value is UNREAD_VALUE: 105 self._cached_value = self._current_value_accessor() 106 if not self._cleared: 107 encoded_values = itertools.chain(self._cached_value, self._new_values) 108 else: 109 encoded_values = self._new_values 110 return (self._decode(v) for v in encoded_values) 111 112 def add(self, value): 113 self._new_values.append(self._encode(value)) 114 115 def clear(self): 116 self._cleared = True 117 self._cached_value = [] 118 self._new_values = [] 119 120 121 class SetRuntimeState(DirectRuntimeState, userstate.SetRuntimeState): 122 def __init__(self, state_spec, state_tag, current_value_accessor): 123 super().__init__(state_spec, state_tag, current_value_accessor) 124 self._current_accumulator = UNREAD_VALUE 125 self._modified = False 126 127 def _read_initial_value(self): 128 if self._current_accumulator is UNREAD_VALUE: 129 self._current_accumulator = { 130 self._decode(a) 131 for a in self._current_value_accessor() 132 } 133 134 def read(self): 135 self._read_initial_value() 136 return self._current_accumulator 137 138 def add(self, value): 139 self._read_initial_value() 140 self._modified = True 141 self._current_accumulator.add(value) 142 143 def clear(self): 144 self._current_accumulator = set() 145 self._modified = True 146 147 def is_modified(self): 148 return self._modified 149 150 151 class CombiningValueRuntimeState(DirectRuntimeState, 152 userstate.CombiningValueRuntimeState): 153 """Combining value state interface object passed to user code.""" 154 def __init__(self, state_spec, state_tag, current_value_accessor): 155 super().__init__(state_spec, state_tag, current_value_accessor) 156 self._current_accumulator = UNREAD_VALUE 157 self._modified = False 158 self._combine_fn = copy.deepcopy(state_spec.combine_fn) 159 self._combine_fn.setup() 160 self._finalized = False 161 162 def _read_initial_value(self): 163 if self._current_accumulator is UNREAD_VALUE: 164 existing_accumulators = list( 165 self._decode(a) for a in self._current_value_accessor()) 166 if existing_accumulators: 167 self._current_accumulator = self._combine_fn.merge_accumulators( 168 existing_accumulators) 169 else: 170 self._current_accumulator = self._combine_fn.create_accumulator() 171 172 def read(self): 173 self._read_initial_value() 174 return self._combine_fn.extract_output(self._current_accumulator) 175 176 def add(self, value): 177 self._read_initial_value() 178 self._modified = True 179 self._current_accumulator = self._combine_fn.add_input( 180 self._current_accumulator, value) 181 182 def clear(self): 183 self._modified = True 184 self._current_accumulator = self._combine_fn.create_accumulator() 185 186 def finalize(self): 187 if not self._finalized: 188 self._combine_fn.teardown() 189 self._finalized = True 190 191 192 class DirectUserStateContext(userstate.UserStateContext): 193 """userstate.UserStateContext for the BundleBasedDirectRunner. 194 195 The DirectUserStateContext buffers up updates that are to be committed 196 by the TransformEvaluator after running a DoFn. 197 """ 198 def __init__(self, step_context, dofn, key_coder): 199 self.step_context = step_context 200 self.dofn = dofn 201 self.key_coder = key_coder 202 203 self.all_state_specs, self.all_timer_specs = userstate.get_dofn_specs(dofn) 204 self.state_tags = {} 205 for state_spec in self.all_state_specs: 206 state_key = 'user/%s' % state_spec.name 207 if isinstance(state_spec, userstate.ReadModifyWriteStateSpec): 208 state_tag = _ReadModifyWriteStateTag(state_key) 209 elif isinstance(state_spec, userstate.BagStateSpec): 210 state_tag = _ListStateTag(state_key) 211 elif isinstance(state_spec, userstate.CombiningValueStateSpec): 212 state_tag = _ListStateTag(state_key) 213 elif isinstance(state_spec, userstate.SetStateSpec): 214 state_tag = _SetStateTag(state_key) 215 else: 216 raise ValueError('Invalid state spec: %s' % state_spec) 217 self.state_tags[state_spec] = state_tag 218 219 self.cached_states = {} 220 self.cached_timers = {} 221 222 def get_timer( 223 self, timer_spec: userstate.TimerSpec, key, window, timestamp, 224 pane) -> userstate.RuntimeTimer: 225 assert timer_spec in self.all_timer_specs 226 encoded_key = self.key_coder.encode(key) 227 cache_key = (encoded_key, window, timer_spec) 228 if cache_key not in self.cached_timers: 229 self.cached_timers[cache_key] = userstate.RuntimeTimer() 230 return self.cached_timers[cache_key] 231 232 def get_state(self, state_spec, key, window): 233 assert state_spec in self.all_state_specs 234 encoded_key = self.key_coder.encode(key) 235 cache_key = (encoded_key, window, state_spec) 236 if cache_key not in self.cached_states: 237 state_tag = self.state_tags[state_spec] 238 value_accessor = ( 239 lambda: self._get_underlying_state(state_spec, key, window)) 240 self.cached_states[cache_key] = DirectRuntimeState.for_spec( 241 state_spec, state_tag, value_accessor) 242 return self.cached_states[cache_key] 243 244 def _get_underlying_state(self, state_spec, key, window): 245 state_tag = self.state_tags[state_spec] 246 encoded_key = self.key_coder.encode(key) 247 return ( 248 self.step_context.get_keyed_state(encoded_key).get_state( 249 window, state_tag)) 250 251 def commit(self): 252 # Commit state modifications. 253 for cache_key, runtime_state in self.cached_states.items(): 254 encoded_key, window, state_spec = cache_key 255 state = self.step_context.get_keyed_state(encoded_key) 256 state_tag = self.state_tags[state_spec] 257 if isinstance(state_spec, userstate.BagStateSpec): 258 if runtime_state._cleared: 259 state.clear_state(window, state_tag) 260 for new_value in runtime_state._new_values: 261 state.add_state(window, state_tag, new_value) 262 elif isinstance(state_spec, userstate.CombiningValueStateSpec): 263 if runtime_state._modified: 264 state.clear_state(window, state_tag) 265 state.add_state( 266 window, 267 state_tag, 268 state_spec.coder.encode(runtime_state._current_accumulator)) 269 elif isinstance(state_spec, userstate.SetStateSpec): 270 if runtime_state.is_modified(): 271 state.clear_state(window, state_tag) 272 for new_value in runtime_state._current_accumulator: 273 state.add_state( 274 window, state_tag, state_spec.coder.encode(new_value)) 275 elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec): 276 if runtime_state.is_cleared(): 277 state.clear_state(window, state_tag) 278 if runtime_state.is_modified(): 279 state.clear_state(window, state_tag) 280 state.add_state(window, state_tag, runtime_state._value) 281 else: 282 raise ValueError('Invalid state spec: %s' % state_spec) 283 284 # Commit new timers. 285 for cache_key, runtime_timer in self.cached_timers.items(): 286 encoded_key, window, timer_spec = cache_key 287 state = self.step_context.get_keyed_state(encoded_key) 288 timer_name = 'user/%s' % timer_spec.name 289 for dynamic_timer_tag, timer in runtime_timer._timer_recordings.items(): 290 if timer.cleared: 291 state.clear_timer( 292 window, 293 timer_name, 294 timer_spec.time_domain, 295 dynamic_timer_tag=dynamic_timer_tag) 296 if timer.timestamp: 297 # TODO(ccy): add corresponding watermark holds after the DirectRunner 298 # allows for keyed watermark holds. 299 state.set_timer( 300 window, 301 timer_name, 302 timer_spec.time_domain, 303 timer.timestamp, 304 dynamic_timer_tag=dynamic_timer_tag) 305 306 def reset(self): 307 for state in self.cached_states.values(): 308 state.finalize() 309 self.cached_states = {} 310 self.cached_timers = {}