github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/evaluation_context.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """EvaluationContext tracks global state, triggers and watermarks."""
    19  
    20  # pytype: skip-file
    21  
    22  import collections
    23  import threading
    24  from typing import TYPE_CHECKING
    25  from typing import Any
    26  from typing import DefaultDict
    27  from typing import Dict
    28  from typing import Iterable
    29  from typing import List
    30  from typing import Optional
    31  from typing import Tuple
    32  from typing import Union
    33  
    34  from apache_beam.runners.direct.direct_metrics import DirectMetrics
    35  from apache_beam.runners.direct.executor import TransformExecutor
    36  from apache_beam.runners.direct.watermark_manager import WatermarkManager
    37  from apache_beam.transforms import sideinputs
    38  from apache_beam.transforms.trigger import InMemoryUnmergedState
    39  from apache_beam.utils import counters
    40  
    41  if TYPE_CHECKING:
    42    from apache_beam import pvalue
    43    from apache_beam.pipeline import AppliedPTransform
    44    from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle
    45    from apache_beam.runners.direct.util import TimerFiring
    46    from apache_beam.runners.direct.util import TransformResult
    47    from apache_beam.runners.direct.watermark_manager import _TransformWatermarks
    48    from apache_beam.utils.timestamp import Timestamp
    49  
    50  
    51  class _ExecutionContext(object):
    52    """Contains the context for the execution of a single PTransform.
    53  
    54    It holds the watermarks for that transform, as well as keyed states.
    55    """
    56    def __init__(
    57        self,
    58        watermarks,  # type: _TransformWatermarks
    59        keyed_states):
    60      self.watermarks = watermarks
    61      self.keyed_states = keyed_states
    62  
    63      self._step_context = None
    64  
    65    def get_step_context(self):
    66      if not self._step_context:
    67        self._step_context = DirectStepContext(self.keyed_states)
    68      return self._step_context
    69  
    70    def reset(self):
    71      self._step_context = None
    72  
    73  
    74  class _SideInputView(object):
    75    def __init__(self, view):
    76      self._view = view
    77      self.blocked_tasks = collections.deque()
    78      self.elements = []
    79      self.value = None
    80      self.watermark = None
    81  
    82    def __repr__(self):
    83      elements_string = (
    84          ', '.join(str(elm) for elm in self.elements) if self.elements else '[]')
    85      return '_SideInputView(elements=%s)' % elements_string
    86  
    87  
    88  class _SideInputsContainer(object):
    89    """An in-process container for side inputs.
    90  
    91    It provides methods for blocking until a side-input is available and writing
    92    to a side input.
    93    """
    94    def __init__(self, side_inputs):
    95      # type: (Iterable[pvalue.AsSideInput]) -> None
    96      self._lock = threading.Lock()
    97      self._views = {}  # type: Dict[pvalue.AsSideInput, _SideInputView]
    98      self._transform_to_side_inputs = collections.defaultdict(
    99          list
   100      )  # type: DefaultDict[Optional[AppliedPTransform], List[pvalue.AsSideInput]]
   101      # this appears unused:
   102      self._side_input_to_blocked_tasks = collections.defaultdict(list)  # type: ignore
   103  
   104      for side in side_inputs:
   105        self._views[side] = _SideInputView(side)
   106        self._transform_to_side_inputs[side.pvalue.producer].append(side)
   107  
   108    def __repr__(self):
   109      views_string = (
   110          ', '.join(str(elm)
   111                    for elm in self._views.values()) if self._views else '[]')
   112      return '_SideInputsContainer(_views=%s)' % views_string
   113  
   114    def get_value_or_block_until_ready(self,
   115                                       side_input,
   116                                       task,  # type: TransformExecutor
   117                                       block_until  # type: Timestamp
   118                                      ):
   119      # type: (...) -> Any
   120  
   121      """Returns the value of a view whose task is unblocked or blocks its task.
   122  
   123      It gets the value of a view whose watermark has been updated and
   124      surpasses a given value.
   125  
   126      Args:
   127        side_input: ``_UnpickledSideInput`` value.
   128        task: ``TransformExecutor`` task waiting on a side input.
   129        block_until: Timestamp after which the task gets unblocked.
   130  
   131      Returns:
   132        The ``SideInputMap`` value of a view when the tasks it blocks are
   133        unblocked. Otherwise, None.
   134      """
   135      with self._lock:
   136        view = self._views[side_input]
   137        if view.watermark and view.watermark.output_watermark >= block_until:
   138          view.value = self._pvalue_to_value(side_input, view.elements)
   139          return view.value
   140        else:
   141          view.blocked_tasks.append((task, block_until))
   142          task.blocked = True
   143  
   144    def add_values(self, side_input, values):
   145      with self._lock:
   146        view = self._views[side_input]
   147        view.elements.extend(values)
   148  
   149    def update_watermarks_for_transform_and_unblock_tasks(
   150        self, ptransform, watermark):
   151      # type: (...) -> List[Tuple[TransformExecutor, Timestamp]]
   152  
   153      """Updates _SideInputsContainer after a watermark update and unbloks tasks.
   154  
   155      It traverses the list of side inputs per PTransform and calls
   156      _update_watermarks_for_side_input_and_unblock_tasks to unblock tasks.
   157  
   158      Args:
   159        ptransform: Value of a PTransform.
   160        watermark: Value of the watermark after an update for a PTransform.
   161  
   162      Returns:
   163        Tasks that get unblocked as a result of the watermark advancing.
   164      """
   165      unblocked_tasks = []
   166      for side in self._transform_to_side_inputs[ptransform]:
   167        unblocked_tasks.extend(
   168            self._update_watermarks_for_side_input_and_unblock_tasks(
   169                side, watermark))
   170      return unblocked_tasks
   171  
   172    def _update_watermarks_for_side_input_and_unblock_tasks(
   173        self, side_input, watermark):
   174      # type: (...) -> List[Tuple[TransformExecutor, Timestamp]]
   175  
   176      """Helps update _SideInputsContainer after a watermark update.
   177  
   178      For each view of the side input, it updates the value of the watermark
   179      recorded when the watermark moved and unblocks tasks accordingly.
   180  
   181      Args:
   182        side_input: ``_UnpickledSideInput`` value.
   183        watermark: Value of the watermark after an update for a PTransform.
   184  
   185      Returns:
   186        Tasks that get unblocked as a result of the watermark advancing.
   187      """
   188      with self._lock:
   189        view = self._views[side_input]
   190        view.watermark = watermark
   191  
   192        unblocked_tasks = []
   193        tasks_just_unblocked = []
   194        for task, block_until in view.blocked_tasks:
   195          if watermark.output_watermark >= block_until:
   196            view.value = self._pvalue_to_value(side_input, view.elements)
   197            unblocked_tasks.append(task)
   198            tasks_just_unblocked.append((task, block_until))
   199            task.blocked = False
   200        for task in tasks_just_unblocked:
   201          view.blocked_tasks.remove(task)
   202        return unblocked_tasks
   203  
   204    def _pvalue_to_value(self, side_input, values):
   205      """Given a side input, returns the associated value in its requested form.
   206  
   207      Args:
   208        side_input: _UnpickledSideInput object.
   209        values: Iterable values associated with the side input.
   210  
   211      Returns:
   212        The side input in its requested form.
   213  
   214      Raises:
   215        ValueError: If values cannot be converted into the requested form.
   216      """
   217      return sideinputs.SideInputMap(
   218          type(side_input), side_input._view_options(), values)
   219  
   220  
   221  class EvaluationContext(object):
   222    """Evaluation context with the global state information of the pipeline.
   223  
   224    The evaluation context for a specific pipeline being executed by the
   225    DirectRunner. Contains state shared within the execution across all
   226    transforms.
   227  
   228    EvaluationContext contains shared state for an execution of the
   229    DirectRunner that can be used while evaluating a PTransform. This
   230    consists of views into underlying state and watermark implementations, access
   231    to read and write side inputs, and constructing counter sets and
   232    execution contexts. This includes executing callbacks asynchronously when
   233    state changes to the appropriate point (e.g. when a side input is
   234    requested and known to be empty).
   235  
   236    EvaluationContext also handles results by committing finalizing
   237    bundles based on the current global state and updating the global state
   238    appropriately. This includes updating the per-(step,key) state, updating
   239    global watermarks, and executing any callbacks that can be executed.
   240    """
   241  
   242    def __init__(self,
   243                 pipeline_options,
   244                 bundle_factory,  # type: BundleFactory
   245                 root_transforms,
   246                 value_to_consumers,
   247                 step_names,
   248                 views,  # type: Iterable[pvalue.AsSideInput]
   249                 clock
   250                ):
   251      self.pipeline_options = pipeline_options
   252      self._bundle_factory = bundle_factory
   253      self._root_transforms = root_transforms
   254      self._value_to_consumers = value_to_consumers
   255      self._step_names = step_names
   256      self.views = views
   257      self._pcollection_to_views = collections.defaultdict(
   258          list)  # type: DefaultDict[pvalue.PValue, List[pvalue.AsSideInput]]
   259      for view in views:
   260        self._pcollection_to_views[view.pvalue].append(view)
   261      self._transform_keyed_states = self._initialize_keyed_states(
   262          root_transforms, value_to_consumers)
   263      self._side_inputs_container = _SideInputsContainer(views)
   264      self._watermark_manager = WatermarkManager(
   265          clock,
   266          root_transforms,
   267          value_to_consumers,
   268          self._transform_keyed_states)
   269      self._pending_unblocked_tasks = [
   270      ]  # type: List[Tuple[TransformExecutor, Timestamp]]
   271      self._counter_factory = counters.CounterFactory()
   272      self._metrics = DirectMetrics()
   273  
   274      self._lock = threading.Lock()
   275      self.shutdown_requested = False
   276  
   277    def _initialize_keyed_states(self, root_transforms, value_to_consumers):
   278      """Initialize user state dicts.
   279  
   280      These dicts track user state per-key, per-transform and per-window.
   281      """
   282      transform_keyed_states = {}
   283      for transform in root_transforms:
   284        transform_keyed_states[transform] = {}
   285      for consumers in value_to_consumers.values():
   286        for consumer in consumers:
   287          transform_keyed_states[consumer] = {}
   288      return transform_keyed_states
   289  
   290    def metrics(self):
   291      # TODO. Should this be made a @property?
   292      return self._metrics
   293  
   294    def is_root_transform(self, applied_ptransform):
   295      # type: (AppliedPTransform) -> bool
   296      return applied_ptransform in self._root_transforms
   297  
   298    def handle_result(self,
   299                      completed_bundle,  # type: _Bundle
   300                      completed_timers,
   301                      result  # type: TransformResult
   302                     ):
   303      """Handle the provided result produced after evaluating the input bundle.
   304  
   305      Handle the provided TransformResult, produced after evaluating
   306      the provided committed bundle (potentially None, if the result of a root
   307      PTransform).
   308  
   309      The result is the output of running the transform contained in the
   310      TransformResult on the contents of the provided bundle.
   311  
   312      Args:
   313        completed_bundle: the bundle that was processed to produce the result.
   314        completed_timers: the timers that were delivered to produce the
   315                          completed_bundle.
   316        result: the ``TransformResult`` of evaluating the input bundle
   317  
   318      Returns:
   319        the committed bundles contained within the handled result.
   320      """
   321      with self._lock:
   322        committed_bundles, unprocessed_bundles = self._commit_bundles(
   323            result.uncommitted_output_bundles,
   324            result.unprocessed_bundles)
   325  
   326        self._metrics.commit_logical(
   327            completed_bundle, result.logical_metric_updates)
   328  
   329        # If the result is for a view, update side inputs container.
   330        self._update_side_inputs_container(committed_bundles, result)
   331  
   332        # Tasks generated from unblocked side inputs as the watermark progresses.
   333        tasks = self._watermark_manager.update_watermarks(
   334            completed_bundle,
   335            result.transform,
   336            completed_timers,
   337            committed_bundles,
   338            unprocessed_bundles,
   339            result.keyed_watermark_holds,
   340            self._side_inputs_container)
   341        self._pending_unblocked_tasks.extend(tasks)
   342  
   343        if result.counters:
   344          for counter in result.counters:
   345            merged_counter = self._counter_factory.get_counter(
   346                counter.name, counter.combine_fn)
   347            merged_counter.accumulator.merge([counter.accumulator])
   348  
   349        # Commit partial GBK states
   350        existing_keyed_state = self._transform_keyed_states[result.transform]
   351        for k, v in result.partial_keyed_state.items():
   352          existing_keyed_state[k] = v
   353        return committed_bundles
   354  
   355    def _update_side_inputs_container(self,
   356                                      committed_bundles,  # type: Iterable[_Bundle]
   357                                      result  # type: TransformResult
   358                                     ):
   359      """Update the side inputs container if we are outputting into a side input.
   360  
   361      Look at the result, and if it's outputing into a PCollection that we have
   362      registered as a PCollectionView, we add the result to the PCollectionView.
   363      """
   364      if (result.uncommitted_output_bundles and
   365          result.uncommitted_output_bundles[0].pcollection in
   366          self._pcollection_to_views):
   367        for view in self._pcollection_to_views[
   368            result.uncommitted_output_bundles[0].pcollection]:
   369          for committed_bundle in committed_bundles:
   370            # side_input must be materialized.
   371            self._side_inputs_container.add_values(
   372                view, committed_bundle.get_elements_iterable(make_copy=True))
   373  
   374    def get_aggregator_values(self, aggregator_or_name):
   375      return self._counter_factory.get_aggregator_values(aggregator_or_name)
   376  
   377    def schedule_pending_unblocked_tasks(self, executor_service):
   378      if self._pending_unblocked_tasks:
   379        with self._lock:
   380          for task in self._pending_unblocked_tasks:
   381            executor_service.submit(task)
   382          self._pending_unblocked_tasks = []
   383  
   384    def _commit_bundles(self,
   385                        uncommitted_bundles,  # type: Iterable[_Bundle]
   386                        unprocessed_bundles  # type: Iterable[_Bundle]
   387                       ):
   388      # type: (...) -> Tuple[Tuple[_Bundle, ...], Tuple[_Bundle, ...]]
   389  
   390      """Commits bundles and returns a immutable set of committed bundles."""
   391      for in_progress_bundle in uncommitted_bundles:
   392        producing_applied_ptransform = in_progress_bundle.pcollection.producer
   393        watermarks = self._watermark_manager.get_watermarks(
   394            producing_applied_ptransform)
   395        in_progress_bundle.commit(watermarks.synchronized_processing_output_time)
   396  
   397      for unprocessed_bundle in unprocessed_bundles:
   398        unprocessed_bundle.commit(None)
   399      return tuple(uncommitted_bundles), tuple(unprocessed_bundles)
   400  
   401    def get_execution_context(self, applied_ptransform):
   402      # type: (AppliedPTransform) -> _ExecutionContext
   403      return _ExecutionContext(
   404          self._watermark_manager.get_watermarks(applied_ptransform),
   405          self._transform_keyed_states[applied_ptransform])
   406  
   407    def create_bundle(self, output_pcollection):
   408      # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle
   409  
   410      """Create an uncommitted bundle for the specified PCollection."""
   411      return self._bundle_factory.create_bundle(output_pcollection)
   412  
   413    def create_empty_committed_bundle(self, output_pcollection):
   414      # type: (pvalue.PCollection) -> _Bundle
   415  
   416      """Create empty bundle useful for triggering evaluation."""
   417      return self._bundle_factory.create_empty_committed_bundle(
   418          output_pcollection)
   419  
   420    def extract_all_timers(self):
   421      # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool]
   422      return self._watermark_manager.extract_all_timers()
   423  
   424    def is_done(self, transform=None):
   425      # type: (Optional[AppliedPTransform]) -> bool
   426  
   427      """Checks completion of a step or the pipeline.
   428  
   429      Args:
   430        transform: AppliedPTransform to check for completion.
   431  
   432      Returns:
   433        True if the step will not produce additional output. If transform is None
   434        returns true if all steps are done.
   435      """
   436      if transform:
   437        return self._is_transform_done(transform)
   438  
   439      for applied_ptransform in self._step_names:
   440        if not self._is_transform_done(applied_ptransform):
   441          return False
   442      return True
   443  
   444    def _is_transform_done(self, transform):
   445      # type: (AppliedPTransform) -> bool
   446      tw = self._watermark_manager.get_watermarks(transform)
   447      return tw.output_watermark == WatermarkManager.WATERMARK_POS_INF
   448  
   449    def get_value_or_block_until_ready(self, side_input, task, block_until):
   450      assert isinstance(task, TransformExecutor)
   451      return self._side_inputs_container.get_value_or_block_until_ready(
   452          side_input, task, block_until)
   453  
   454    def shutdown(self):
   455      self.shutdown_requested = True
   456  
   457  
   458  class DirectUnmergedState(InMemoryUnmergedState):
   459    """UnmergedState implementation for the DirectRunner."""
   460    def __init__(self):
   461      super().__init__(defensive_copy=False)
   462  
   463  
   464  class DirectStepContext(object):
   465    """Context for the currently-executing step."""
   466    def __init__(self, existing_keyed_state):
   467      self.existing_keyed_state = existing_keyed_state
   468      # In order to avoid partial writes of a bundle, every time
   469      # existing_keyed_state is accessed, a copy of the state is made
   470      # to be transferred to the bundle state once the bundle is committed.
   471      self.partial_keyed_state = {}
   472  
   473    def get_keyed_state(self, key):
   474      if not self.existing_keyed_state.get(key):
   475        self.existing_keyed_state[key] = DirectUnmergedState()
   476      if not self.partial_keyed_state.get(key):
   477        self.partial_keyed_state[key] = self.existing_keyed_state[key].copy()
   478      return self.partial_keyed_state[key]