github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/test_stream_impl.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """The TestStream implementation for the DirectRunner
    19  
    20  The DirectRunner implements TestStream as the _TestStream class which is used
    21  to store the events in memory, the _WatermarkController which is used to set the
    22  watermark and emit events, and the multiplexer which sends events to the correct
    23  tagged PCollection.
    24  """
    25  
    26  # pytype: skip-file
    27  
    28  import itertools
    29  import logging
    30  from queue import Empty as EmptyException
    31  from queue import Queue
    32  from threading import Thread
    33  
    34  import grpc
    35  
    36  from apache_beam import ParDo
    37  from apache_beam import coders
    38  from apache_beam import pvalue
    39  from apache_beam.portability.api import beam_runner_api_pb2
    40  from apache_beam.portability.api import beam_runner_api_pb2_grpc
    41  from apache_beam.testing.test_stream import ElementEvent
    42  from apache_beam.testing.test_stream import ProcessingTimeEvent
    43  from apache_beam.testing.test_stream import WatermarkEvent
    44  from apache_beam.transforms import PTransform
    45  from apache_beam.transforms import core
    46  from apache_beam.transforms import window
    47  from apache_beam.transforms.window import TimestampedValue
    48  from apache_beam.utils import timestamp
    49  from apache_beam.utils.timestamp import Duration
    50  from apache_beam.utils.timestamp import Timestamp
    51  
    52  _LOGGER = logging.getLogger(__name__)
    53  
    54  
    55  class _EndOfStream:
    56    pass
    57  
    58  
    59  class _WatermarkController(PTransform):
    60    """A runner-overridable PTransform Primitive to control the watermark.
    61  
    62    Expected implementation behavior:
    63     - If the instance recieves a WatermarkEvent, it sets its output watermark to
    64       the specified value then drops the event.
    65     - If the instance receives an ElementEvent, it emits all specified elements
    66       to the Global Window with the event time set to the element's timestamp.
    67    """
    68    def __init__(self, output_tag):
    69      self.output_tag = output_tag
    70  
    71    def get_windowing(self, _):
    72      return core.Windowing(window.GlobalWindows())
    73  
    74    def expand(self, pcoll):
    75      ret = pvalue.PCollection.from_(pcoll)
    76      ret.tag = self.output_tag
    77      return ret
    78  
    79  
    80  class _ExpandableTestStream(PTransform):
    81    def __init__(self, test_stream):
    82      self.test_stream = test_stream
    83  
    84    def expand(self, pbegin):
    85      """Expands the TestStream into the DirectRunner implementation.
    86  
    87      Takes the TestStream transform and creates a _TestStream -> multiplexer ->
    88      _WatermarkController.
    89      """
    90  
    91      assert isinstance(pbegin, pvalue.PBegin)
    92  
    93      # If there is only one tag there is no need to add the multiplexer.
    94      if len(self.test_stream.output_tags) == 1:
    95        return (
    96            pbegin
    97            | _TestStream(
    98                self.test_stream.output_tags,
    99                events=self.test_stream._events,
   100                coder=self.test_stream.coder,
   101                endpoint=self.test_stream._endpoint)
   102            | _WatermarkController(list(self.test_stream.output_tags)[0]))
   103  
   104      # Multiplex to the correct PCollection based upon the event tag.
   105      def mux(event):
   106        if event.tag:
   107          yield pvalue.TaggedOutput(event.tag, event)
   108        else:
   109          yield event
   110  
   111      mux_output = (
   112          pbegin
   113          | _TestStream(
   114              self.test_stream.output_tags,
   115              events=self.test_stream._events,
   116              coder=self.test_stream.coder,
   117              endpoint=self.test_stream._endpoint)
   118          | 'TestStream Multiplexer' >> ParDo(mux).with_outputs())
   119  
   120      # Apply a way to control the watermark per output. It is necessary to
   121      # have an individual _WatermarkController per PCollection because the
   122      # calculation of the input watermark of a transform is based on the event
   123      # timestamp of the elements flowing through it. Meaning, it is impossible
   124      # to control the output watermarks of the individual PCollections solely
   125      # on the event timestamps.
   126      outputs = {}
   127      for tag in self.test_stream.output_tags:
   128        label = '_WatermarkController[{}]'.format(tag)
   129        outputs[tag] = (mux_output[tag] | label >> _WatermarkController(tag))
   130  
   131      return outputs
   132  
   133  
   134  class _TestStream(PTransform):
   135    """Test stream that generates events on an unbounded PCollection of elements.
   136  
   137    Each event emits elements, advances the watermark or advances the processing
   138    time.  After all of the specified elements are emitted, ceases to produce
   139    output.
   140  
   141    Expected implementation behavior:
   142     - If the instance receives a WatermarkEvent with the WATERMARK_CONTROL_TAG
   143       then the instance sets its own watermark hold at the specified value and
   144       drops the event.
   145     - If the instance receives any other WatermarkEvent or ElementEvent, it
   146       passes it to the consumer.
   147    """
   148  
   149    # This tag is used on WatermarkEvents to control the watermark at the root
   150    # TestStream.
   151    WATERMARK_CONTROL_TAG = '_TestStream_Watermark'
   152  
   153    def __init__(
   154        self,
   155        output_tags,
   156        coder=coders.FastPrimitivesCoder(),
   157        events=None,
   158        endpoint=None):
   159      assert coder is not None
   160      self.coder = coder
   161      self._raw_events = events
   162      self._events = self._add_watermark_advancements(output_tags, events)
   163      self.output_tags = output_tags
   164      self.endpoint = endpoint
   165  
   166    def _watermark_starts(self, output_tags):
   167      """Sentinel values to hold the watermark of outputs to -inf.
   168  
   169      The output watermarks of the output PCollections (fake unbounded sources) in
   170      a TestStream are controlled by watermark holds. This sets the hold of each
   171      output PCollection so that the individual holds can be controlled by the
   172      given events.
   173      """
   174      return [WatermarkEvent(timestamp.MIN_TIMESTAMP, tag) for tag in output_tags]
   175  
   176    def _watermark_stops(self, output_tags):
   177      """Sentinel values to close the watermark of outputs."""
   178      return [WatermarkEvent(timestamp.MAX_TIMESTAMP, tag) for tag in output_tags]
   179  
   180    def _test_stream_start(self):
   181      """Sentinel value to move the watermark hold of the TestStream to +inf.
   182  
   183      This sets a hold to +inf such that the individual holds of the output
   184      PCollections are allowed to modify their individial output watermarks with
   185      their holds. This is because the calculation of the output watermark is a
   186      min over all input watermarks.
   187      """
   188      return [
   189          WatermarkEvent(
   190              timestamp.MAX_TIMESTAMP - timestamp.TIME_GRANULARITY,
   191              _TestStream.WATERMARK_CONTROL_TAG)
   192      ]
   193  
   194    def _test_stream_stop(self):
   195      """Sentinel value to close the watermark of the TestStream."""
   196      return [
   197          WatermarkEvent(
   198              timestamp.MAX_TIMESTAMP, _TestStream.WATERMARK_CONTROL_TAG)
   199      ]
   200  
   201    def _test_stream_init(self):
   202      """Sentinel value to hold the watermark of the TestStream to -inf.
   203  
   204      This sets a hold to ensure that the output watermarks of the output
   205      PCollections do not advance to +inf before their watermark holds are set.
   206      """
   207      return [
   208          WatermarkEvent(
   209              timestamp.MIN_TIMESTAMP, _TestStream.WATERMARK_CONTROL_TAG)
   210      ]
   211  
   212    def _set_up(self, output_tags):
   213      return (
   214          self._test_stream_init() + self._watermark_starts(output_tags) +
   215          self._test_stream_start())
   216  
   217    def _tear_down(self, output_tags):
   218      return self._watermark_stops(output_tags) + self._test_stream_stop()
   219  
   220    def _add_watermark_advancements(self, output_tags, events):
   221      """Adds watermark advancements to the given events.
   222  
   223      The following watermark advancements can be done on the runner side.
   224      However, it makes the logic on the runner side much more complicated than
   225      it needs to be.
   226  
   227      In order for watermarks to be properly advanced in a TestStream, a specific
   228      sequence of watermark holds must be sent:
   229  
   230      1. Hold the root watermark at -inf (this prevents the pipeline from
   231         immediately returning).
   232      2. Hold the watermarks at the WatermarkControllerss at -inf (this prevents
   233         the pipeline from immediately returning).
   234      3. Advance the root watermark to +inf - 1 (this allows the downstream
   235         WatermarkControllers to control their watermarks via holds).
   236      4. Advance watermarks as normal.
   237      5. Advance WatermarkController watermarks to +inf
   238      6. Advance root watermark to +inf.
   239      """
   240      if not events:
   241        return []
   242  
   243      return self._set_up(output_tags) + events + self._tear_down(output_tags)
   244  
   245    def get_windowing(self, unused_inputs):
   246      return core.Windowing(window.GlobalWindows())
   247  
   248    def expand(self, pcoll):
   249      return pvalue.PCollection(pcoll.pipeline, is_bounded=False)
   250  
   251    def _infer_output_coder(self, input_type=None, input_coder=None):
   252      return self.coder
   253  
   254    @staticmethod
   255    def events_from_script(events):
   256      """Yields the in-memory events.
   257      """
   258      return itertools.chain(events)
   259  
   260    @staticmethod
   261    def _stream_events_from_rpc(endpoint, output_tags, coder, channel, is_alive):
   262      """Yields the events received from the given endpoint.
   263  
   264      This is the producer thread that reads events from the TestStreamService and
   265      puts them onto the shared queue. At the end of the stream, an _EndOfStream
   266      is placed on the channel to signify a successful end.
   267      """
   268      stub_channel = grpc.insecure_channel(endpoint)
   269      stub = beam_runner_api_pb2_grpc.TestStreamServiceStub(stub_channel)
   270  
   271      # Request the PCollections that we are looking for from the service.
   272      event_request = beam_runner_api_pb2.EventsRequest(
   273          output_ids=[str(tag) for tag in output_tags])
   274  
   275      event_stream = stub.Events(event_request)
   276      try:
   277        for e in event_stream:
   278          channel.put(_TestStream.test_stream_payload_to_events(e, coder))
   279          if not is_alive():
   280            return
   281      except grpc.RpcError as e:
   282        # Do not raise an exception in the non-error status codes. These can occur
   283        # when the Python interpreter shuts down or when in a notebook environment
   284        # when the kernel is interrupted.
   285        if e.code() in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE):
   286          return
   287        raise e
   288      finally:
   289        # Gracefully stop the job if there is an exception.
   290        channel.put(_EndOfStream())
   291  
   292    @staticmethod
   293    def events_from_rpc(endpoint, output_tags, coder, evaluation_context):
   294      """Yields the events received from the given endpoint.
   295  
   296      This method starts a new thread that reads from the TestStreamService and
   297      puts the events onto a shared queue. This method then yields all elements
   298      from the queue. Unfortunately, this is necessary because the GRPC API does
   299      not allow for non-blocking calls when utilizing a streaming RPC. It is
   300      officially suggested from the docs to use a producer/consumer pattern to
   301      handle streaming RPCs. By doing so, this gives this method control over when
   302      to cancel reading from the RPC if the server takes too long to respond.
   303      """
   304      # Shared variable with the producer queue. This shuts down the producer if
   305      # the consumer exits early.
   306      shutdown_requested = False
   307  
   308      def is_alive():
   309        return not (shutdown_requested or evaluation_context.shutdown_requested)
   310  
   311      # The shared queue that allows the producer and consumer to communicate.
   312      channel = Queue(
   313      )  # type: Queue[Union[test_stream.Event, _EndOfStream]] # noqa: F821
   314      event_stream = Thread(
   315          target=_TestStream._stream_events_from_rpc,
   316          args=(endpoint, output_tags, coder, channel, is_alive))
   317      event_stream.setDaemon(True)
   318      event_stream.start()
   319  
   320      # This pumps the shared queue for events until the _EndOfStream sentinel is
   321      # reached. If the TestStreamService takes longer than expected, the queue
   322      # will timeout and an EmptyException will be raised. This also sets the
   323      # shared is_alive sentinel to shut down the producer.
   324      while True:
   325        try:
   326          # Raise an EmptyException if there are no events during the last timeout
   327          # period.
   328          event = channel.get(timeout=30)
   329          if isinstance(event, _EndOfStream):
   330            break
   331          yield event
   332        except EmptyException as e:
   333          _LOGGER.warning(
   334              'TestStream timed out waiting for new events from service.'
   335              ' Stopping pipeline.')
   336          shutdown_requested = True
   337          raise e
   338  
   339    @staticmethod
   340    def test_stream_payload_to_events(payload, coder):
   341      """Returns a TestStream Python event object from a TestStream event Proto.
   342      """
   343      if payload.HasField('element_event'):
   344        element_event = payload.element_event
   345        elements = [
   346            TimestampedValue(
   347                coder.decode(e.encoded_element), Timestamp(micros=e.timestamp))
   348            for e in element_event.elements
   349        ]
   350        return ElementEvent(timestamped_values=elements, tag=element_event.tag)
   351  
   352      if payload.HasField('watermark_event'):
   353        watermark_event = payload.watermark_event
   354        return WatermarkEvent(
   355            Timestamp(micros=watermark_event.new_watermark),
   356            tag=watermark_event.tag)
   357  
   358      if payload.HasField('processing_time_event'):
   359        processing_time_event = payload.processing_time_event
   360        return ProcessingTimeEvent(
   361            Duration(micros=processing_time_event.advance_duration))
   362  
   363      raise RuntimeError(
   364          'Received a proto without the specified fields: {}'.format(payload))