github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/bundle_processor.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """SDK harness for executing Python Fns via the Fn API.""" 19 20 # pytype: skip-file 21 22 import base64 23 import bisect 24 import collections 25 import copy 26 import json 27 import logging 28 import random 29 import threading 30 from typing import TYPE_CHECKING 31 from typing import Any 32 from typing import Callable 33 from typing import Container 34 from typing import DefaultDict 35 from typing import Dict 36 from typing import FrozenSet 37 from typing import Iterable 38 from typing import Iterator 39 from typing import List 40 from typing import Mapping 41 from typing import Optional 42 from typing import Set 43 from typing import Tuple 44 from typing import Type 45 from typing import TypeVar 46 from typing import Union 47 from typing import cast 48 49 from google.protobuf import duration_pb2 50 from google.protobuf import timestamp_pb2 51 52 import apache_beam as beam 53 from apache_beam import coders 54 from apache_beam.coders import WindowedValueCoder 55 from apache_beam.coders import coder_impl 56 from apache_beam.internal import pickler 57 from apache_beam.io import iobase 58 from apache_beam.metrics import monitoring_infos 59 from apache_beam.portability import common_urns 60 from apache_beam.portability import python_urns 61 from apache_beam.portability.api import beam_fn_api_pb2 62 from apache_beam.portability.api import beam_runner_api_pb2 63 from apache_beam.runners import common 64 from apache_beam.runners import pipeline_context 65 from apache_beam.runners.worker import data_sampler 66 from apache_beam.runners.worker import operation_specs 67 from apache_beam.runners.worker import operations 68 from apache_beam.runners.worker import statesampler 69 from apache_beam.runners.worker.data_sampler import OutputSampler 70 from apache_beam.transforms import TimeDomain 71 from apache_beam.transforms import core 72 from apache_beam.transforms import environments 73 from apache_beam.transforms import sideinputs 74 from apache_beam.transforms import userstate 75 from apache_beam.transforms import window 76 from apache_beam.utils import counters 77 from apache_beam.utils import proto_utils 78 from apache_beam.utils import timestamp 79 from apache_beam.utils.windowed_value import WindowedValue 80 81 if TYPE_CHECKING: 82 from google.protobuf import message # pylint: disable=ungrouped-imports 83 from apache_beam import pvalue 84 from apache_beam.portability.api import metrics_pb2 85 from apache_beam.runners.sdf_utils import SplitResultPrimary 86 from apache_beam.runners.sdf_utils import SplitResultResidual 87 from apache_beam.runners.worker import data_plane 88 from apache_beam.runners.worker import sdk_worker 89 from apache_beam.transforms.core import Windowing 90 from apache_beam.transforms.window import BoundedWindow 91 from apache_beam.utils import windowed_value 92 93 T = TypeVar('T') 94 ConstructorFn = Callable[[ 95 'BeamTransformFactory', 96 Any, 97 beam_runner_api_pb2.PTransform, 98 Union['message.Message', bytes], 99 Dict[str, List[operations.Operation]] 100 ], 101 operations.Operation] 102 OperationT = TypeVar('OperationT', bound=operations.Operation) 103 FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState', 104 'CombiningValueRuntimeState', 105 'SynchronousSetRuntimeState', 106 'SynchronousBagRuntimeState'] 107 108 DATA_INPUT_URN = 'beam:runner:source:v1' 109 DATA_OUTPUT_URN = 'beam:runner:sink:v1' 110 SYNTHETIC_DATA_SAMPLING_URN = 'beam:internal:sampling:v1' 111 IDENTITY_DOFN_URN = 'beam:dofn:identity:0.1' 112 # TODO(vikasrk): Fix this once runner sends appropriate common_urns. 113 OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN = 'beam:dofn:javasdk:0.1' 114 OLD_DATAFLOW_RUNNER_HARNESS_READ_URN = 'beam:source:java:0.1' 115 URNS_NEEDING_PCOLLECTIONS = set([ 116 monitoring_infos.ELEMENT_COUNT_URN, monitoring_infos.SAMPLED_BYTE_SIZE_URN 117 ]) 118 119 _LOGGER = logging.getLogger(__name__) 120 121 122 class RunnerIOOperation(operations.Operation): 123 """Common baseclass for runner harness IO operations.""" 124 125 def __init__(self, 126 name_context, # type: common.NameContext 127 step_name, # type: Any 128 consumers, # type: Mapping[Any, Iterable[operations.Operation]] 129 counter_factory, # type: counters.CounterFactory 130 state_sampler, # type: statesampler.StateSampler 131 windowed_coder, # type: coders.Coder 132 transform_id, # type: str 133 data_channel # type: data_plane.DataChannel 134 ): 135 # type: (...) -> None 136 super().__init__(name_context, None, counter_factory, state_sampler) 137 self.windowed_coder = windowed_coder 138 self.windowed_coder_impl = windowed_coder.get_impl() 139 # transform_id represents the consumer for the bytes in the data plane for a 140 # DataInputOperation or a producer of these bytes for a DataOutputOperation. 141 self.transform_id = transform_id 142 self.data_channel = data_channel 143 for _, consumer_ops in consumers.items(): 144 for consumer in consumer_ops: 145 self.add_receiver(consumer, 0) 146 147 148 class DataOutputOperation(RunnerIOOperation): 149 """A sink-like operation that gathers outputs to be sent back to the runner. 150 """ 151 def set_output_stream(self, output_stream): 152 # type: (data_plane.ClosableOutputStream) -> None 153 self.output_stream = output_stream 154 155 def process(self, windowed_value): 156 # type: (windowed_value.WindowedValue) -> None 157 self.windowed_coder_impl.encode_to_stream( 158 windowed_value, self.output_stream, True) 159 self.output_stream.maybe_flush() 160 161 def finish(self): 162 # type: () -> None 163 super().finish() 164 self.output_stream.close() 165 166 167 class DataInputOperation(RunnerIOOperation): 168 """A source-like operation that gathers input from the runner.""" 169 170 def __init__(self, 171 operation_name, # type: common.NameContext 172 step_name, 173 consumers, # type: Mapping[Any, List[operations.Operation]] 174 counter_factory, # type: counters.CounterFactory 175 state_sampler, # type: statesampler.StateSampler 176 windowed_coder, # type: coders.Coder 177 transform_id, 178 data_channel # type: data_plane.GrpcClientDataChannel 179 ): 180 # type: (...) -> None 181 super().__init__( 182 operation_name, 183 step_name, 184 consumers, 185 counter_factory, 186 state_sampler, 187 windowed_coder, 188 transform_id=transform_id, 189 data_channel=data_channel) 190 191 self.consumer = next(iter(consumers.values())) 192 self.splitting_lock = threading.Lock() 193 self.index = -1 194 self.stop = float('inf') 195 self.started = False 196 197 def setup(self): 198 super().setup() 199 # We must do this manually as we don't have a spec or spec.output_coders. 200 self.receivers = [ 201 operations.ConsumerSet.create( 202 counter_factory=self.counter_factory, 203 step_name=self.name_context.step_name, 204 output_index=0, 205 consumers=self.consumer, 206 coder=self.windowed_coder, 207 producer_type_hints=self._get_runtime_performance_hints(), 208 producer_batch_converter=self.get_output_batch_converter()) 209 ] 210 211 def start(self): 212 # type: () -> None 213 super().start() 214 with self.splitting_lock: 215 self.started = True 216 217 def process(self, windowed_value): 218 # type: (windowed_value.WindowedValue) -> None 219 self.output(windowed_value) 220 221 def process_encoded(self, encoded_windowed_values): 222 # type: (bytes) -> None 223 input_stream = coder_impl.create_InputStream(encoded_windowed_values) 224 while input_stream.size() > 0: 225 with self.splitting_lock: 226 if self.index == self.stop - 1: 227 return 228 self.index += 1 229 decoded_value = self.windowed_coder_impl.decode_from_stream( 230 input_stream, True) 231 self.output(decoded_value) 232 233 def monitoring_infos(self, transform_id, tag_to_pcollection_id): 234 # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 235 all_monitoring_infos = super().monitoring_infos( 236 transform_id, tag_to_pcollection_id) 237 read_progress_info = monitoring_infos.int64_counter( 238 monitoring_infos.DATA_CHANNEL_READ_INDEX, 239 self.index, 240 ptransform=transform_id) 241 all_monitoring_infos[monitoring_infos.to_key( 242 read_progress_info)] = read_progress_info 243 return all_monitoring_infos 244 245 # TODO(https://github.com/apache/beam/issues/19737): typing not compatible 246 # with super type 247 def try_split( # type: ignore[override] 248 self, fraction_of_remainder, total_buffer_size, allowed_split_points): 249 # type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]] 250 with self.splitting_lock: 251 if not self.started: 252 return None 253 if self.index == -1: 254 # We are "finished" with the (non-existent) previous element. 255 current_element_progress = 1.0 256 else: 257 current_element_progress_object = ( 258 self.receivers[0].current_element_progress()) 259 if current_element_progress_object is None: 260 current_element_progress = 0.5 261 else: 262 current_element_progress = ( 263 current_element_progress_object.fraction_completed) 264 # Now figure out where to split. 265 split = self._compute_split( 266 self.index, 267 current_element_progress, 268 self.stop, 269 fraction_of_remainder, 270 total_buffer_size, 271 allowed_split_points, 272 self.receivers[0].try_split) 273 if split: 274 self.stop = split[-1] 275 return split 276 277 @staticmethod 278 def _compute_split( 279 index, 280 current_element_progress, 281 stop, 282 fraction_of_remainder, 283 total_buffer_size, 284 allowed_split_points=(), 285 try_split=lambda fraction: None): 286 def is_valid_split_point(index): 287 return not allowed_split_points or index in allowed_split_points 288 289 if total_buffer_size < index + 1: 290 total_buffer_size = index + 1 291 elif total_buffer_size > stop: 292 total_buffer_size = stop 293 # The units here (except for keep_of_element_remainder) are all in 294 # terms of number of (possibly fractional) elements. 295 remainder = total_buffer_size - index - current_element_progress 296 keep = remainder * fraction_of_remainder 297 if current_element_progress < 1: 298 keep_of_element_remainder = keep / (1 - current_element_progress) 299 # If it's less than what's left of the current element, 300 # try splitting at the current element. 301 if (keep_of_element_remainder < 1 and is_valid_split_point(index) and 302 is_valid_split_point(index + 1)): 303 split = try_split( 304 keep_of_element_remainder 305 ) # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]] 306 if split: 307 element_primaries, element_residuals = split 308 return index - 1, element_primaries, element_residuals, index + 1 309 # Otherwise, split at the closest element boundary. 310 # pylint: disable=bad-option-value 311 stop_index = index + max(1, int(round(current_element_progress + keep))) 312 if allowed_split_points and stop_index not in allowed_split_points: 313 # Choose the closest allowed split point. 314 allowed_split_points = sorted(allowed_split_points) 315 closest = bisect.bisect(allowed_split_points, stop_index) 316 if closest == 0: 317 stop_index = allowed_split_points[0] 318 elif closest == len(allowed_split_points): 319 stop_index = allowed_split_points[-1] 320 else: 321 prev = allowed_split_points[closest - 1] 322 next = allowed_split_points[closest] 323 if index < prev and stop_index - prev < next - stop_index: 324 stop_index = prev 325 else: 326 stop_index = next 327 if index < stop_index < stop: 328 return stop_index - 1, [], [], stop_index 329 else: 330 return None 331 332 def finish(self): 333 # type: () -> None 334 super().finish() 335 with self.splitting_lock: 336 self.index += 1 337 self.started = False 338 339 def reset(self): 340 # type: () -> None 341 with self.splitting_lock: 342 self.index = -1 343 self.stop = float('inf') 344 super().reset() 345 346 347 class _StateBackedIterable(object): 348 def __init__(self, 349 state_handler, # type: sdk_worker.CachingStateHandler 350 state_key, # type: beam_fn_api_pb2.StateKey 351 coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl] 352 ): 353 # type: (...) -> None 354 self._state_handler = state_handler 355 self._state_key = state_key 356 if isinstance(coder_or_impl, coders.Coder): 357 self._coder_impl = coder_or_impl.get_impl() 358 else: 359 self._coder_impl = coder_or_impl 360 361 def __iter__(self): 362 # type: () -> Iterator[Any] 363 return iter( 364 self._state_handler.blocking_get(self._state_key, self._coder_impl)) 365 366 def __reduce__(self): 367 return list, (list(self), ) 368 369 370 coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type( 371 _StateBackedIterable) 372 373 374 class StateBackedSideInputMap(object): 375 def __init__(self, 376 state_handler, # type: sdk_worker.CachingStateHandler 377 transform_id, # type: str 378 tag, # type: Optional[str] 379 side_input_data, # type: pvalue.SideInputData 380 coder # type: WindowedValueCoder 381 ): 382 # type: (...) -> None 383 self._state_handler = state_handler 384 self._transform_id = transform_id 385 self._tag = tag 386 self._side_input_data = side_input_data 387 self._element_coder = coder.wrapped_value_coder 388 self._target_window_coder = coder.window_coder 389 # TODO(robertwb): Limit the cache size. 390 self._cache = {} # type: Dict[BoundedWindow, Any] 391 392 def __getitem__(self, window): 393 target_window = self._side_input_data.window_mapping_fn(window) 394 if target_window not in self._cache: 395 state_handler = self._state_handler 396 access_pattern = self._side_input_data.access_pattern 397 398 if access_pattern == common_urns.side_inputs.ITERABLE.urn: 399 state_key = beam_fn_api_pb2.StateKey( 400 iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( 401 transform_id=self._transform_id, 402 side_input_id=self._tag, 403 window=self._target_window_coder.encode(target_window))) 404 raw_view = _StateBackedIterable( 405 state_handler, state_key, self._element_coder) 406 407 elif access_pattern == common_urns.side_inputs.MULTIMAP.urn: 408 state_key = beam_fn_api_pb2.StateKey( 409 multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( 410 transform_id=self._transform_id, 411 side_input_id=self._tag, 412 window=self._target_window_coder.encode(target_window), 413 key=b'')) 414 cache = {} 415 key_coder_impl = self._element_coder.key_coder().get_impl() 416 value_coder = self._element_coder.value_coder() 417 418 class MultiMap(object): 419 def __getitem__(self, key): 420 if key not in cache: 421 keyed_state_key = beam_fn_api_pb2.StateKey() 422 keyed_state_key.CopyFrom(state_key) 423 keyed_state_key.multimap_side_input.key = ( 424 key_coder_impl.encode_nested(key)) 425 cache[key] = _StateBackedIterable( 426 state_handler, keyed_state_key, value_coder) 427 return cache[key] 428 429 def __reduce__(self): 430 # TODO(robertwb): Figure out how to support this. 431 raise TypeError(common_urns.side_inputs.MULTIMAP.urn) 432 433 raw_view = MultiMap() 434 435 else: 436 raise ValueError("Unknown access pattern: '%s'" % access_pattern) 437 438 self._cache[target_window] = self._side_input_data.view_fn(raw_view) 439 return self._cache[target_window] 440 441 def is_globally_windowed(self): 442 # type: () -> bool 443 return ( 444 self._side_input_data.window_mapping_fn == 445 sideinputs._global_window_mapping_fn) 446 447 def reset(self): 448 # type: () -> None 449 # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. 450 self._cache = {} 451 452 453 class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): 454 def __init__(self, underlying_bag_state): 455 self._underlying_bag_state = underlying_bag_state 456 457 def read(self): # type: () -> Any 458 values = list(self._underlying_bag_state.read()) 459 if not values: 460 return None 461 return values[0] 462 463 def write(self, value): # type: (Any) -> None 464 self.clear() 465 self._underlying_bag_state.add(value) 466 467 def clear(self): # type: () -> None 468 self._underlying_bag_state.clear() 469 470 def commit(self): # type: () -> None 471 self._underlying_bag_state.commit() 472 473 474 class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): 475 def __init__(self, underlying_bag_state, combinefn): 476 # type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None 477 self._combinefn = combinefn 478 self._combinefn.setup() 479 self._underlying_bag_state = underlying_bag_state 480 self._finalized = False 481 482 def _read_accumulator(self, rewrite=True): 483 merged_accumulator = self._combinefn.merge_accumulators( 484 self._underlying_bag_state.read()) 485 if rewrite: 486 self._underlying_bag_state.clear() 487 self._underlying_bag_state.add(merged_accumulator) 488 return merged_accumulator 489 490 def read(self): 491 # type: () -> Iterable[Any] 492 return self._combinefn.extract_output(self._read_accumulator()) 493 494 def add(self, value): 495 # type: (Any) -> None 496 # Prefer blind writes, but don't let them grow unboundedly. 497 # This should be tuned to be much lower, but for now exercise 498 # both paths well. 499 if random.random() < 0.5: 500 accumulator = self._read_accumulator(False) 501 self._underlying_bag_state.clear() 502 else: 503 accumulator = self._combinefn.create_accumulator() 504 self._underlying_bag_state.add( 505 self._combinefn.add_input(accumulator, value)) 506 507 def clear(self): 508 # type: () -> None 509 self._underlying_bag_state.clear() 510 511 def commit(self): 512 self._underlying_bag_state.commit() 513 514 def finalize(self): 515 if not self._finalized: 516 self._combinefn.teardown() 517 self._finalized = True 518 519 520 class _ConcatIterable(object): 521 """An iterable that is the concatination of two iterables. 522 523 Unlike itertools.chain, this allows reiteration. 524 """ 525 def __init__(self, first, second): 526 # type: (Iterable[Any], Iterable[Any]) -> None 527 self.first = first 528 self.second = second 529 530 def __iter__(self): 531 # type: () -> Iterator[Any] 532 for elem in self.first: 533 yield elem 534 for elem in self.second: 535 yield elem 536 537 538 coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(_ConcatIterable) 539 540 541 class SynchronousBagRuntimeState(userstate.BagRuntimeState): 542 543 def __init__(self, 544 state_handler, # type: sdk_worker.CachingStateHandler 545 state_key, # type: beam_fn_api_pb2.StateKey 546 value_coder # type: coders.Coder 547 ): 548 # type: (...) -> None 549 self._state_handler = state_handler 550 self._state_key = state_key 551 self._value_coder = value_coder 552 self._cleared = False 553 self._added_elements = [] # type: List[Any] 554 555 def read(self): 556 # type: () -> Iterable[Any] 557 return _ConcatIterable([] if self._cleared else cast( 558 'Iterable[Any]', 559 _StateBackedIterable( 560 self._state_handler, self._state_key, self._value_coder)), 561 self._added_elements) 562 563 def add(self, value): 564 # type: (Any) -> None 565 self._added_elements.append(value) 566 567 def clear(self): 568 # type: () -> None 569 self._cleared = True 570 self._added_elements = [] 571 572 def commit(self): 573 # type: () -> None 574 to_await = None 575 if self._cleared: 576 to_await = self._state_handler.clear(self._state_key) 577 if self._added_elements: 578 to_await = self._state_handler.extend( 579 self._state_key, self._value_coder.get_impl(), self._added_elements) 580 if to_await: 581 # To commit, we need to wait on the last state request future to complete. 582 to_await.get() 583 584 585 class SynchronousSetRuntimeState(userstate.SetRuntimeState): 586 587 def __init__(self, 588 state_handler, # type: sdk_worker.CachingStateHandler 589 state_key, # type: beam_fn_api_pb2.StateKey 590 value_coder # type: coders.Coder 591 ): 592 # type: (...) -> None 593 self._state_handler = state_handler 594 self._state_key = state_key 595 self._value_coder = value_coder 596 self._cleared = False 597 self._added_elements = set() # type: Set[Any] 598 599 def _compact_data(self, rewrite=True): 600 accumulator = set( 601 _ConcatIterable( 602 set() if self._cleared else _StateBackedIterable( 603 self._state_handler, self._state_key, self._value_coder), 604 self._added_elements)) 605 606 if rewrite and accumulator: 607 self._state_handler.clear(self._state_key) 608 self._state_handler.extend( 609 self._state_key, self._value_coder.get_impl(), accumulator) 610 611 # Since everthing is already committed so we can safely reinitialize 612 # added_elements here. 613 self._added_elements = set() 614 615 return accumulator 616 617 def read(self): 618 # type: () -> Set[Any] 619 return self._compact_data(rewrite=False) 620 621 def add(self, value): 622 # type: (Any) -> None 623 if self._cleared: 624 # This is a good time explicitly clear. 625 self._state_handler.clear(self._state_key) 626 self._cleared = False 627 628 self._added_elements.add(value) 629 if random.random() > 0.5: 630 self._compact_data() 631 632 def clear(self): 633 # type: () -> None 634 self._cleared = True 635 self._added_elements = set() 636 637 def commit(self): 638 # type: () -> None 639 to_await = None 640 if self._cleared: 641 to_await = self._state_handler.clear(self._state_key) 642 if self._added_elements: 643 to_await = self._state_handler.extend( 644 self._state_key, self._value_coder.get_impl(), self._added_elements) 645 if to_await: 646 # To commit, we need to wait on the last state request future to complete. 647 to_await.get() 648 649 650 class OutputTimer(userstate.BaseTimer): 651 def __init__(self, 652 key, 653 window, # type: BoundedWindow 654 timestamp, # type: timestamp.Timestamp 655 paneinfo, # type: windowed_value.PaneInfo 656 time_domain, # type: str 657 timer_family_id, # type: str 658 timer_coder_impl, # type: coder_impl.TimerCoderImpl 659 output_stream # type: data_plane.ClosableOutputStream 660 ): 661 self._key = key 662 self._window = window 663 self._input_timestamp = timestamp 664 self._paneinfo = paneinfo 665 self._time_domain = time_domain 666 self._timer_family_id = timer_family_id 667 self._output_stream = output_stream 668 self._timer_coder_impl = timer_coder_impl 669 670 def set(self, ts: timestamp.TimestampTypes, dynamic_timer_tag='') -> None: 671 ts = timestamp.Timestamp.of(ts) 672 timer = userstate.Timer( 673 user_key=self._key, 674 dynamic_timer_tag=dynamic_timer_tag, 675 windows=(self._window, ), 676 clear_bit=False, 677 fire_timestamp=ts, 678 hold_timestamp=ts if TimeDomain.is_event_time(self._time_domain) else 679 self._input_timestamp, 680 paneinfo=self._paneinfo) 681 self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True) 682 self._output_stream.maybe_flush() 683 684 def clear(self, dynamic_timer_tag='') -> None: 685 timer = userstate.Timer( 686 user_key=self._key, 687 dynamic_timer_tag=dynamic_timer_tag, 688 windows=(self._window, ), 689 clear_bit=True, 690 fire_timestamp=None, 691 hold_timestamp=None, 692 paneinfo=None) 693 self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True) 694 self._output_stream.maybe_flush() 695 696 697 class TimerInfo(object): 698 """A data class to store information related to a timer.""" 699 def __init__(self, timer_coder_impl, output_stream=None): 700 self.timer_coder_impl = timer_coder_impl 701 self.output_stream = output_stream 702 703 704 class FnApiUserStateContext(userstate.UserStateContext): 705 """Interface for state and timers from SDK to Fn API servicer of state..""" 706 707 def __init__(self, 708 state_handler, # type: sdk_worker.CachingStateHandler 709 transform_id, # type: str 710 key_coder, # type: coders.Coder 711 window_coder, # type: coders.Coder 712 ): 713 # type: (...) -> None 714 715 """Initialize a ``FnApiUserStateContext``. 716 717 Args: 718 state_handler: A StateServicer object. 719 transform_id: The name of the PTransform that this context is associated. 720 key_coder: Coder for the key type. 721 window_coder: Coder for the window type. 722 """ 723 self._state_handler = state_handler 724 self._transform_id = transform_id 725 self._key_coder = key_coder 726 self._window_coder = window_coder 727 # A mapping of {timer_family_id: TimerInfo} 728 self._timers_info = {} # type: Dict[str, TimerInfo] 729 self._all_states = {} # type: Dict[tuple, FnApiUserRuntimeStateTypes] 730 731 def add_timer_info(self, timer_family_id, timer_info): 732 # type: (str, TimerInfo) -> None 733 self._timers_info[timer_family_id] = timer_info 734 735 def get_timer( 736 self, timer_spec: userstate.TimerSpec, key, window, timestamp, 737 pane) -> OutputTimer: 738 assert self._timers_info[timer_spec.name].output_stream is not None 739 timer_coder_impl = self._timers_info[timer_spec.name].timer_coder_impl 740 output_stream = self._timers_info[timer_spec.name].output_stream 741 return OutputTimer( 742 key, 743 window, 744 timestamp, 745 pane, 746 timer_spec.time_domain, 747 timer_spec.name, 748 timer_coder_impl, 749 output_stream) 750 751 def get_state(self, *args): 752 # type: (*Any) -> FnApiUserRuntimeStateTypes 753 state_handle = self._all_states.get(args) 754 if state_handle is None: 755 state_handle = self._all_states[args] = self._create_state(*args) 756 return state_handle 757 758 def _create_state(self, 759 state_spec, # type: userstate.StateSpec 760 key, 761 window # type: BoundedWindow 762 ): 763 # type: (...) -> FnApiUserRuntimeStateTypes 764 if isinstance(state_spec, 765 (userstate.BagStateSpec, 766 userstate.CombiningValueStateSpec, 767 userstate.ReadModifyWriteStateSpec)): 768 bag_state = SynchronousBagRuntimeState( 769 self._state_handler, 770 state_key=beam_fn_api_pb2.StateKey( 771 bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( 772 transform_id=self._transform_id, 773 user_state_id=state_spec.name, 774 window=self._window_coder.encode(window), 775 # State keys are expected in nested encoding format 776 key=self._key_coder.encode_nested(key))), 777 value_coder=state_spec.coder) 778 if isinstance(state_spec, userstate.BagStateSpec): 779 return bag_state 780 elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec): 781 return ReadModifyWriteRuntimeState(bag_state) 782 else: 783 return CombiningValueRuntimeState( 784 bag_state, copy.deepcopy(state_spec.combine_fn)) 785 elif isinstance(state_spec, userstate.SetStateSpec): 786 return SynchronousSetRuntimeState( 787 self._state_handler, 788 state_key=beam_fn_api_pb2.StateKey( 789 bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( 790 transform_id=self._transform_id, 791 user_state_id=state_spec.name, 792 window=self._window_coder.encode(window), 793 # State keys are expected in nested encoding format 794 key=self._key_coder.encode_nested(key))), 795 value_coder=state_spec.coder) 796 else: 797 raise NotImplementedError(state_spec) 798 799 def commit(self): 800 # type: () -> None 801 for state in self._all_states.values(): 802 state.commit() 803 804 def reset(self): 805 # type: () -> None 806 for state in self._all_states.values(): 807 state.finalize() 808 self._all_states = {} 809 810 811 def memoize(func): 812 cache = {} 813 missing = object() 814 815 def wrapper(*args): 816 result = cache.get(args, missing) 817 if result is missing: 818 result = cache[args] = func(*args) 819 return result 820 821 return wrapper 822 823 824 def only_element(iterable): 825 # type: (Iterable[T]) -> T 826 element, = iterable 827 return element 828 829 830 def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): 831 # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None 832 833 runtime_sdk = environments.sdk_base_version_capability() 834 for t in process_bundle_descriptor.transforms.values(): 835 env = process_bundle_descriptor.environments[t.environment_id] 836 for c in env.capabilities: 837 if (c.startswith(environments.SDK_VERSION_CAPABILITY_PREFIX) and 838 c != runtime_sdk): 839 raise RuntimeError( 840 "Pipeline construction environment and pipeline runtime " 841 "environment are not compatible. If you use a custom " 842 "container image, check that the Python interpreter minor version " 843 "and the Apache Beam version in your image match the versions " 844 "used at pipeline construction time. " 845 f"Submission environment: {c}. " 846 f"Runtime environment: {runtime_sdk}.") 847 848 # TODO: Consider warning on mismatches in versions of installed packages. 849 850 851 class BundleProcessor(object): 852 """ A class for processing bundles of elements. """ 853 854 def __init__(self, 855 process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor 856 state_handler, # type: sdk_worker.CachingStateHandler 857 data_channel_factory, # type: data_plane.DataChannelFactory 858 data_sampler=None, # type: Optional[data_sampler.DataSampler] 859 ): 860 # type: (...) -> None 861 862 """Initialize a bundle processor. 863 864 Args: 865 process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``): 866 a description of the stage that this ``BundleProcessor``is to execute. 867 state_handler (CachingStateHandler). 868 data_channel_factory (``data_plane.DataChannelFactory``). 869 """ 870 self.process_bundle_descriptor = process_bundle_descriptor 871 self.state_handler = state_handler 872 self.data_channel_factory = data_channel_factory 873 self.data_sampler = data_sampler 874 self.current_instruction_id = None # type: Optional[str] 875 876 _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor) 877 # There is no guarantee that the runner only set 878 # timer_api_service_descriptor when having timers. So this field cannot be 879 # used as an indicator of timers. 880 if self.process_bundle_descriptor.timer_api_service_descriptor.url: 881 self.timer_data_channel = ( 882 data_channel_factory.create_data_channel_from_url( 883 self.process_bundle_descriptor.timer_api_service_descriptor.url)) 884 else: 885 self.timer_data_channel = None 886 887 # A mapping of 888 # {(transform_id, timer_family_id): TimerInfo} 889 # The mapping is empty when there is no timer_family_specs in the 890 # ProcessBundleDescriptor. 891 self.timers_info = {} # type: Dict[Tuple[str, str], TimerInfo] 892 893 # TODO(robertwb): Figure out the correct prefix to use for output counters 894 # from StateSampler. 895 self.counter_factory = counters.CounterFactory() 896 self.state_sampler = statesampler.StateSampler( 897 'fnapi-step-%s' % self.process_bundle_descriptor.id, 898 self.counter_factory) 899 900 if self.data_sampler: 901 self.add_data_sampling_operations(process_bundle_descriptor) 902 903 self.ops = self.create_execution_tree(self.process_bundle_descriptor) 904 for op in reversed(self.ops.values()): 905 op.setup() 906 self.splitting_lock = threading.Lock() 907 908 def add_data_sampling_operations(self, pbd): 909 # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None 910 911 """Adds a DataSamplingOperation to every PCollection. 912 913 Implementation note: the alternative to this, is to add modify each 914 Operation and forward a DataSampler to manually sample when an element is 915 processed. This gets messy very quickly and is not future-proof as new 916 operation types will need to be updated. This is the cleanest way of adding 917 new operations to the final execution tree. 918 """ 919 coder = coders.FastPrimitivesCoder() 920 921 for pcoll_id in pbd.pcollections: 922 transform_id = 'synthetic-data-sampling-transform-{}'.format(pcoll_id) 923 transform_proto: beam_runner_api_pb2.PTransform = pbd.transforms[ 924 transform_id] 925 transform_proto.unique_name = transform_id 926 transform_proto.spec.urn = SYNTHETIC_DATA_SAMPLING_URN 927 928 coder_id = pbd.pcollections[pcoll_id].coder_id 929 transform_proto.spec.payload = coder.encode((pcoll_id, coder_id)) 930 931 transform_proto.inputs['None'] = pcoll_id 932 933 def create_execution_tree( 934 self, 935 descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor 936 ): 937 # type: (...) -> collections.OrderedDict[str, operations.DoOperation] 938 transform_factory = BeamTransformFactory( 939 descriptor, 940 self.data_channel_factory, 941 self.counter_factory, 942 self.state_sampler, 943 self.state_handler, 944 self.data_sampler) 945 946 self.timers_info = transform_factory.extract_timers_info() 947 948 def is_side_input(transform_proto, tag): 949 if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: 950 return tag in proto_utils.parse_Bytes( 951 transform_proto.spec.payload, 952 beam_runner_api_pb2.ParDoPayload).side_inputs 953 954 pcoll_consumers = collections.defaultdict( 955 list) # type: DefaultDict[str, List[str]] 956 for transform_id, transform_proto in descriptor.transforms.items(): 957 for tag, pcoll_id in transform_proto.inputs.items(): 958 if not is_side_input(transform_proto, tag): 959 pcoll_consumers[pcoll_id].append(transform_id) 960 961 @memoize 962 def get_operation(transform_id): 963 # type: (str) -> operations.Operation 964 transform_consumers = { 965 tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] 966 for tag, 967 pcoll_id in descriptor.transforms[transform_id].outputs.items() 968 } 969 return transform_factory.create_operation( 970 transform_id, transform_consumers) 971 972 # Operations must be started (hence returned) in order. 973 @memoize 974 def topological_height(transform_id): 975 # type: (str) -> int 976 return 1 + max([0] + [ 977 topological_height(consumer) 978 for pcoll in descriptor.transforms[transform_id].outputs.values() 979 for consumer in pcoll_consumers[pcoll] 980 ]) 981 982 return collections.OrderedDict([( 983 transform_id, 984 cast(operations.DoOperation, 985 get_operation(transform_id))) for transform_id in sorted( 986 descriptor.transforms, key=topological_height, reverse=True)]) 987 988 def reset(self): 989 # type: () -> None 990 self.counter_factory.reset() 991 self.state_sampler.reset() 992 # Side input caches. 993 for op in self.ops.values(): 994 op.reset() 995 996 def process_bundle(self, instruction_id): 997 # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] 998 999 expected_input_ops = [] # type: List[DataInputOperation] 1000 1001 for op in self.ops.values(): 1002 if isinstance(op, DataOutputOperation): 1003 # TODO(robertwb): Is there a better way to pass the instruction id to 1004 # the operation? 1005 op.set_output_stream( 1006 op.data_channel.output_stream(instruction_id, op.transform_id)) 1007 elif isinstance(op, DataInputOperation): 1008 # We must wait until we receive "end of stream" for each of these ops. 1009 expected_input_ops.append(op) 1010 1011 try: 1012 execution_context = ExecutionContext() 1013 self.current_instruction_id = instruction_id 1014 self.state_sampler.start() 1015 # Start all operations. 1016 for op in reversed(self.ops.values()): 1017 _LOGGER.debug('start %s', op) 1018 op.execution_context = execution_context 1019 op.start() 1020 1021 # Each data_channel is mapped to a list of expected inputs which includes 1022 # both data input and timer input. The data input is identied by 1023 # transform_id. The data input is identified by 1024 # (transform_id, timer_family_id). 1025 data_channels = collections.defaultdict( 1026 list 1027 ) # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]] 1028 1029 # Add expected data inputs for each data channel. 1030 input_op_by_transform_id = {} 1031 for input_op in expected_input_ops: 1032 data_channels[input_op.data_channel].append(input_op.transform_id) 1033 input_op_by_transform_id[input_op.transform_id] = input_op 1034 1035 # Update timer_data channel with expected timer inputs. 1036 if self.timer_data_channel: 1037 data_channels[self.timer_data_channel].extend( 1038 list(self.timers_info.keys())) 1039 1040 # Set up timer output stream for DoOperation. 1041 for ((transform_id, timer_family_id), 1042 timer_info) in self.timers_info.items(): 1043 output_stream = self.timer_data_channel.output_timer_stream( 1044 instruction_id, transform_id, timer_family_id) 1045 timer_info.output_stream = output_stream 1046 self.ops[transform_id].add_timer_info(timer_family_id, timer_info) 1047 1048 # Process data and timer inputs 1049 for data_channel, expected_inputs in data_channels.items(): 1050 for element in data_channel.input_elements(instruction_id, 1051 expected_inputs): 1052 if isinstance(element, beam_fn_api_pb2.Elements.Timers): 1053 timer_coder_impl = ( 1054 self.timers_info[( 1055 element.transform_id, 1056 element.timer_family_id)].timer_coder_impl) 1057 for timer_data in timer_coder_impl.decode_all(element.timers): 1058 self.ops[element.transform_id].process_timer( 1059 element.timer_family_id, timer_data) 1060 elif isinstance(element, beam_fn_api_pb2.Elements.Data): 1061 input_op_by_transform_id[element.transform_id].process_encoded( 1062 element.data) 1063 1064 # Finish all operations. 1065 for op in self.ops.values(): 1066 _LOGGER.debug('finish %s', op) 1067 op.finish() 1068 1069 # Close every timer output stream 1070 for timer_info in self.timers_info.values(): 1071 assert timer_info.output_stream is not None 1072 timer_info.output_stream.close() 1073 1074 return ([ 1075 self.delayed_bundle_application(op, residual) for op, 1076 residual in execution_context.delayed_applications 1077 ], 1078 self.requires_finalization()) 1079 1080 finally: 1081 # Ensure any in-flight split attempts complete. 1082 with self.splitting_lock: 1083 self.current_instruction_id = None 1084 self.state_sampler.stop_if_still_running() 1085 1086 def finalize_bundle(self): 1087 # type: () -> beam_fn_api_pb2.FinalizeBundleResponse 1088 for op in self.ops.values(): 1089 op.finalize_bundle() 1090 return beam_fn_api_pb2.FinalizeBundleResponse() 1091 1092 def requires_finalization(self): 1093 # type: () -> bool 1094 return any(op.needs_finalization() for op in self.ops.values()) 1095 1096 def try_split(self, bundle_split_request): 1097 # type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse 1098 split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() 1099 with self.splitting_lock: 1100 if bundle_split_request.instruction_id != self.current_instruction_id: 1101 # This may be a delayed split for a former bundle, see BEAM-12475. 1102 return split_response 1103 1104 for op in self.ops.values(): 1105 if isinstance(op, DataInputOperation): 1106 desired_split = bundle_split_request.desired_splits.get( 1107 op.transform_id) 1108 if desired_split: 1109 split = op.try_split( 1110 desired_split.fraction_of_remainder, 1111 desired_split.estimated_input_elements, 1112 desired_split.allowed_split_points) 1113 if split: 1114 ( 1115 primary_end, 1116 element_primaries, 1117 element_residuals, 1118 residual_start, 1119 ) = split 1120 for element_primary in element_primaries: 1121 split_response.primary_roots.add().CopyFrom( 1122 self.bundle_application(*element_primary)) 1123 for element_residual in element_residuals: 1124 split_response.residual_roots.add().CopyFrom( 1125 self.delayed_bundle_application(*element_residual)) 1126 split_response.channel_splits.extend([ 1127 beam_fn_api_pb2.ProcessBundleSplitResponse.ChannelSplit( 1128 transform_id=op.transform_id, 1129 last_primary_element=primary_end, 1130 first_residual_element=residual_start) 1131 ]) 1132 1133 return split_response 1134 1135 def delayed_bundle_application(self, 1136 op, # type: operations.DoOperation 1137 deferred_remainder # type: SplitResultResidual 1138 ): 1139 # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication 1140 assert op.input_info is not None 1141 # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. 1142 (element_and_restriction, current_watermark, deferred_timestamp) = ( 1143 deferred_remainder) 1144 if deferred_timestamp: 1145 assert isinstance(deferred_timestamp, timestamp.Duration) 1146 proto_deferred_watermark = proto_utils.from_micros( 1147 duration_pb2.Duration, 1148 deferred_timestamp.micros) # type: Optional[duration_pb2.Duration] 1149 else: 1150 proto_deferred_watermark = None 1151 return beam_fn_api_pb2.DelayedBundleApplication( 1152 requested_time_delay=proto_deferred_watermark, 1153 application=self.construct_bundle_application( 1154 op.input_info, current_watermark, element_and_restriction)) 1155 1156 def bundle_application(self, 1157 op, # type: operations.DoOperation 1158 primary # type: SplitResultPrimary 1159 ): 1160 # type: (...) -> beam_fn_api_pb2.BundleApplication 1161 assert op.input_info is not None 1162 return self.construct_bundle_application( 1163 op.input_info, None, primary.primary_value) 1164 1165 def construct_bundle_application(self, 1166 op_input_info, # type: operations.OpInputInfo 1167 output_watermark, # type: Optional[timestamp.Timestamp] 1168 element 1169 ): 1170 # type: (...) -> beam_fn_api_pb2.BundleApplication 1171 transform_id, main_input_tag, main_input_coder, outputs = op_input_info 1172 if output_watermark: 1173 proto_output_watermark = proto_utils.from_micros( 1174 timestamp_pb2.Timestamp, output_watermark.micros) 1175 output_watermarks = { 1176 output: proto_output_watermark 1177 for output in outputs 1178 } # type: Optional[Dict[str, timestamp_pb2.Timestamp]] 1179 else: 1180 output_watermarks = None 1181 return beam_fn_api_pb2.BundleApplication( 1182 transform_id=transform_id, 1183 input_id=main_input_tag, 1184 output_watermarks=output_watermarks, 1185 element=main_input_coder.get_impl().encode_nested(element)) 1186 1187 def monitoring_infos(self): 1188 # type: () -> List[metrics_pb2.MonitoringInfo] 1189 1190 """Returns the list of MonitoringInfos collected processing this bundle.""" 1191 # Construct a new dict first to remove duplicates. 1192 all_monitoring_infos_dict = {} 1193 for transform_id, op in self.ops.items(): 1194 tag_to_pcollection_id = self.process_bundle_descriptor.transforms[ 1195 transform_id].outputs 1196 all_monitoring_infos_dict.update( 1197 op.monitoring_infos(transform_id, dict(tag_to_pcollection_id))) 1198 1199 return list(all_monitoring_infos_dict.values()) 1200 1201 def shutdown(self): 1202 # type: () -> None 1203 for op in self.ops.values(): 1204 op.teardown() 1205 1206 1207 class ExecutionContext(object): 1208 def __init__(self): 1209 self.delayed_applications = [ 1210 ] # type: List[Tuple[operations.DoOperation, common.SplitResultResidual]] 1211 1212 1213 class BeamTransformFactory(object): 1214 """Factory for turning transform_protos into executable operations.""" 1215 def __init__(self, 1216 descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor 1217 data_channel_factory, # type: data_plane.DataChannelFactory 1218 counter_factory, # type: counters.CounterFactory 1219 state_sampler, # type: statesampler.StateSampler 1220 state_handler, # type: sdk_worker.CachingStateHandler 1221 data_sampler, # type: Optional[data_sampler.DataSampler] 1222 ): 1223 self.descriptor = descriptor 1224 self.data_channel_factory = data_channel_factory 1225 self.counter_factory = counter_factory 1226 self.state_sampler = state_sampler 1227 self.state_handler = state_handler 1228 self.context = pipeline_context.PipelineContext( 1229 descriptor, 1230 iterable_state_read=lambda token, 1231 element_coder_impl: _StateBackedIterable( 1232 state_handler, 1233 beam_fn_api_pb2.StateKey( 1234 runner=beam_fn_api_pb2.StateKey.Runner(key=token)), 1235 element_coder_impl)) 1236 self.data_sampler = data_sampler 1237 1238 _known_urns = { 1239 } # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]] 1240 1241 @classmethod 1242 def register_urn( 1243 cls, 1244 urn, # type: str 1245 parameter_type # type: Optional[Type[T]] 1246 ): 1247 # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]] 1248 def wrapper(func): 1249 cls._known_urns[urn] = func, parameter_type 1250 return func 1251 1252 return wrapper 1253 1254 def create_operation(self, 1255 transform_id, # type: str 1256 consumers # type: Dict[str, List[operations.Operation]] 1257 ): 1258 # type: (...) -> operations.Operation 1259 transform_proto = self.descriptor.transforms[transform_id] 1260 if not transform_proto.unique_name: 1261 _LOGGER.debug("No unique name set for transform %s" % transform_id) 1262 transform_proto.unique_name = transform_id 1263 creator, parameter_type = self._known_urns[transform_proto.spec.urn] 1264 payload = proto_utils.parse_Bytes( 1265 transform_proto.spec.payload, parameter_type) 1266 return creator(self, transform_id, transform_proto, payload, consumers) 1267 1268 def extract_timers_info(self): 1269 # type: () -> Dict[Tuple[str, str], TimerInfo] 1270 timers_info = {} 1271 for transform_id, transform_proto in self.descriptor.transforms.items(): 1272 if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: 1273 pardo_payload = proto_utils.parse_Bytes( 1274 transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload) 1275 for (timer_family_id, 1276 timer_family_spec) in pardo_payload.timer_family_specs.items(): 1277 timer_coder_impl = self.get_coder( 1278 timer_family_spec.timer_family_coder_id).get_impl() 1279 # The output_stream should be updated when processing a bundle. 1280 timers_info[(transform_id, timer_family_id)] = TimerInfo( 1281 timer_coder_impl=timer_coder_impl) 1282 return timers_info 1283 1284 def get_coder(self, coder_id): 1285 # type: (str) -> coders.Coder 1286 if coder_id not in self.descriptor.coders: 1287 raise KeyError("No such coder: %s" % coder_id) 1288 coder_proto = self.descriptor.coders[coder_id] 1289 if coder_proto.spec.urn: 1290 return self.context.coders.get_by_id(coder_id) 1291 else: 1292 # No URN, assume cloud object encoding json bytes. 1293 return operation_specs.get_coder_from_spec( 1294 json.loads(coder_proto.spec.payload.decode('utf-8'))) 1295 1296 def get_windowed_coder(self, pcoll_id): 1297 # type: (str) -> WindowedValueCoder 1298 coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) 1299 # TODO(robertwb): Remove this condition once all runners are consistent. 1300 if not isinstance(coder, WindowedValueCoder): 1301 windowing_strategy = self.descriptor.windowing_strategies[ 1302 self.descriptor.pcollections[pcoll_id].windowing_strategy_id] 1303 return WindowedValueCoder( 1304 coder, self.get_coder(windowing_strategy.window_coder_id)) 1305 else: 1306 return coder 1307 1308 def get_output_coders(self, transform_proto): 1309 # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder] 1310 return { 1311 tag: self.get_windowed_coder(pcoll_id) 1312 for tag, 1313 pcoll_id in transform_proto.outputs.items() 1314 } 1315 1316 def get_only_output_coder(self, transform_proto): 1317 # type: (beam_runner_api_pb2.PTransform) -> coders.Coder 1318 return only_element(self.get_output_coders(transform_proto).values()) 1319 1320 def get_input_coders(self, transform_proto): 1321 # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder] 1322 return { 1323 tag: self.get_windowed_coder(pcoll_id) 1324 for tag, 1325 pcoll_id in transform_proto.inputs.items() 1326 } 1327 1328 def get_only_input_coder(self, transform_proto): 1329 # type: (beam_runner_api_pb2.PTransform) -> coders.Coder 1330 return only_element(list(self.get_input_coders(transform_proto).values())) 1331 1332 def get_input_windowing(self, transform_proto): 1333 # type: (beam_runner_api_pb2.PTransform) -> Windowing 1334 pcoll_id = only_element(transform_proto.inputs.values()) 1335 windowing_strategy_id = self.descriptor.pcollections[ 1336 pcoll_id].windowing_strategy_id 1337 return self.context.windowing_strategies.get_by_id(windowing_strategy_id) 1338 1339 # TODO(robertwb): Update all operations to take these in the constructor. 1340 @staticmethod 1341 def augment_oldstyle_op( 1342 op, # type: OperationT 1343 step_name, # type: str 1344 consumers, # type: Mapping[str, Iterable[operations.Operation]] 1345 tag_list=None # type: Optional[List[str]] 1346 ): 1347 # type: (...) -> OperationT 1348 op.step_name = step_name 1349 for tag, op_consumers in consumers.items(): 1350 for consumer in op_consumers: 1351 op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0) 1352 return op 1353 1354 1355 @BeamTransformFactory.register_urn( 1356 DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) 1357 def create_source_runner( 1358 factory, # type: BeamTransformFactory 1359 transform_id, # type: str 1360 transform_proto, # type: beam_runner_api_pb2.PTransform 1361 grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort 1362 consumers # type: Dict[str, List[operations.Operation]] 1363 ): 1364 # type: (...) -> DataInputOperation 1365 1366 output_coder = factory.get_coder(grpc_port.coder_id) 1367 return DataInputOperation( 1368 common.NameContext(transform_proto.unique_name, transform_id), 1369 transform_proto.unique_name, 1370 consumers, 1371 factory.counter_factory, 1372 factory.state_sampler, 1373 output_coder, 1374 transform_id=transform_id, 1375 data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) 1376 1377 1378 @BeamTransformFactory.register_urn( 1379 DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) 1380 def create_sink_runner( 1381 factory, # type: BeamTransformFactory 1382 transform_id, # type: str 1383 transform_proto, # type: beam_runner_api_pb2.PTransform 1384 grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort 1385 consumers # type: Dict[str, List[operations.Operation]] 1386 ): 1387 # type: (...) -> DataOutputOperation 1388 output_coder = factory.get_coder(grpc_port.coder_id) 1389 return DataOutputOperation( 1390 common.NameContext(transform_proto.unique_name, transform_id), 1391 transform_proto.unique_name, 1392 consumers, 1393 factory.counter_factory, 1394 factory.state_sampler, 1395 output_coder, 1396 transform_id=transform_id, 1397 data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) 1398 1399 1400 @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) 1401 def create_source_java( 1402 factory, # type: BeamTransformFactory 1403 transform_id, # type: str 1404 transform_proto, # type: beam_runner_api_pb2.PTransform 1405 parameter, 1406 consumers # type: Dict[str, List[operations.Operation]] 1407 ): 1408 # type: (...) -> operations.ReadOperation 1409 # The Dataflow runner harness strips the base64 encoding. 1410 source = pickler.loads(base64.b64encode(parameter)) 1411 spec = operation_specs.WorkerRead( 1412 iobase.SourceBundle(1.0, source, None, None), 1413 [factory.get_only_output_coder(transform_proto)]) 1414 return factory.augment_oldstyle_op( 1415 operations.ReadOperation( 1416 common.NameContext(transform_proto.unique_name, transform_id), 1417 spec, 1418 factory.counter_factory, 1419 factory.state_sampler), 1420 transform_proto.unique_name, 1421 consumers) 1422 1423 1424 @BeamTransformFactory.register_urn( 1425 common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) 1426 def create_deprecated_read( 1427 factory, # type: BeamTransformFactory 1428 transform_id, # type: str 1429 transform_proto, # type: beam_runner_api_pb2.PTransform 1430 parameter, # type: beam_runner_api_pb2.ReadPayload 1431 consumers # type: Dict[str, List[operations.Operation]] 1432 ): 1433 # type: (...) -> operations.ReadOperation 1434 source = iobase.BoundedSource.from_runner_api( 1435 parameter.source, factory.context) 1436 spec = operation_specs.WorkerRead( 1437 iobase.SourceBundle(1.0, source, None, None), 1438 [WindowedValueCoder(source.default_output_coder())]) 1439 return factory.augment_oldstyle_op( 1440 operations.ReadOperation( 1441 common.NameContext(transform_proto.unique_name, transform_id), 1442 spec, 1443 factory.counter_factory, 1444 factory.state_sampler), 1445 transform_proto.unique_name, 1446 consumers) 1447 1448 1449 @BeamTransformFactory.register_urn( 1450 python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) 1451 def create_read_from_impulse_python( 1452 factory, # type: BeamTransformFactory 1453 transform_id, # type: str 1454 transform_proto, # type: beam_runner_api_pb2.PTransform 1455 parameter, # type: beam_runner_api_pb2.ReadPayload 1456 consumers # type: Dict[str, List[operations.Operation]] 1457 ): 1458 # type: (...) -> operations.ImpulseReadOperation 1459 return operations.ImpulseReadOperation( 1460 common.NameContext(transform_proto.unique_name, transform_id), 1461 factory.counter_factory, 1462 factory.state_sampler, 1463 consumers, 1464 iobase.BoundedSource.from_runner_api(parameter.source, factory.context), 1465 factory.get_only_output_coder(transform_proto)) 1466 1467 1468 @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) 1469 def create_dofn_javasdk( 1470 factory, # type: BeamTransformFactory 1471 transform_id, # type: str 1472 transform_proto, # type: beam_runner_api_pb2.PTransform 1473 serialized_fn, 1474 consumers # type: Dict[str, List[operations.Operation]] 1475 ): 1476 return _create_pardo_operation( 1477 factory, transform_id, transform_proto, consumers, serialized_fn) 1478 1479 1480 @BeamTransformFactory.register_urn( 1481 common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn, 1482 beam_runner_api_pb2.ParDoPayload) 1483 def create_pair_with_restriction(*args): 1484 class PairWithRestriction(beam.DoFn): 1485 def __init__(self, fn, restriction_provider, watermark_estimator_provider): 1486 self.restriction_provider = restriction_provider 1487 self.watermark_estimator_provider = watermark_estimator_provider 1488 1489 def process(self, element, *args, **kwargs): 1490 # TODO(SDF): Do we want to allow mutation of the element? 1491 # (E.g. it could be nice to shift bulky description to the portion 1492 # that can be distributed.) 1493 initial_restriction = self.restriction_provider.initial_restriction( 1494 element) 1495 initial_estimator_state = ( 1496 self.watermark_estimator_provider.initial_estimator_state( 1497 element, initial_restriction)) 1498 yield (element, (initial_restriction, initial_estimator_state)) 1499 1500 return _create_sdf_operation(PairWithRestriction, *args) 1501 1502 1503 @BeamTransformFactory.register_urn( 1504 common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn, 1505 beam_runner_api_pb2.ParDoPayload) 1506 def create_split_and_size_restrictions(*args): 1507 class SplitAndSizeRestrictions(beam.DoFn): 1508 def __init__(self, fn, restriction_provider, watermark_estimator_provider): 1509 self.restriction_provider = restriction_provider 1510 self.watermark_estimator_provider = watermark_estimator_provider 1511 1512 def process(self, element_restriction, *args, **kwargs): 1513 element, (restriction, _) = element_restriction 1514 for part, size in self.restriction_provider.split_and_size( 1515 element, restriction): 1516 if size < 0: 1517 raise ValueError('Expected size >= 0 but received %s.' % size) 1518 estimator_state = ( 1519 self.watermark_estimator_provider.initial_estimator_state( 1520 element, part)) 1521 yield ((element, (part, estimator_state)), size) 1522 1523 return _create_sdf_operation(SplitAndSizeRestrictions, *args) 1524 1525 1526 @BeamTransformFactory.register_urn( 1527 common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn, 1528 beam_runner_api_pb2.ParDoPayload) 1529 def create_truncate_sized_restriction(*args): 1530 class TruncateAndSizeRestriction(beam.DoFn): 1531 def __init__(self, fn, restriction_provider, watermark_estimator_provider): 1532 self.restriction_provider = restriction_provider 1533 1534 def process(self, element_restriction, *args, **kwargs): 1535 ((element, (restriction, estimator_state)), _) = element_restriction 1536 truncated_restriction = self.restriction_provider.truncate( 1537 element, restriction) 1538 if truncated_restriction: 1539 truncated_restriction_size = ( 1540 self.restriction_provider.restriction_size( 1541 element, truncated_restriction)) 1542 if truncated_restriction_size < 0: 1543 raise ValueError( 1544 'Expected size >= 0 but received %s.' % 1545 truncated_restriction_size) 1546 yield ((element, (truncated_restriction, estimator_state)), 1547 truncated_restriction_size) 1548 1549 return _create_sdf_operation( 1550 TruncateAndSizeRestriction, 1551 *args, 1552 operation_cls=operations.SdfTruncateSizedRestrictions) 1553 1554 1555 @BeamTransformFactory.register_urn( 1556 common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, 1557 beam_runner_api_pb2.ParDoPayload) 1558 def create_process_sized_elements_and_restrictions( 1559 factory, # type: BeamTransformFactory 1560 transform_id, # type: str 1561 transform_proto, # type: beam_runner_api_pb2.PTransform 1562 parameter, # type: beam_runner_api_pb2.ParDoPayload 1563 consumers # type: Dict[str, List[operations.Operation]] 1564 ): 1565 return _create_pardo_operation( 1566 factory, 1567 transform_id, 1568 transform_proto, 1569 consumers, 1570 core.DoFnInfo.from_runner_api(parameter.do_fn, 1571 factory.context).serialized_dofn_data(), 1572 parameter, 1573 operation_cls=operations.SdfProcessSizedElements) 1574 1575 1576 def _create_sdf_operation( 1577 proxy_dofn, 1578 factory, 1579 transform_id, 1580 transform_proto, 1581 parameter, 1582 consumers, 1583 operation_cls=operations.DoOperation): 1584 1585 dofn_data = pickler.loads(parameter.do_fn.payload) 1586 dofn = dofn_data[0] 1587 restriction_provider = common.DoFnSignature(dofn).get_restriction_provider() 1588 watermark_estimator_provider = ( 1589 common.DoFnSignature(dofn).get_watermark_estimator_provider()) 1590 serialized_fn = pickler.dumps( 1591 (proxy_dofn(dofn, restriction_provider, watermark_estimator_provider), ) + 1592 dofn_data[1:]) 1593 return _create_pardo_operation( 1594 factory, 1595 transform_id, 1596 transform_proto, 1597 consumers, 1598 serialized_fn, 1599 parameter, 1600 operation_cls=operation_cls) 1601 1602 1603 @BeamTransformFactory.register_urn( 1604 common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) 1605 def create_par_do( 1606 factory, # type: BeamTransformFactory 1607 transform_id, # type: str 1608 transform_proto, # type: beam_runner_api_pb2.PTransform 1609 parameter, # type: beam_runner_api_pb2.ParDoPayload 1610 consumers # type: Dict[str, List[operations.Operation]] 1611 ): 1612 # type: (...) -> operations.DoOperation 1613 return _create_pardo_operation( 1614 factory, 1615 transform_id, 1616 transform_proto, 1617 consumers, 1618 core.DoFnInfo.from_runner_api(parameter.do_fn, 1619 factory.context).serialized_dofn_data(), 1620 parameter) 1621 1622 1623 def _create_pardo_operation( 1624 factory, # type: BeamTransformFactory 1625 transform_id, # type: str 1626 transform_proto, # type: beam_runner_api_pb2.PTransform 1627 consumers, 1628 serialized_fn, 1629 pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload] 1630 operation_cls=operations.DoOperation 1631 ): 1632 1633 if pardo_proto and pardo_proto.side_inputs: 1634 input_tags_to_coders = factory.get_input_coders(transform_proto) 1635 tagged_side_inputs = [ 1636 (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context)) 1637 for tag, 1638 si in pardo_proto.side_inputs.items() 1639 ] 1640 tagged_side_inputs.sort( 1641 key=lambda tag_si: sideinputs.get_sideinput_index(tag_si[0])) 1642 side_input_maps = [ 1643 StateBackedSideInputMap( 1644 factory.state_handler, 1645 transform_id, 1646 tag, 1647 si, 1648 input_tags_to_coders[tag]) for tag, 1649 si in tagged_side_inputs 1650 ] 1651 else: 1652 side_input_maps = [] 1653 1654 output_tags = list(transform_proto.outputs.keys()) 1655 1656 dofn_data = pickler.loads(serialized_fn) 1657 if not dofn_data[-1]: 1658 # Windowing not set. 1659 if pardo_proto: 1660 other_input_tags = set.union( 1661 set(pardo_proto.side_inputs), 1662 set(pardo_proto.timer_family_specs)) # type: Container[str] 1663 else: 1664 other_input_tags = () 1665 pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() 1666 if tag not in other_input_tags] 1667 windowing = factory.context.windowing_strategies.get_by_id( 1668 factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) 1669 serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, )) 1670 1671 if pardo_proto and (pardo_proto.timer_family_specs or pardo_proto.state_specs 1672 or pardo_proto.restriction_coder_id): 1673 found_input_coder = None 1674 for tag, pcoll_id in transform_proto.inputs.items(): 1675 if tag in pardo_proto.side_inputs: 1676 pass 1677 else: 1678 # Must be the main input 1679 assert found_input_coder is None 1680 main_input_tag = tag 1681 found_input_coder = factory.get_windowed_coder(pcoll_id) 1682 assert found_input_coder is not None 1683 main_input_coder = found_input_coder 1684 1685 if pardo_proto.timer_family_specs or pardo_proto.state_specs: 1686 user_state_context = FnApiUserStateContext( 1687 factory.state_handler, 1688 transform_id, 1689 main_input_coder.key_coder(), 1690 main_input_coder.window_coder 1691 ) # type: Optional[FnApiUserStateContext] 1692 else: 1693 user_state_context = None 1694 else: 1695 user_state_context = None 1696 1697 output_coders = factory.get_output_coders(transform_proto) 1698 spec = operation_specs.WorkerDoFn( 1699 serialized_fn=serialized_fn, 1700 output_tags=output_tags, 1701 input=None, 1702 side_inputs=None, # Fn API uses proto definitions and the Fn State API 1703 output_coders=[output_coders[tag] for tag in output_tags]) 1704 1705 result = factory.augment_oldstyle_op( 1706 operation_cls( 1707 common.NameContext(transform_proto.unique_name, transform_id), 1708 spec, 1709 factory.counter_factory, 1710 factory.state_sampler, 1711 side_input_maps, 1712 user_state_context), 1713 transform_proto.unique_name, 1714 consumers, 1715 output_tags) 1716 if pardo_proto and pardo_proto.restriction_coder_id: 1717 result.input_info = operations.OpInputInfo( 1718 transform_id, 1719 main_input_tag, 1720 main_input_coder, 1721 transform_proto.outputs.keys()) 1722 return result 1723 1724 1725 def _create_simple_pardo_operation(factory, # type: BeamTransformFactory 1726 transform_id, 1727 transform_proto, 1728 consumers, 1729 dofn, # type: beam.DoFn 1730 ): 1731 serialized_fn = pickler.dumps((dofn, (), {}, [], None)) 1732 return _create_pardo_operation( 1733 factory, transform_id, transform_proto, consumers, serialized_fn) 1734 1735 1736 @BeamTransformFactory.register_urn( 1737 common_urns.primitives.ASSIGN_WINDOWS.urn, 1738 beam_runner_api_pb2.WindowingStrategy) 1739 def create_assign_windows( 1740 factory, # type: BeamTransformFactory 1741 transform_id, # type: str 1742 transform_proto, # type: beam_runner_api_pb2.PTransform 1743 parameter, # type: beam_runner_api_pb2.WindowingStrategy 1744 consumers # type: Dict[str, List[operations.Operation]] 1745 ): 1746 class WindowIntoDoFn(beam.DoFn): 1747 def __init__(self, windowing): 1748 self.windowing = windowing 1749 1750 def process( 1751 self, 1752 element, 1753 timestamp=beam.DoFn.TimestampParam, 1754 window=beam.DoFn.WindowParam): 1755 new_windows = self.windowing.windowfn.assign( 1756 WindowFn.AssignContext(timestamp, element=element, window=window)) 1757 yield WindowedValue(element, timestamp, new_windows) 1758 1759 from apache_beam.transforms.core import Windowing 1760 from apache_beam.transforms.window import WindowFn 1761 windowing = Windowing.from_runner_api(parameter, factory.context) 1762 return _create_simple_pardo_operation( 1763 factory, 1764 transform_id, 1765 transform_proto, 1766 consumers, 1767 WindowIntoDoFn(windowing)) 1768 1769 1770 @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) 1771 def create_identity_dofn( 1772 factory, # type: BeamTransformFactory 1773 transform_id, # type: str 1774 transform_proto, # type: beam_runner_api_pb2.PTransform 1775 parameter, 1776 consumers # type: Dict[str, List[operations.Operation]] 1777 ): 1778 # type: (...) -> operations.FlattenOperation 1779 return factory.augment_oldstyle_op( 1780 operations.FlattenOperation( 1781 common.NameContext(transform_proto.unique_name, transform_id), 1782 operation_specs.WorkerFlatten( 1783 None, [factory.get_only_output_coder(transform_proto)]), 1784 factory.counter_factory, 1785 factory.state_sampler), 1786 transform_proto.unique_name, 1787 consumers) 1788 1789 1790 @BeamTransformFactory.register_urn( 1791 common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, 1792 beam_runner_api_pb2.CombinePayload) 1793 def create_combine_per_key_precombine( 1794 factory, # type: BeamTransformFactory 1795 transform_id, # type: str 1796 transform_proto, # type: beam_runner_api_pb2.PTransform 1797 payload, # type: beam_runner_api_pb2.CombinePayload 1798 consumers # type: Dict[str, List[operations.Operation]] 1799 ): 1800 # type: (...) -> operations.PGBKCVOperation 1801 serialized_combine_fn = pickler.dumps(( 1802 beam.CombineFn.from_runner_api(payload.combine_fn, 1803 factory.context), [], {})) 1804 return factory.augment_oldstyle_op( 1805 operations.PGBKCVOperation( 1806 common.NameContext(transform_proto.unique_name, transform_id), 1807 operation_specs.WorkerPartialGroupByKey( 1808 serialized_combine_fn, 1809 None, [factory.get_only_output_coder(transform_proto)]), 1810 factory.counter_factory, 1811 factory.state_sampler, 1812 factory.get_input_windowing(transform_proto)), 1813 transform_proto.unique_name, 1814 consumers) 1815 1816 1817 @BeamTransformFactory.register_urn( 1818 common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, 1819 beam_runner_api_pb2.CombinePayload) 1820 def create_combbine_per_key_merge_accumulators( 1821 factory, # type: BeamTransformFactory 1822 transform_id, # type: str 1823 transform_proto, # type: beam_runner_api_pb2.PTransform 1824 payload, # type: beam_runner_api_pb2.CombinePayload 1825 consumers # type: Dict[str, List[operations.Operation]] 1826 ): 1827 return _create_combine_phase_operation( 1828 factory, transform_id, transform_proto, payload, consumers, 'merge') 1829 1830 1831 @BeamTransformFactory.register_urn( 1832 common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, 1833 beam_runner_api_pb2.CombinePayload) 1834 def create_combine_per_key_extract_outputs( 1835 factory, # type: BeamTransformFactory 1836 transform_id, # type: str 1837 transform_proto, # type: beam_runner_api_pb2.PTransform 1838 payload, # type: beam_runner_api_pb2.CombinePayload 1839 consumers # type: Dict[str, List[operations.Operation]] 1840 ): 1841 return _create_combine_phase_operation( 1842 factory, transform_id, transform_proto, payload, consumers, 'extract') 1843 1844 1845 @BeamTransformFactory.register_urn( 1846 common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, 1847 beam_runner_api_pb2.CombinePayload) 1848 def create_combine_per_key_convert_to_accumulators( 1849 factory, # type: BeamTransformFactory 1850 transform_id, # type: str 1851 transform_proto, # type: beam_runner_api_pb2.PTransform 1852 payload, # type: beam_runner_api_pb2.CombinePayload 1853 consumers # type: Dict[str, List[operations.Operation]] 1854 ): 1855 return _create_combine_phase_operation( 1856 factory, transform_id, transform_proto, payload, consumers, 'convert') 1857 1858 1859 @BeamTransformFactory.register_urn( 1860 common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, 1861 beam_runner_api_pb2.CombinePayload) 1862 def create_combine_grouped_values( 1863 factory, # type: BeamTransformFactory 1864 transform_id, # type: str 1865 transform_proto, # type: beam_runner_api_pb2.PTransform 1866 payload, # type: beam_runner_api_pb2.CombinePayload 1867 consumers # type: Dict[str, List[operations.Operation]] 1868 ): 1869 return _create_combine_phase_operation( 1870 factory, transform_id, transform_proto, payload, consumers, 'all') 1871 1872 1873 def _create_combine_phase_operation( 1874 factory, transform_id, transform_proto, payload, consumers, phase): 1875 # type: (...) -> operations.CombineOperation 1876 serialized_combine_fn = pickler.dumps(( 1877 beam.CombineFn.from_runner_api(payload.combine_fn, 1878 factory.context), [], {})) 1879 return factory.augment_oldstyle_op( 1880 operations.CombineOperation( 1881 common.NameContext(transform_proto.unique_name, transform_id), 1882 operation_specs.WorkerCombineFn( 1883 serialized_combine_fn, 1884 phase, 1885 None, [factory.get_only_output_coder(transform_proto)]), 1886 factory.counter_factory, 1887 factory.state_sampler), 1888 transform_proto.unique_name, 1889 consumers) 1890 1891 1892 @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) 1893 def create_flatten( 1894 factory, # type: BeamTransformFactory 1895 transform_id, # type: str 1896 transform_proto, # type: beam_runner_api_pb2.PTransform 1897 payload, 1898 consumers # type: Dict[str, List[operations.Operation]] 1899 ): 1900 # type: (...) -> operations.FlattenOperation 1901 return factory.augment_oldstyle_op( 1902 operations.FlattenOperation( 1903 common.NameContext(transform_proto.unique_name, transform_id), 1904 operation_specs.WorkerFlatten( 1905 None, [factory.get_only_output_coder(transform_proto)]), 1906 factory.counter_factory, 1907 factory.state_sampler), 1908 transform_proto.unique_name, 1909 consumers) 1910 1911 1912 @BeamTransformFactory.register_urn( 1913 common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) 1914 def create_map_windows( 1915 factory, # type: BeamTransformFactory 1916 transform_id, # type: str 1917 transform_proto, # type: beam_runner_api_pb2.PTransform 1918 mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec 1919 consumers # type: Dict[str, List[operations.Operation]] 1920 ): 1921 assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN 1922 window_mapping_fn = pickler.loads(mapping_fn_spec.payload) 1923 1924 class MapWindows(beam.DoFn): 1925 def process(self, element): 1926 key, window = element 1927 return [(key, window_mapping_fn(window))] 1928 1929 return _create_simple_pardo_operation( 1930 factory, transform_id, transform_proto, consumers, MapWindows()) 1931 1932 1933 @BeamTransformFactory.register_urn( 1934 common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) 1935 def create_merge_windows( 1936 factory, # type: BeamTransformFactory 1937 transform_id, # type: str 1938 transform_proto, # type: beam_runner_api_pb2.PTransform 1939 mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec 1940 consumers # type: Dict[str, List[operations.Operation]] 1941 ): 1942 assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN 1943 window_fn = pickler.loads(mapping_fn_spec.payload) 1944 1945 class MergeWindows(beam.DoFn): 1946 def process(self, element): 1947 nonce, windows = element 1948 1949 original_windows = set(windows) # type: Set[window.BoundedWindow] 1950 merged_windows = collections.defaultdict( 1951 set 1952 ) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821 1953 1954 class RecordingMergeContext(window.WindowFn.MergeContext): 1955 def merge( 1956 self, 1957 to_be_merged, # type: Iterable[window.BoundedWindow] 1958 merge_result, # type: window.BoundedWindow 1959 ): 1960 originals = merged_windows[merge_result] 1961 for window in to_be_merged: 1962 if window in original_windows: 1963 originals.add(window) 1964 original_windows.remove(window) 1965 else: 1966 originals.update(merged_windows.pop(window)) 1967 1968 window_fn.merge(RecordingMergeContext(windows)) 1969 yield nonce, (original_windows, merged_windows.items()) 1970 1971 return _create_simple_pardo_operation( 1972 factory, transform_id, transform_proto, consumers, MergeWindows()) 1973 1974 1975 @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) 1976 def create_to_string_fn( 1977 factory, # type: BeamTransformFactory 1978 transform_id, # type: str 1979 transform_proto, # type: beam_runner_api_pb2.PTransform 1980 mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec 1981 consumers # type: Dict[str, List[operations.Operation]] 1982 ): 1983 class ToString(beam.DoFn): 1984 def process(self, element): 1985 key, value = element 1986 return [(key, str(value))] 1987 1988 return _create_simple_pardo_operation( 1989 factory, transform_id, transform_proto, consumers, ToString()) 1990 1991 1992 class DataSamplingOperation(operations.Operation): 1993 """Operation that samples incoming elements.""" 1994 1995 def __init__( 1996 self, 1997 name_context, # type: common.NameContext 1998 counter_factory, # type: counters.CounterFactory 1999 state_sampler, # type: statesampler.StateSampler 2000 pcoll_id, # type: str 2001 sample_coder, # type: coders.Coder 2002 data_sampler, # type: data_sampler.DataSampler 2003 ): 2004 # type: (...) -> None 2005 super().__init__(name_context, None, counter_factory, state_sampler) 2006 self._coder = sample_coder # type: coders.Coder 2007 self._pcoll_id = pcoll_id # type: str 2008 2009 self._sampler: OutputSampler = data_sampler.sample_output( 2010 self._pcoll_id, sample_coder) 2011 2012 def process(self, windowed_value): 2013 # type: (windowed_value.WindowedValue) -> None 2014 self._sampler.sample(windowed_value) 2015 2016 2017 @BeamTransformFactory.register_urn(SYNTHETIC_DATA_SAMPLING_URN, (bytes)) 2018 def create_data_sampling_op( 2019 factory, # type: BeamTransformFactory 2020 transform_id, # type: str 2021 transform_proto, # type: beam_runner_api_pb2.PTransform 2022 pcoll_and_coder_id, # type: bytes 2023 consumers, # type: Dict[str, List[operations.Operation]] 2024 ): 2025 # Creating this operation should only occur when data sampling is enabled. 2026 data_sampler = factory.data_sampler 2027 assert data_sampler is not None 2028 2029 coder = coders.FastPrimitivesCoder() 2030 pcoll_id, coder_id = coder.decode(pcoll_and_coder_id) 2031 return DataSamplingOperation( 2032 common.NameContext(transform_proto.unique_name, transform_id), 2033 factory.counter_factory, 2034 factory.state_sampler, 2035 pcoll_id, 2036 factory.get_coder(coder_id), 2037 data_sampler, 2038 )