github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_plane.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Implementation of ``DataChannel``s to communicate across the data plane.""" 19 20 # pytype: skip-file 21 # mypy: disallow-untyped-defs 22 23 import abc 24 import collections 25 import json 26 import logging 27 import queue 28 import threading 29 import time 30 from typing import TYPE_CHECKING 31 from typing import Any 32 from typing import Callable 33 from typing import Collection 34 from typing import DefaultDict 35 from typing import Dict 36 from typing import Iterable 37 from typing import Iterator 38 from typing import List 39 from typing import Mapping 40 from typing import Optional 41 from typing import Set 42 from typing import Tuple 43 from typing import Union 44 45 import grpc 46 47 from apache_beam.coders import coder_impl 48 from apache_beam.portability.api import beam_fn_api_pb2 49 from apache_beam.portability.api import beam_fn_api_pb2_grpc 50 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory 51 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor 52 53 if TYPE_CHECKING: 54 import apache_beam.coders.slow_stream 55 56 OutputStream = apache_beam.coders.slow_stream.OutputStream 57 DataOrTimers = Union[beam_fn_api_pb2.Elements.Data, 58 beam_fn_api_pb2.Elements.Timers] 59 else: 60 OutputStream = type(coder_impl.create_OutputStream()) 61 62 _LOGGER = logging.getLogger(__name__) 63 64 _DEFAULT_SIZE_FLUSH_THRESHOLD = 10 << 20 # 10MB 65 _DEFAULT_TIME_FLUSH_THRESHOLD_MS = 0 # disable time-based flush by default 66 67 # Keep a set of completed instructions to discard late received data. The set 68 # can have up to _MAX_CLEANED_INSTRUCTIONS items. See _GrpcDataChannel. 69 _MAX_CLEANED_INSTRUCTIONS = 10000 70 71 # retry on transient UNAVAILABLE grpc error from data channels. 72 _GRPC_SERVICE_CONFIG = json.dumps({ 73 "methodConfig": [{ 74 "name": [{ 75 "service": "org.apache.beam.model.fn_execution.v1.BeamFnData" 76 }], 77 "retryPolicy": { 78 "maxAttempts": 5, 79 "initialBackoff": "0.1s", 80 "maxBackoff": "5s", 81 "backoffMultiplier": 2, 82 "retryableStatusCodes": ["UNAVAILABLE"], 83 }, 84 }] 85 }) 86 87 88 class ClosableOutputStream(OutputStream): 89 """A Outputstream for use with CoderImpls that has a close() method.""" 90 def __init__( 91 self, 92 close_callback=None # type: Optional[Callable[[bytes], None]] 93 ): 94 # type: (...) -> None 95 super().__init__() 96 self._close_callback = close_callback 97 98 def close(self): 99 # type: () -> None 100 if self._close_callback: 101 self._close_callback(self.get()) 102 103 def maybe_flush(self): 104 # type: () -> None 105 pass 106 107 def flush(self): 108 # type: () -> None 109 pass 110 111 @staticmethod 112 def create( 113 close_callback, # type: Optional[Callable[[bytes], None]] 114 flush_callback, # type: Optional[Callable[[bytes], None]] 115 data_buffer_time_limit_ms # type: int 116 ): 117 # type: (...) -> ClosableOutputStream 118 if data_buffer_time_limit_ms > 0: 119 return TimeBasedBufferingClosableOutputStream( 120 close_callback, 121 flush_callback=flush_callback, 122 time_flush_threshold_ms=data_buffer_time_limit_ms) 123 else: 124 return SizeBasedBufferingClosableOutputStream( 125 close_callback, flush_callback=flush_callback) 126 127 128 class SizeBasedBufferingClosableOutputStream(ClosableOutputStream): 129 """A size-based buffering OutputStream.""" 130 131 def __init__( 132 self, 133 close_callback=None, # type: Optional[Callable[[bytes], None]] 134 flush_callback=None, # type: Optional[Callable[[bytes], None]] 135 size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD # type: int 136 ): 137 super().__init__(close_callback) 138 self._flush_callback = flush_callback 139 self._size_flush_threshold = size_flush_threshold 140 141 # This must be called explicitly to avoid flushing partial elements. 142 def maybe_flush(self): 143 # type: () -> None 144 if self.size() > self._size_flush_threshold: 145 self.flush() 146 147 def flush(self): 148 # type: () -> None 149 if self._flush_callback: 150 self._flush_callback(self.get()) 151 self._clear() 152 153 154 class TimeBasedBufferingClosableOutputStream( 155 SizeBasedBufferingClosableOutputStream): 156 """A buffering OutputStream with both time-based and size-based.""" 157 _periodic_flusher = None # type: Optional[PeriodicThread] 158 159 def __init__( 160 self, 161 close_callback=None, # type: Optional[Callable[[bytes], None]] 162 flush_callback=None, # type: Optional[Callable[[bytes], None]] 163 size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD, # type: int 164 time_flush_threshold_ms=_DEFAULT_TIME_FLUSH_THRESHOLD_MS # type: int 165 ): 166 # type: (...) -> None 167 super().__init__(close_callback, flush_callback, size_flush_threshold) 168 assert time_flush_threshold_ms > 0 169 self._time_flush_threshold_ms = time_flush_threshold_ms 170 self._flush_lock = threading.Lock() 171 self._schedule_lock = threading.Lock() 172 self._closed = False 173 self._schedule_periodic_flush() 174 175 def flush(self): 176 # type: () -> None 177 with self._flush_lock: 178 super().flush() 179 180 def close(self): 181 # type: () -> None 182 with self._schedule_lock: 183 self._closed = True 184 if self._periodic_flusher: 185 self._periodic_flusher.cancel() 186 self._periodic_flusher = None 187 super().close() 188 189 def _schedule_periodic_flush(self): 190 # type: () -> None 191 def _flush(): 192 # type: () -> None 193 with self._schedule_lock: 194 if not self._closed: 195 self.flush() 196 197 self._periodic_flusher = PeriodicThread( 198 self._time_flush_threshold_ms / 1000.0, _flush) 199 self._periodic_flusher.daemon = True 200 self._periodic_flusher.start() 201 202 203 class PeriodicThread(threading.Thread): 204 """Call a function periodically with the specified number of seconds""" 205 206 def __init__( 207 self, 208 interval, # type: float 209 function, # type: Callable 210 args=None, # type: Optional[Iterable] 211 kwargs=None # type: Optional[Mapping[str, Any]] 212 ): 213 # type: (...) -> None 214 threading.Thread.__init__(self) 215 self._interval = interval 216 self._function = function 217 self._args = args if args is not None else [] 218 self._kwargs = kwargs if kwargs is not None else {} 219 self._finished = threading.Event() 220 221 def run(self): 222 # type: () -> None 223 next_call = time.time() + self._interval 224 while not self._finished.wait(next_call - time.time()): 225 next_call = next_call + self._interval 226 self._function(*self._args, **self._kwargs) 227 228 def cancel(self): 229 # type: () -> None 230 231 """Stop the thread if it hasn't finished yet.""" 232 self._finished.set() 233 234 235 class DataChannel(metaclass=abc.ABCMeta): 236 """Represents a channel for reading and writing data over the data plane. 237 238 Read data and timer from this channel with the input_elements method:: 239 240 for elements_data in data_channel.input_elements( 241 instruction_id, transform_ids, timers): 242 [process elements_data] 243 244 Write data to this channel using the output_stream method:: 245 246 out1 = data_channel.output_stream(instruction_id, transform_id) 247 out1.write(...) 248 out1.close() 249 250 Write timer to this channel using the output_timer_stream method:: 251 252 out1 = data_channel.output_timer_stream(instruction_id, 253 transform_id, 254 timer_family_id) 255 out1.write(...) 256 out1.close() 257 258 When all data/timer for all instructions is written, close the channel:: 259 260 data_channel.close() 261 """ 262 @abc.abstractmethod 263 def input_elements( 264 self, 265 instruction_id, # type: str 266 expected_inputs, # type: Collection[Union[str, Tuple[str, str]]] 267 abort_callback=None # type: Optional[Callable[[], bool]] 268 ): 269 # type: (...) -> Iterator[DataOrTimers] 270 271 """Returns an iterable of all Element.Data and Element.Timers bundles for 272 instruction_id. 273 274 This iterable terminates only once the full set of data has been recieved 275 for each of the expected transforms. It may block waiting for more data. 276 277 Args: 278 instruction_id: which instruction the results must belong to 279 expected_inputs: which transforms to wait on for completion 280 abort_callback: a callback to invoke if blocking returning whether 281 to abort before consuming all the data 282 """ 283 raise NotImplementedError(type(self)) 284 285 @abc.abstractmethod 286 def output_stream( 287 self, 288 instruction_id, # type: str 289 transform_id # type: str 290 ): 291 # type: (...) -> ClosableOutputStream 292 293 """Returns an output stream writing elements to transform_id. 294 295 Args: 296 instruction_id: which instruction this stream belongs to 297 transform_id: the transform_id of the returned stream 298 """ 299 raise NotImplementedError(type(self)) 300 301 @abc.abstractmethod 302 def output_timer_stream( 303 self, 304 instruction_id, # type: str 305 transform_id, # type: str 306 timer_family_id # type: str 307 ): 308 # type: (...) -> ClosableOutputStream 309 310 """Returns an output stream written timers to transform_id. 311 312 Args: 313 instruction_id: which instruction this stream belongs to 314 transform_id: the transform_id of the returned stream 315 timer_family_id: the timer family of the written timer 316 """ 317 raise NotImplementedError(type(self)) 318 319 @abc.abstractmethod 320 def close(self): 321 # type: () -> None 322 323 """Closes this channel, indicating that all data has been written. 324 325 Data can continue to be read. 326 327 If this channel is shared by many instructions, should only be called on 328 worker shutdown. 329 """ 330 raise NotImplementedError(type(self)) 331 332 333 class InMemoryDataChannel(DataChannel): 334 """An in-memory implementation of a DataChannel. 335 336 This channel is two-sided. What is written to one side is read by the other. 337 The inverse() method returns the other side of a instance. 338 """ 339 def __init__(self, inverse=None, data_buffer_time_limit_ms=0): 340 # type: (Optional[InMemoryDataChannel], int) -> None 341 self._inputs = [] # type: List[DataOrTimers] 342 self._data_buffer_time_limit_ms = data_buffer_time_limit_ms 343 self._inverse = inverse or InMemoryDataChannel( 344 self, data_buffer_time_limit_ms=data_buffer_time_limit_ms) 345 346 def inverse(self): 347 # type: () -> InMemoryDataChannel 348 return self._inverse 349 350 def input_elements( 351 self, 352 instruction_id, # type: str 353 unused_expected_inputs, # type: Any 354 abort_callback=None # type: Optional[Callable[[], bool]] 355 ): 356 # type: (...) -> Iterator[DataOrTimers] 357 other_inputs = [] 358 for element in self._inputs: 359 if element.instruction_id == instruction_id: 360 if isinstance(element, beam_fn_api_pb2.Elements.Timers): 361 if not element.is_last: 362 yield element 363 if isinstance(element, beam_fn_api_pb2.Elements.Data): 364 if element.data or element.is_last: 365 yield element 366 else: 367 other_inputs.append(element) 368 self._inputs = other_inputs 369 370 def output_timer_stream( 371 self, 372 instruction_id, # type: str 373 transform_id, # type: str 374 timer_family_id # type: str 375 ): 376 # type: (...) -> ClosableOutputStream 377 def add_to_inverse_output(timer): 378 # type: (bytes) -> None 379 if timer: 380 self._inverse._inputs.append( 381 beam_fn_api_pb2.Elements.Timers( 382 instruction_id=instruction_id, 383 transform_id=transform_id, 384 timer_family_id=timer_family_id, 385 timers=timer, 386 is_last=False)) 387 388 def close_stream(timer): 389 # type: (bytes) -> None 390 add_to_inverse_output(timer) 391 self._inverse._inputs.append( 392 beam_fn_api_pb2.Elements.Timers( 393 instruction_id=instruction_id, 394 transform_id=transform_id, 395 timer_family_id='', 396 is_last=True)) 397 398 return ClosableOutputStream.create( 399 add_to_inverse_output, close_stream, self._data_buffer_time_limit_ms) 400 401 def output_stream(self, instruction_id, transform_id): 402 # type: (str, str) -> ClosableOutputStream 403 def add_to_inverse_output(data): 404 # type: (bytes) -> None 405 self._inverse._inputs.append( # pylint: disable=protected-access 406 beam_fn_api_pb2.Elements.Data( 407 instruction_id=instruction_id, 408 transform_id=transform_id, 409 data=data)) 410 411 return ClosableOutputStream.create( 412 add_to_inverse_output, 413 add_to_inverse_output, 414 self._data_buffer_time_limit_ms) 415 416 def close(self): 417 # type: () -> None 418 pass 419 420 421 class _GrpcDataChannel(DataChannel): 422 """Base class for implementing a BeamFnData-based DataChannel.""" 423 424 _WRITES_FINISHED = object() 425 426 def __init__(self, data_buffer_time_limit_ms=0): 427 # type: (int) -> None 428 self._data_buffer_time_limit_ms = data_buffer_time_limit_ms 429 self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers] 430 self._received = collections.defaultdict( 431 lambda: queue.Queue(maxsize=5) 432 ) # type: DefaultDict[str, queue.Queue[DataOrTimers]] 433 434 # Keep a cache of completed instructions. Data for completed instructions 435 # must be discarded. See input_elements() and _clean_receiving_queue(). 436 # OrderedDict is used as FIFO set with the value being always `True`. 437 self._cleaned_instruction_ids = collections.OrderedDict( 438 ) # type: collections.OrderedDict[str, bool] 439 440 self._receive_lock = threading.Lock() 441 self._reads_finished = threading.Event() 442 self._closed = False 443 self._exception = None # type: Optional[Exception] 444 445 def close(self): 446 # type: () -> None 447 self._to_send.put(self._WRITES_FINISHED) # type: ignore[arg-type] 448 self._closed = True 449 450 def wait(self, timeout=None): 451 # type: (Optional[int]) -> None 452 self._reads_finished.wait(timeout) 453 454 def _receiving_queue(self, instruction_id): 455 # type: (str) -> Optional[queue.Queue[DataOrTimers]] 456 457 """ 458 Gets or creates queue for a instruction_id. Or, returns None if the 459 instruction_id is already cleaned up. This is best-effort as we track 460 a limited number of cleaned-up instructions. 461 """ 462 with self._receive_lock: 463 if instruction_id in self._cleaned_instruction_ids: 464 return None 465 return self._received[instruction_id] 466 467 def _clean_receiving_queue(self, instruction_id): 468 # type: (str) -> None 469 470 """ 471 Removes the queue and adds the instruction_id to the cleaned-up list. The 472 instruction_id cannot be reused for new queue. 473 """ 474 with self._receive_lock: 475 self._received.pop(instruction_id) 476 self._cleaned_instruction_ids[instruction_id] = True 477 while len(self._cleaned_instruction_ids) > _MAX_CLEANED_INSTRUCTIONS: 478 self._cleaned_instruction_ids.popitem(last=False) 479 480 def input_elements( 481 self, 482 instruction_id, # type: str 483 expected_inputs, # type: Collection[Union[str, Tuple[str, str]]] 484 abort_callback=None # type: Optional[Callable[[], bool]] 485 ): 486 487 # type: (...) -> Iterator[DataOrTimers] 488 489 """ 490 Generator to retrieve elements for an instruction_id 491 input_elements should be called only once for an instruction_id 492 493 Args: 494 instruction_id(str): instruction_id for which data is read 495 expected_inputs(collection): expected inputs, include both data and timer. 496 """ 497 received = self._receiving_queue(instruction_id) 498 if received is None: 499 raise RuntimeError('Instruction cleaned up already %s' % instruction_id) 500 done_inputs = set() # type: Set[Union[str, Tuple[str, str]]] 501 abort_callback = abort_callback or (lambda: False) 502 try: 503 while len(done_inputs) < len(expected_inputs): 504 try: 505 element = received.get(timeout=1) 506 except queue.Empty: 507 if self._closed: 508 raise RuntimeError('Channel closed prematurely.') 509 if abort_callback(): 510 return 511 if self._exception: 512 raise self._exception from None 513 else: 514 if isinstance(element, beam_fn_api_pb2.Elements.Timers): 515 if element.is_last: 516 done_inputs.add((element.transform_id, element.timer_family_id)) 517 else: 518 yield element 519 elif isinstance(element, beam_fn_api_pb2.Elements.Data): 520 if element.is_last: 521 done_inputs.add(element.transform_id) 522 else: 523 assert element.transform_id not in done_inputs 524 yield element 525 else: 526 raise ValueError('Unexpected input element type %s' % type(element)) 527 finally: 528 # Instruction_ids are not reusable so Clean queue once we are done with 529 # an instruction_id 530 self._clean_receiving_queue(instruction_id) 531 532 def output_stream(self, instruction_id, transform_id): 533 # type: (str, str) -> ClosableOutputStream 534 def add_to_send_queue(data): 535 # type: (bytes) -> None 536 if data: 537 self._to_send.put( 538 beam_fn_api_pb2.Elements.Data( 539 instruction_id=instruction_id, 540 transform_id=transform_id, 541 data=data)) 542 543 def close_callback(data): 544 # type: (bytes) -> None 545 add_to_send_queue(data) 546 # End of stream marker. 547 self._to_send.put( 548 beam_fn_api_pb2.Elements.Data( 549 instruction_id=instruction_id, 550 transform_id=transform_id, 551 is_last=True)) 552 553 return ClosableOutputStream.create( 554 close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) 555 556 def output_timer_stream( 557 self, 558 instruction_id, # type: str 559 transform_id, # type: str 560 timer_family_id # type: str 561 ): 562 # type: (...) -> ClosableOutputStream 563 def add_to_send_queue(timer): 564 # type: (bytes) -> None 565 if timer: 566 self._to_send.put( 567 beam_fn_api_pb2.Elements.Timers( 568 instruction_id=instruction_id, 569 transform_id=transform_id, 570 timer_family_id=timer_family_id, 571 timers=timer, 572 is_last=False)) 573 574 def close_callback(timer): 575 # type: (bytes) -> None 576 add_to_send_queue(timer) 577 self._to_send.put( 578 beam_fn_api_pb2.Elements.Timers( 579 instruction_id=instruction_id, 580 transform_id=transform_id, 581 timer_family_id=timer_family_id, 582 is_last=True)) 583 584 return ClosableOutputStream.create( 585 close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) 586 587 def _write_outputs(self): 588 # type: () -> Iterator[beam_fn_api_pb2.Elements] 589 stream_done = False 590 while not stream_done: 591 streams = [self._to_send.get()] 592 try: 593 # Coalesce up to 100 other items. 594 for _ in range(100): 595 streams.append(self._to_send.get_nowait()) 596 except queue.Empty: 597 pass 598 if streams[-1] is self._WRITES_FINISHED: 599 stream_done = True 600 streams.pop() 601 if streams: 602 data_stream = [] 603 timer_stream = [] 604 for stream in streams: 605 if isinstance(stream, beam_fn_api_pb2.Elements.Timers): 606 timer_stream.append(stream) 607 elif isinstance(stream, beam_fn_api_pb2.Elements.Data): 608 data_stream.append(stream) 609 else: 610 raise ValueError('Unexpected output element type %s' % type(stream)) 611 yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream) 612 613 def _read_inputs(self, elements_iterator): 614 # type: (Iterable[beam_fn_api_pb2.Elements]) -> None 615 616 next_discard_log_time = 0 # type: float 617 618 def _put_queue(instruction_id, element): 619 # type: (str, Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> None 620 621 """ 622 Puts element to the queue of the instruction_id, or discards it if the 623 instruction_id is already cleaned up. 624 """ 625 nonlocal next_discard_log_time 626 start_time = time.time() 627 next_waiting_log_time = start_time + 300 628 while True: 629 input_queue = self._receiving_queue(instruction_id) 630 if input_queue is None: 631 current_time = time.time() 632 if next_discard_log_time <= current_time: 633 # Log every 10 seconds across all _put_queue calls 634 _LOGGER.info( 635 'Discard inputs for cleaned up instruction: %s', instruction_id) 636 next_discard_log_time = current_time + 10 637 return 638 try: 639 input_queue.put(element, timeout=1) 640 return 641 except queue.Full: 642 current_time = time.time() 643 if next_waiting_log_time <= current_time: 644 # Log every 5 mins in each _put_queue call 645 _LOGGER.info( 646 'Waiting on input queue of instruction: %s for %.2f seconds', 647 instruction_id, 648 current_time - start_time) 649 next_waiting_log_time = current_time + 300 650 651 try: 652 for elements in elements_iterator: 653 for timer in elements.timers: 654 _put_queue(timer.instruction_id, timer) 655 for data in elements.data: 656 _put_queue(data.instruction_id, data) 657 except Exception as e: 658 if not self._closed: 659 _LOGGER.exception('Failed to read inputs in the data plane.') 660 self._exception = e 661 raise 662 finally: 663 self._closed = True 664 self._reads_finished.set() 665 666 def set_inputs(self, elements_iterator): 667 # type: (Iterable[beam_fn_api_pb2.Elements]) -> None 668 reader = threading.Thread( 669 target=lambda: self._read_inputs(elements_iterator), 670 name='read_grpc_client_inputs') 671 reader.daemon = True 672 reader.start() 673 674 675 class GrpcClientDataChannel(_GrpcDataChannel): 676 """A DataChannel wrapping the client side of a BeamFnData connection.""" 677 678 def __init__( 679 self, 680 data_stub, # type: beam_fn_api_pb2_grpc.BeamFnDataStub 681 data_buffer_time_limit_ms=0 # type: int 682 ): 683 # type: (...) -> None 684 super().__init__(data_buffer_time_limit_ms) 685 self.set_inputs(data_stub.Data(self._write_outputs())) 686 687 688 class BeamFnDataServicer(beam_fn_api_pb2_grpc.BeamFnDataServicer): 689 """Implementation of BeamFnDataServicer for any number of clients""" 690 def __init__( 691 self, 692 data_buffer_time_limit_ms=0 # type: int 693 ): 694 self._lock = threading.Lock() 695 self._connections_by_worker_id = collections.defaultdict( 696 lambda: _GrpcDataChannel(data_buffer_time_limit_ms) 697 ) # type: DefaultDict[str, _GrpcDataChannel] 698 699 def get_conn_by_worker_id(self, worker_id): 700 # type: (str) -> _GrpcDataChannel 701 with self._lock: 702 return self._connections_by_worker_id[worker_id] 703 704 def Data( 705 self, 706 elements_iterator, # type: Iterable[beam_fn_api_pb2.Elements] 707 context # type: Any 708 ): 709 # type: (...) -> Iterator[beam_fn_api_pb2.Elements] 710 worker_id = dict(context.invocation_metadata())['worker_id'] 711 data_conn = self.get_conn_by_worker_id(worker_id) 712 data_conn.set_inputs(elements_iterator) 713 for elements in data_conn._write_outputs(): 714 yield elements 715 716 717 class DataChannelFactory(metaclass=abc.ABCMeta): 718 """An abstract factory for creating ``DataChannel``.""" 719 @abc.abstractmethod 720 def create_data_channel(self, remote_grpc_port): 721 # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel 722 723 """Returns a ``DataChannel`` from the given RemoteGrpcPort.""" 724 raise NotImplementedError(type(self)) 725 726 @abc.abstractmethod 727 def create_data_channel_from_url(self, url): 728 # type: (str) -> Optional[GrpcClientDataChannel] 729 730 """Returns a ``DataChannel`` from the given url.""" 731 raise NotImplementedError(type(self)) 732 733 @abc.abstractmethod 734 def close(self): 735 # type: () -> None 736 737 """Close all channels that this factory owns.""" 738 raise NotImplementedError(type(self)) 739 740 741 class GrpcClientDataChannelFactory(DataChannelFactory): 742 """A factory for ``GrpcClientDataChannel``. 743 744 Caches the created channels by ``data descriptor url``. 745 """ 746 747 def __init__( 748 self, 749 credentials=None, # type: Any 750 worker_id=None, # type: Optional[str] 751 data_buffer_time_limit_ms=0 # type: int 752 ): 753 # type: (...) -> None 754 self._data_channel_cache = {} # type: Dict[str, GrpcClientDataChannel] 755 self._lock = threading.Lock() 756 self._credentials = None 757 self._worker_id = worker_id 758 self._data_buffer_time_limit_ms = data_buffer_time_limit_ms 759 if credentials is not None: 760 _LOGGER.info('Using secure channel creds.') 761 self._credentials = credentials 762 763 def create_data_channel_from_url(self, url): 764 # type: (str) -> Optional[GrpcClientDataChannel] 765 if not url: 766 return None 767 if url not in self._data_channel_cache: 768 with self._lock: 769 if url not in self._data_channel_cache: 770 _LOGGER.info('Creating client data channel for %s', url) 771 # Options to have no limits (-1) on the size of the messages 772 # received or sent over the data plane. The actual buffer size 773 # is controlled in a layer above. 774 channel_options = [("grpc.max_receive_message_length", -1), 775 ("grpc.max_send_message_length", -1), 776 ("grpc.service_config", _GRPC_SERVICE_CONFIG)] 777 grpc_channel = None 778 if self._credentials is None: 779 grpc_channel = GRPCChannelFactory.insecure_channel( 780 url, options=channel_options) 781 else: 782 grpc_channel = GRPCChannelFactory.secure_channel( 783 url, self._credentials, options=channel_options) 784 # Add workerId to the grpc channel 785 grpc_channel = grpc.intercept_channel( 786 grpc_channel, WorkerIdInterceptor(self._worker_id)) 787 self._data_channel_cache[url] = GrpcClientDataChannel( 788 beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel), 789 self._data_buffer_time_limit_ms) 790 791 return self._data_channel_cache[url] 792 793 def create_data_channel(self, remote_grpc_port): 794 # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel 795 url = remote_grpc_port.api_service_descriptor.url 796 # TODO(https://github.com/apache/beam/issues/19737): this can return None 797 # if url is falsey, but this seems incorrect, as code that calls this 798 # method seems to always expect non-Optional values. 799 return self.create_data_channel_from_url(url) # type: ignore[return-value] 800 801 def close(self): 802 # type: () -> None 803 _LOGGER.info('Closing all cached grpc data channels.') 804 for _, channel in self._data_channel_cache.items(): 805 channel.close() 806 self._data_channel_cache.clear() 807 808 809 class InMemoryDataChannelFactory(DataChannelFactory): 810 """A singleton factory for ``InMemoryDataChannel``.""" 811 def __init__(self, in_memory_data_channel): 812 # type: (GrpcClientDataChannel) -> None 813 self._in_memory_data_channel = in_memory_data_channel 814 815 def create_data_channel(self, unused_remote_grpc_port): 816 # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel 817 return self._in_memory_data_channel 818 819 def create_data_channel_from_url(self, url): 820 # type: (Any) -> GrpcClientDataChannel 821 return self._in_memory_data_channel 822 823 def close(self): 824 # type: () -> None 825 pass