github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/pipeline_context.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Utility class for serializing pipelines via the runner API. 19 20 For internal use only; no backwards-compatibility guarantees. 21 """ 22 23 # pytype: skip-file 24 # mypy: disallow-untyped-defs 25 26 from typing import TYPE_CHECKING 27 from typing import Any 28 from typing import Dict 29 from typing import FrozenSet 30 from typing import Generic 31 from typing import Iterable 32 from typing import Mapping 33 from typing import Optional 34 from typing import Type 35 from typing import TypeVar 36 from typing import Union 37 38 from typing_extensions import Protocol 39 40 from apache_beam import coders 41 from apache_beam import pipeline 42 from apache_beam import pvalue 43 from apache_beam.internal import pickler 44 from apache_beam.pipeline import ComponentIdMap 45 from apache_beam.portability.api import beam_fn_api_pb2 46 from apache_beam.portability.api import beam_runner_api_pb2 47 from apache_beam.transforms import core 48 from apache_beam.transforms import environments 49 from apache_beam.transforms.resources import merge_resource_hints 50 from apache_beam.typehints import native_type_compatibility 51 52 if TYPE_CHECKING: 53 from google.protobuf import message # pylint: disable=ungrouped-imports 54 from apache_beam.coders.coder_impl import IterableStateReader 55 from apache_beam.coders.coder_impl import IterableStateWriter 56 from apache_beam.transforms import ptransform 57 58 PortableObjectT = TypeVar('PortableObjectT', bound='PortableObject') 59 60 61 class PortableObject(Protocol): 62 def to_runner_api(self, __context): 63 # type: (PipelineContext) -> Any 64 pass 65 66 @classmethod 67 def from_runner_api(cls, __proto, __context): 68 # type: (Any, PipelineContext) -> Any 69 pass 70 71 72 class _PipelineContextMap(Generic[PortableObjectT]): 73 """This is a bi-directional map between objects and ids. 74 75 Under the hood it encodes and decodes these objects into runner API 76 representations. 77 """ 78 def __init__(self, 79 context, # type: PipelineContext 80 obj_type, # type: Type[PortableObjectT] 81 namespace, # type: str 82 proto_map=None # type: Optional[Mapping[str, message.Message]] 83 ): 84 # type: (...) -> None 85 self._pipeline_context = context 86 self._obj_type = obj_type 87 self._namespace = namespace 88 self._obj_to_id = {} # type: Dict[Any, str] 89 self._id_to_obj = {} # type: Dict[str, Any] 90 self._id_to_proto = dict(proto_map) if proto_map else {} 91 92 def populate_map(self, proto_map): 93 # type: (Mapping[str, message.Message]) -> None 94 for id, proto in self._id_to_proto.items(): 95 proto_map[id].CopyFrom(proto) 96 97 def get_id(self, obj, label=None): 98 # type: (PortableObjectT, Optional[str]) -> str 99 if obj not in self._obj_to_id: 100 id = self._pipeline_context.component_id_map.get_or_assign( 101 obj, self._obj_type, label) 102 self._id_to_obj[id] = obj 103 self._obj_to_id[obj] = id 104 self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context) 105 return self._obj_to_id[obj] 106 107 def get_proto(self, obj, label=None): 108 # type: (PortableObjectT, Optional[str]) -> message.Message 109 return self._id_to_proto[self.get_id(obj, label)] 110 111 def get_by_id(self, id): 112 # type: (str) -> PortableObjectT 113 if id not in self._id_to_obj: 114 self._id_to_obj[id] = self._obj_type.from_runner_api( 115 self._id_to_proto[id], self._pipeline_context) 116 return self._id_to_obj[id] 117 118 def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): 119 # type: (message.Message, Optional[str], bool) -> str 120 # TODO: this method may not be safe for arbitrary protos due to 121 # xlang concerns, hence limiting usage to the only current use-case it has. 122 # See: https://github.com/apache/beam/pull/14390#discussion_r616062377 123 assert isinstance(maybe_new_proto, beam_runner_api_pb2.Environment) 124 obj = self._obj_type.from_runner_api( 125 maybe_new_proto, self._pipeline_context) 126 127 if deduplicate: 128 if obj in self._obj_to_id: 129 return self._obj_to_id[obj] 130 131 for id, proto in self._id_to_proto.items(): 132 if proto == maybe_new_proto: 133 return id 134 return self.put_proto( 135 self._pipeline_context.component_id_map.get_or_assign( 136 obj=obj, obj_type=self._obj_type, label=label), 137 maybe_new_proto) 138 139 def get_id_to_proto_map(self): 140 # type: () -> Dict[str, message.Message] 141 return self._id_to_proto 142 143 def get_proto_from_id(self, id): 144 # type: (str) -> message.Message 145 return self.get_id_to_proto_map()[id] 146 147 def put_proto(self, id, proto, ignore_duplicates=False): 148 # type: (str, message.Message, bool) -> str 149 if not ignore_duplicates and id in self._id_to_proto: 150 raise ValueError("Id '%s' is already taken." % id) 151 elif (ignore_duplicates and id in self._id_to_proto and 152 self._id_to_proto[id] != proto): 153 raise ValueError( 154 'Cannot insert different protos %r and %r with the same ID %r', 155 self._id_to_proto[id], 156 proto, 157 id) 158 self._id_to_proto[id] = proto 159 return id 160 161 def __getitem__(self, id): 162 # type: (str) -> Any 163 return self.get_by_id(id) 164 165 def __contains__(self, id): 166 # type: (str) -> bool 167 return id in self._id_to_proto 168 169 170 class PipelineContext(object): 171 """For internal use only; no backwards-compatibility guarantees. 172 173 Used for accessing and constructing the referenced objects of a Pipeline. 174 """ 175 176 def __init__(self, 177 proto=None, # type: Optional[Union[beam_runner_api_pb2.Components, beam_fn_api_pb2.ProcessBundleDescriptor]] 178 component_id_map=None, # type: Optional[pipeline.ComponentIdMap] 179 default_environment=None, # type: Optional[environments.Environment] 180 use_fake_coders=False, # type: bool 181 iterable_state_read=None, # type: Optional[IterableStateReader] 182 iterable_state_write=None, # type: Optional[IterableStateWriter] 183 namespace='ref', # type: str 184 requirements=(), # type: Iterable[str] 185 ): 186 # type: (...) -> None 187 if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): 188 proto = beam_runner_api_pb2.Components( 189 coders=dict(proto.coders.items()), 190 windowing_strategies=dict(proto.windowing_strategies.items()), 191 environments=dict(proto.environments.items())) 192 193 self.component_id_map = component_id_map or ComponentIdMap(namespace) 194 assert self.component_id_map.namespace == namespace 195 196 # TODO(https://github.com/apache/beam/issues/20827) Initialize 197 # component_id_map with objects from proto. 198 self.transforms = _PipelineContextMap( 199 self, 200 pipeline.AppliedPTransform, 201 namespace, 202 proto.transforms if proto is not None else None) 203 self.pcollections = _PipelineContextMap( 204 self, 205 pvalue.PCollection, 206 namespace, 207 proto.pcollections if proto is not None else None) 208 self.coders = _PipelineContextMap( 209 self, 210 coders.Coder, 211 namespace, 212 proto.coders if proto is not None else None) 213 self.windowing_strategies = _PipelineContextMap( 214 self, 215 core.Windowing, 216 namespace, 217 proto.windowing_strategies if proto is not None else None) 218 self.environments = _PipelineContextMap( 219 self, 220 environments.Environment, 221 namespace, 222 proto.environments if proto is not None else None) 223 224 if default_environment is None: 225 default_environment = environments.DefaultEnvironment() 226 227 self._default_environment_id = self.environments.get_id( 228 default_environment, label='default_environment') # type: str 229 230 self.use_fake_coders = use_fake_coders 231 self.deterministic_coder_map = { 232 } # type: Mapping[coders.Coder, coders.Coder] 233 self.iterable_state_read = iterable_state_read 234 self.iterable_state_write = iterable_state_write 235 self._requirements = set(requirements) 236 237 def add_requirement(self, requirement): 238 # type: (str) -> None 239 self._requirements.add(requirement) 240 241 def requirements(self): 242 # type: () -> FrozenSet[str] 243 return frozenset(self._requirements) 244 245 # If fake coders are requested, return a pickled version of the element type 246 # rather than an actual coder. The element type is required for some runners, 247 # as well as performing a round-trip through protos. 248 # TODO(https://github.com/apache/beam/issues/18490): Remove once this is no 249 # longer needed. 250 def coder_id_from_element_type( 251 self, element_type, requires_deterministic_key_coder=None): 252 # type: (Any, Optional[str]) -> str 253 if self.use_fake_coders: 254 return pickler.dumps(element_type).decode('ascii') 255 else: 256 coder = coders.registry.get_coder(element_type) 257 if requires_deterministic_key_coder: 258 coder = coders.TupleCoder([ 259 self.deterministic_coder( 260 coder.key_coder(), requires_deterministic_key_coder), 261 coder.value_coder() 262 ]) 263 return self.coders.get_id(coder) 264 265 def deterministic_coder(self, coder, msg): 266 # type: (coders.Coder, str) -> coders.Coder 267 if coder not in self.deterministic_coder_map: 268 self.deterministic_coder_map[coder] = coder.as_deterministic_coder(msg) # type: ignore 269 return self.deterministic_coder_map[coder] 270 271 def element_type_from_coder_id(self, coder_id): 272 # type: (str) -> Any 273 if self.use_fake_coders or coder_id not in self.coders: 274 return pickler.loads(coder_id) 275 else: 276 return native_type_compatibility.convert_to_beam_type( 277 self.coders[coder_id].to_type_hint()) 278 279 @staticmethod 280 def from_runner_api(proto): 281 # type: (beam_runner_api_pb2.Components) -> PipelineContext 282 return PipelineContext(proto) 283 284 def to_runner_api(self): 285 # type: () -> beam_runner_api_pb2.Components 286 context_proto = beam_runner_api_pb2.Components() 287 288 self.transforms.populate_map(context_proto.transforms) 289 self.pcollections.populate_map(context_proto.pcollections) 290 self.coders.populate_map(context_proto.coders) 291 self.windowing_strategies.populate_map(context_proto.windowing_strategies) 292 self.environments.populate_map(context_proto.environments) 293 294 return context_proto 295 296 def default_environment_id(self): 297 # type: () -> str 298 return self._default_environment_id 299 300 def get_environment_id_for_resource_hints( 301 self, hints): # type: (Dict[str, bytes]) -> str 302 """Returns an environment id that has necessary resource hints.""" 303 if not hints: 304 return self.default_environment_id() 305 306 def get_or_create_environment_with_resource_hints( 307 template_env_id, 308 resource_hints, 309 ): # type: (str, Dict[str, bytes]) -> str 310 """Creates an environment that has necessary hints and returns its id.""" 311 template_env = self.environments.get_proto_from_id(template_env_id) 312 cloned_env = beam_runner_api_pb2.Environment() 313 # (TODO https://github.com/apache/beam/issues/25615) 314 # Remove the suppress warning for type once mypy is updated to 0.941 or 315 # higher. 316 # mypy 0.790 throws the warning below but 0.941 doesn't. 317 # error: Argument 1 to "CopyFrom" of "Message" has incompatible type 318 # "Message"; expected "Environment" [arg-type] 319 # Here, Environment is a subclass of Message but mypy still 320 # throws an error. 321 cloned_env.CopyFrom(template_env) # type: ignore[arg-type] 322 cloned_env.resource_hints.clear() 323 cloned_env.resource_hints.update(resource_hints) 324 325 return self.environments.get_by_proto( 326 cloned_env, label='environment_with_resource_hints', deduplicate=True) 327 328 default_env_id = self.default_environment_id() 329 env_hints = self.environments.get_by_id(default_env_id).resource_hints() 330 hints = merge_resource_hints(outer_hints=env_hints, inner_hints=hints) 331 maybe_new_env_id = get_or_create_environment_with_resource_hints( 332 default_env_id, hints) 333 334 return maybe_new_env_id