github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/pipeline_context.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Utility class for serializing pipelines via the runner API.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  """
    22  
    23  # pytype: skip-file
    24  # mypy: disallow-untyped-defs
    25  
    26  from typing import TYPE_CHECKING
    27  from typing import Any
    28  from typing import Dict
    29  from typing import FrozenSet
    30  from typing import Generic
    31  from typing import Iterable
    32  from typing import Mapping
    33  from typing import Optional
    34  from typing import Type
    35  from typing import TypeVar
    36  from typing import Union
    37  
    38  from typing_extensions import Protocol
    39  
    40  from apache_beam import coders
    41  from apache_beam import pipeline
    42  from apache_beam import pvalue
    43  from apache_beam.internal import pickler
    44  from apache_beam.pipeline import ComponentIdMap
    45  from apache_beam.portability.api import beam_fn_api_pb2
    46  from apache_beam.portability.api import beam_runner_api_pb2
    47  from apache_beam.transforms import core
    48  from apache_beam.transforms import environments
    49  from apache_beam.transforms.resources import merge_resource_hints
    50  from apache_beam.typehints import native_type_compatibility
    51  
    52  if TYPE_CHECKING:
    53    from google.protobuf import message  # pylint: disable=ungrouped-imports
    54    from apache_beam.coders.coder_impl import IterableStateReader
    55    from apache_beam.coders.coder_impl import IterableStateWriter
    56    from apache_beam.transforms import ptransform
    57  
    58  PortableObjectT = TypeVar('PortableObjectT', bound='PortableObject')
    59  
    60  
    61  class PortableObject(Protocol):
    62    def to_runner_api(self, __context):
    63      # type: (PipelineContext) -> Any
    64      pass
    65  
    66    @classmethod
    67    def from_runner_api(cls, __proto, __context):
    68      # type: (Any, PipelineContext) -> Any
    69      pass
    70  
    71  
    72  class _PipelineContextMap(Generic[PortableObjectT]):
    73    """This is a bi-directional map between objects and ids.
    74  
    75    Under the hood it encodes and decodes these objects into runner API
    76    representations.
    77    """
    78    def __init__(self,
    79                 context,  # type: PipelineContext
    80                 obj_type,  # type: Type[PortableObjectT]
    81                 namespace,  # type: str
    82                 proto_map=None  # type: Optional[Mapping[str, message.Message]]
    83                ):
    84      # type: (...) -> None
    85      self._pipeline_context = context
    86      self._obj_type = obj_type
    87      self._namespace = namespace
    88      self._obj_to_id = {}  # type: Dict[Any, str]
    89      self._id_to_obj = {}  # type: Dict[str, Any]
    90      self._id_to_proto = dict(proto_map) if proto_map else {}
    91  
    92    def populate_map(self, proto_map):
    93      # type: (Mapping[str, message.Message]) -> None
    94      for id, proto in self._id_to_proto.items():
    95        proto_map[id].CopyFrom(proto)
    96  
    97    def get_id(self, obj, label=None):
    98      # type: (PortableObjectT, Optional[str]) -> str
    99      if obj not in self._obj_to_id:
   100        id = self._pipeline_context.component_id_map.get_or_assign(
   101            obj, self._obj_type, label)
   102        self._id_to_obj[id] = obj
   103        self._obj_to_id[obj] = id
   104        self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context)
   105      return self._obj_to_id[obj]
   106  
   107    def get_proto(self, obj, label=None):
   108      # type: (PortableObjectT, Optional[str]) -> message.Message
   109      return self._id_to_proto[self.get_id(obj, label)]
   110  
   111    def get_by_id(self, id):
   112      # type: (str) -> PortableObjectT
   113      if id not in self._id_to_obj:
   114        self._id_to_obj[id] = self._obj_type.from_runner_api(
   115            self._id_to_proto[id], self._pipeline_context)
   116      return self._id_to_obj[id]
   117  
   118    def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False):
   119      # type: (message.Message, Optional[str], bool) -> str
   120      # TODO: this method may not be safe for arbitrary protos due to
   121      #  xlang concerns, hence limiting usage to the only current use-case it has.
   122      #  See: https://github.com/apache/beam/pull/14390#discussion_r616062377
   123      assert isinstance(maybe_new_proto, beam_runner_api_pb2.Environment)
   124      obj = self._obj_type.from_runner_api(
   125          maybe_new_proto, self._pipeline_context)
   126  
   127      if deduplicate:
   128        if obj in self._obj_to_id:
   129          return self._obj_to_id[obj]
   130  
   131        for id, proto in self._id_to_proto.items():
   132          if proto == maybe_new_proto:
   133            return id
   134      return self.put_proto(
   135          self._pipeline_context.component_id_map.get_or_assign(
   136              obj=obj, obj_type=self._obj_type, label=label),
   137          maybe_new_proto)
   138  
   139    def get_id_to_proto_map(self):
   140      # type: () -> Dict[str, message.Message]
   141      return self._id_to_proto
   142  
   143    def get_proto_from_id(self, id):
   144      # type: (str) -> message.Message
   145      return self.get_id_to_proto_map()[id]
   146  
   147    def put_proto(self, id, proto, ignore_duplicates=False):
   148      # type: (str, message.Message, bool) -> str
   149      if not ignore_duplicates and id in self._id_to_proto:
   150        raise ValueError("Id '%s' is already taken." % id)
   151      elif (ignore_duplicates and id in self._id_to_proto and
   152            self._id_to_proto[id] != proto):
   153        raise ValueError(
   154            'Cannot insert different protos %r and %r with the same ID %r',
   155            self._id_to_proto[id],
   156            proto,
   157            id)
   158      self._id_to_proto[id] = proto
   159      return id
   160  
   161    def __getitem__(self, id):
   162      # type: (str) -> Any
   163      return self.get_by_id(id)
   164  
   165    def __contains__(self, id):
   166      # type: (str) -> bool
   167      return id in self._id_to_proto
   168  
   169  
   170  class PipelineContext(object):
   171    """For internal use only; no backwards-compatibility guarantees.
   172  
   173    Used for accessing and constructing the referenced objects of a Pipeline.
   174    """
   175  
   176    def __init__(self,
   177                 proto=None,  # type: Optional[Union[beam_runner_api_pb2.Components, beam_fn_api_pb2.ProcessBundleDescriptor]]
   178                 component_id_map=None,  # type: Optional[pipeline.ComponentIdMap]
   179                 default_environment=None,  # type: Optional[environments.Environment]
   180                 use_fake_coders=False,  # type: bool
   181                 iterable_state_read=None,  # type: Optional[IterableStateReader]
   182                 iterable_state_write=None,  # type: Optional[IterableStateWriter]
   183                 namespace='ref',  # type: str
   184                 requirements=(),  # type: Iterable[str]
   185                ):
   186      # type: (...) -> None
   187      if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor):
   188        proto = beam_runner_api_pb2.Components(
   189            coders=dict(proto.coders.items()),
   190            windowing_strategies=dict(proto.windowing_strategies.items()),
   191            environments=dict(proto.environments.items()))
   192  
   193      self.component_id_map = component_id_map or ComponentIdMap(namespace)
   194      assert self.component_id_map.namespace == namespace
   195  
   196      # TODO(https://github.com/apache/beam/issues/20827) Initialize
   197      # component_id_map with objects from proto.
   198      self.transforms = _PipelineContextMap(
   199          self,
   200          pipeline.AppliedPTransform,
   201          namespace,
   202          proto.transforms if proto is not None else None)
   203      self.pcollections = _PipelineContextMap(
   204          self,
   205          pvalue.PCollection,
   206          namespace,
   207          proto.pcollections if proto is not None else None)
   208      self.coders = _PipelineContextMap(
   209          self,
   210          coders.Coder,
   211          namespace,
   212          proto.coders if proto is not None else None)
   213      self.windowing_strategies = _PipelineContextMap(
   214          self,
   215          core.Windowing,
   216          namespace,
   217          proto.windowing_strategies if proto is not None else None)
   218      self.environments = _PipelineContextMap(
   219          self,
   220          environments.Environment,
   221          namespace,
   222          proto.environments if proto is not None else None)
   223  
   224      if default_environment is None:
   225        default_environment = environments.DefaultEnvironment()
   226  
   227      self._default_environment_id = self.environments.get_id(
   228          default_environment, label='default_environment')  # type: str
   229  
   230      self.use_fake_coders = use_fake_coders
   231      self.deterministic_coder_map = {
   232      }  # type: Mapping[coders.Coder, coders.Coder]
   233      self.iterable_state_read = iterable_state_read
   234      self.iterable_state_write = iterable_state_write
   235      self._requirements = set(requirements)
   236  
   237    def add_requirement(self, requirement):
   238      # type: (str) -> None
   239      self._requirements.add(requirement)
   240  
   241    def requirements(self):
   242      # type: () -> FrozenSet[str]
   243      return frozenset(self._requirements)
   244  
   245    # If fake coders are requested, return a pickled version of the element type
   246    # rather than an actual coder. The element type is required for some runners,
   247    # as well as performing a round-trip through protos.
   248    # TODO(https://github.com/apache/beam/issues/18490): Remove once this is no
   249    # longer needed.
   250    def coder_id_from_element_type(
   251        self, element_type, requires_deterministic_key_coder=None):
   252      # type: (Any, Optional[str]) -> str
   253      if self.use_fake_coders:
   254        return pickler.dumps(element_type).decode('ascii')
   255      else:
   256        coder = coders.registry.get_coder(element_type)
   257        if requires_deterministic_key_coder:
   258          coder = coders.TupleCoder([
   259              self.deterministic_coder(
   260                  coder.key_coder(), requires_deterministic_key_coder),
   261              coder.value_coder()
   262          ])
   263        return self.coders.get_id(coder)
   264  
   265    def deterministic_coder(self, coder, msg):
   266      # type: (coders.Coder, str) -> coders.Coder
   267      if coder not in self.deterministic_coder_map:
   268        self.deterministic_coder_map[coder] = coder.as_deterministic_coder(msg)  # type: ignore
   269      return self.deterministic_coder_map[coder]
   270  
   271    def element_type_from_coder_id(self, coder_id):
   272      # type: (str) -> Any
   273      if self.use_fake_coders or coder_id not in self.coders:
   274        return pickler.loads(coder_id)
   275      else:
   276        return native_type_compatibility.convert_to_beam_type(
   277            self.coders[coder_id].to_type_hint())
   278  
   279    @staticmethod
   280    def from_runner_api(proto):
   281      # type: (beam_runner_api_pb2.Components) -> PipelineContext
   282      return PipelineContext(proto)
   283  
   284    def to_runner_api(self):
   285      # type: () -> beam_runner_api_pb2.Components
   286      context_proto = beam_runner_api_pb2.Components()
   287  
   288      self.transforms.populate_map(context_proto.transforms)
   289      self.pcollections.populate_map(context_proto.pcollections)
   290      self.coders.populate_map(context_proto.coders)
   291      self.windowing_strategies.populate_map(context_proto.windowing_strategies)
   292      self.environments.populate_map(context_proto.environments)
   293  
   294      return context_proto
   295  
   296    def default_environment_id(self):
   297      # type: () -> str
   298      return self._default_environment_id
   299  
   300    def get_environment_id_for_resource_hints(
   301        self, hints):  # type: (Dict[str, bytes]) -> str
   302      """Returns an environment id that has necessary resource hints."""
   303      if not hints:
   304        return self.default_environment_id()
   305  
   306      def get_or_create_environment_with_resource_hints(
   307          template_env_id,
   308          resource_hints,
   309      ):  # type: (str, Dict[str, bytes]) -> str
   310        """Creates an environment that has necessary hints and returns its id."""
   311        template_env = self.environments.get_proto_from_id(template_env_id)
   312        cloned_env = beam_runner_api_pb2.Environment()
   313        # (TODO https://github.com/apache/beam/issues/25615)
   314        # Remove the suppress warning for type once mypy is updated to 0.941 or
   315        # higher.
   316        #  mypy 0.790 throws the warning below but 0.941 doesn't.
   317        #  error: Argument 1 to "CopyFrom" of "Message" has incompatible type
   318        #  "Message"; expected "Environment"  [arg-type]
   319        # Here, Environment is a subclass of Message but mypy still
   320        # throws an error.
   321        cloned_env.CopyFrom(template_env)  # type: ignore[arg-type]
   322        cloned_env.resource_hints.clear()
   323        cloned_env.resource_hints.update(resource_hints)
   324  
   325        return self.environments.get_by_proto(
   326            cloned_env, label='environment_with_resource_hints', deduplicate=True)
   327  
   328      default_env_id = self.default_environment_id()
   329      env_hints = self.environments.get_by_id(default_env_id).resource_hints()
   330      hints = merge_resource_hints(outer_hints=env_hints, inner_hints=hints)
   331      maybe_new_env_id = get_or_create_environment_with_resource_hints(
   332          default_env_id, hints)
   333  
   334      return maybe_new_env_id