github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/external.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Defines Transform whose expansion is implemented elsewhere.""" 19 # pytype: skip-file 20 21 import contextlib 22 import copy 23 import functools 24 import glob 25 import logging 26 import threading 27 from collections import OrderedDict 28 from collections import namedtuple 29 from typing import Dict 30 31 import grpc 32 33 from apache_beam import pvalue 34 from apache_beam.coders import RowCoder 35 from apache_beam.portability import common_urns 36 from apache_beam.portability.api import beam_artifact_api_pb2_grpc 37 from apache_beam.portability.api import beam_expansion_api_pb2 38 from apache_beam.portability.api import beam_expansion_api_pb2_grpc 39 from apache_beam.portability.api import beam_runner_api_pb2 40 from apache_beam.portability.api import external_transforms_pb2 41 from apache_beam.runners import pipeline_context 42 from apache_beam.runners.portability import artifact_service 43 from apache_beam.transforms import ptransform 44 from apache_beam.typehints import WithTypeHints 45 from apache_beam.typehints import native_type_compatibility 46 from apache_beam.typehints import row_type 47 from apache_beam.typehints.schemas import named_fields_to_schema 48 from apache_beam.typehints.schemas import named_tuple_from_schema 49 from apache_beam.typehints.schemas import named_tuple_to_schema 50 from apache_beam.typehints.trivial_inference import instance_to_type 51 from apache_beam.typehints.typehints import Union 52 from apache_beam.typehints.typehints import UnionConstraint 53 from apache_beam.utils import subprocess_server 54 55 DEFAULT_EXPANSION_SERVICE = 'localhost:8097' 56 57 58 def convert_to_typing_type(type_): 59 if isinstance(type_, row_type.RowTypeConstraint): 60 return named_tuple_from_schema(named_fields_to_schema(type_._fields)) 61 else: 62 return native_type_compatibility.convert_to_typing_type(type_) 63 64 65 def _is_optional_or_none(typehint): 66 return ( 67 type(None) in typehint.union_types if isinstance( 68 typehint, UnionConstraint) else typehint is type(None)) 69 70 71 def _strip_optional(typehint): 72 if not _is_optional_or_none(typehint): 73 return typehint 74 new_types = typehint.union_types.difference({type(None)}) 75 if len(new_types) == 1: 76 return list(new_types)[0] 77 return Union[new_types] 78 79 80 def iter_urns(coder, context=None): 81 yield coder.to_runner_api_parameter(context)[0] 82 for child in coder._get_component_coders(): 83 for urn in iter_urns(child, context): 84 yield urn 85 86 87 class PayloadBuilder(object): 88 """ 89 Abstract base class for building payloads to pass to ExternalTransform. 90 """ 91 def build(self): 92 """ 93 :return: ExternalConfigurationPayload 94 """ 95 raise NotImplementedError 96 97 def payload(self): 98 """ 99 The serialized ExternalConfigurationPayload 100 101 :return: bytes 102 """ 103 return self.build().SerializeToString() 104 105 def _get_schema_proto_and_payload(self, **kwargs): 106 named_fields = [] 107 fields_to_values = OrderedDict() 108 109 for key, value in kwargs.items(): 110 if not key: 111 raise ValueError('Parameter name cannot be empty') 112 if value is None: 113 raise ValueError( 114 'Received value None for key %s. None values are currently not ' 115 'supported' % key) 116 named_fields.append( 117 (key, convert_to_typing_type(instance_to_type(value)))) 118 fields_to_values[key] = value 119 120 schema_proto = named_fields_to_schema(named_fields) 121 row = named_tuple_from_schema(schema_proto)(**fields_to_values) 122 schema = named_tuple_to_schema(type(row)) 123 124 payload = RowCoder(schema).encode(row) 125 return (schema_proto, payload) 126 127 128 class SchemaBasedPayloadBuilder(PayloadBuilder): 129 """ 130 Base class for building payloads based on a schema that provides 131 type information for each configuration value to encode. 132 """ 133 def _get_named_tuple_instance(self): 134 raise NotImplementedError() 135 136 def build(self): 137 row = self._get_named_tuple_instance() 138 schema = named_tuple_to_schema(type(row)) 139 return external_transforms_pb2.ExternalConfigurationPayload( 140 schema=schema, payload=RowCoder(schema).encode(row)) 141 142 143 class ImplicitSchemaPayloadBuilder(SchemaBasedPayloadBuilder): 144 """ 145 Build a payload that generates a schema from the provided values. 146 """ 147 def __init__(self, values): 148 self._values = values 149 150 def _get_named_tuple_instance(self): 151 # omit fields with value=None since we can't infer their type 152 values = { 153 key: value 154 for key, value in self._values.items() if value is not None 155 } 156 157 schema = named_fields_to_schema([ 158 (key, convert_to_typing_type(instance_to_type(value))) for key, 159 value in values.items() 160 ]) 161 return named_tuple_from_schema(schema)(**values) 162 163 164 class NamedTupleBasedPayloadBuilder(SchemaBasedPayloadBuilder): 165 """ 166 Build a payload based on a NamedTuple schema. 167 """ 168 def __init__(self, tuple_instance): 169 """ 170 :param tuple_instance: an instance of a typing.NamedTuple 171 """ 172 super().__init__() 173 self._tuple_instance = tuple_instance 174 175 def _get_named_tuple_instance(self): 176 return self._tuple_instance 177 178 179 class SchemaTransformPayloadBuilder(PayloadBuilder): 180 def __init__(self, identifier, **kwargs): 181 self._identifier = identifier 182 self._kwargs = kwargs 183 184 def build(self): 185 schema_proto, payload = self._get_schema_proto_and_payload(**self._kwargs) 186 payload = external_transforms_pb2.SchemaTransformPayload( 187 identifier=self._identifier, 188 configuration_schema=schema_proto, 189 configuration_row=payload) 190 return payload 191 192 193 class JavaClassLookupPayloadBuilder(PayloadBuilder): 194 """ 195 Builds a payload for directly instantiating a Java transform using a 196 constructor and builder methods. 197 """ 198 199 IGNORED_ARG_FORMAT = 'ignore%d' 200 201 def __init__(self, class_name): 202 """ 203 :param class_name: fully qualified name of the transform class. 204 """ 205 if not class_name: 206 raise ValueError('Class name must not be empty') 207 208 self._class_name = class_name 209 self._constructor_method = None 210 self._constructor_param_args = None 211 self._constructor_param_kwargs = None 212 self._builder_methods_and_params = OrderedDict() 213 214 def _args_to_named_fields(self, args): 215 next_field_id = 0 216 named_fields = OrderedDict() 217 for value in args: 218 if value is None: 219 raise ValueError( 220 'Received value None. None values are currently not supported') 221 named_fields[( 222 JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT % 223 next_field_id)] = value 224 next_field_id += 1 225 return named_fields 226 227 def build(self): 228 all_constructor_param_kwargs = self._args_to_named_fields( 229 self._constructor_param_args) 230 if self._constructor_param_kwargs: 231 all_constructor_param_kwargs.update(self._constructor_param_kwargs) 232 constructor_schema, constructor_payload = ( 233 self._get_schema_proto_and_payload(**all_constructor_param_kwargs)) 234 payload = external_transforms_pb2.JavaClassLookupPayload( 235 class_name=self._class_name, 236 constructor_schema=constructor_schema, 237 constructor_payload=constructor_payload) 238 if self._constructor_method: 239 payload.constructor_method = self._constructor_method 240 241 for builder_method_name, params in self._builder_methods_and_params.items(): 242 builder_method_args, builder_method_kwargs = params 243 all_builder_method_kwargs = self._args_to_named_fields( 244 builder_method_args) 245 if builder_method_kwargs: 246 all_builder_method_kwargs.update(builder_method_kwargs) 247 builder_method_schema, builder_method_payload = ( 248 self._get_schema_proto_and_payload(**all_builder_method_kwargs)) 249 builder_method = external_transforms_pb2.BuilderMethod( 250 name=builder_method_name, 251 schema=builder_method_schema, 252 payload=builder_method_payload) 253 builder_method.name = builder_method_name 254 payload.builder_methods.append(builder_method) 255 return payload 256 257 def with_constructor(self, *args, **kwargs): 258 """ 259 Specifies the Java constructor to use. 260 Arguments provided using args and kwargs will be applied to the Java 261 transform constructor in the specified order. 262 263 :param args: parameter values of the constructor. 264 :param kwargs: parameter names and values of the constructor. 265 """ 266 if self._has_constructor(): 267 raise ValueError( 268 'Constructor or constructor method can only be specified once') 269 270 self._constructor_param_args = args 271 self._constructor_param_kwargs = kwargs 272 273 def with_constructor_method(self, method_name, *args, **kwargs): 274 """ 275 Specifies the Java constructor method to use. 276 Arguments provided using args and kwargs will be applied to the Java 277 transform constructor method in the specified order. 278 279 :param method_name: name of the constructor method. 280 :param args: parameter values of the constructor method. 281 :param kwargs: parameter names and values of the constructor method. 282 """ 283 if self._has_constructor(): 284 raise ValueError( 285 'Constructor or constructor method can only be specified once') 286 287 self._constructor_method = method_name 288 self._constructor_param_args = args 289 self._constructor_param_kwargs = kwargs 290 291 def add_builder_method(self, method_name, *args, **kwargs): 292 """ 293 Specifies a Java builder method to be invoked after instantiating the Java 294 transform class. Specified builder method will be applied in order. 295 Arguments provided using args and kwargs will be applied to the Java 296 transform builder method in the specified order. 297 298 :param method_name: name of the builder method. 299 :param args: parameter values of the builder method. 300 :param kwargs: parameter names and values of the builder method. 301 """ 302 self._builder_methods_and_params[method_name] = (args, kwargs) 303 304 def _has_constructor(self): 305 return ( 306 self._constructor_method or self._constructor_param_args or 307 self._constructor_param_kwargs) 308 309 310 # Information regarding a SchemaTransform available in an external SDK. 311 SchemaTransformsConfig = namedtuple( 312 'SchemaTransformsConfig', 313 ['identifier', 'configuration_schema', 'inputs', 'outputs']) 314 315 316 class SchemaAwareExternalTransform(ptransform.PTransform): 317 """A proxy transform for SchemaTransforms implemented in external SDKs. 318 319 This allows Python pipelines to directly use existing SchemaTransforms 320 available to the expansion service without adding additional code in external 321 SDKs. 322 323 :param identifier: unique identifier of the SchemaTransform. 324 :param expansion_service: an expansion service to use. This should already be 325 available and the Schema-aware transforms to be used must already be 326 deployed. 327 :param rearrange_based_on_discovery: if this flag is set, the input kwargs 328 will be rearranged to match the order of fields in the external 329 SchemaTransform configuration. A discovery call will be made to fetch 330 the configuration. 331 :param classpath: (Optional) A list paths to additional jars to place on the 332 expansion service classpath. 333 :kwargs: field name to value mapping for configuring the schema transform. 334 keys map to the field names of the schema of the SchemaTransform 335 (in-order). 336 """ 337 def __init__( 338 self, 339 identifier, 340 expansion_service, 341 rearrange_based_on_discovery=False, 342 classpath=None, 343 **kwargs): 344 self._expansion_service = expansion_service 345 self._kwargs = kwargs 346 self._classpath = classpath 347 348 _kwargs = kwargs 349 if rearrange_based_on_discovery: 350 _kwargs = self._rearrange_kwargs(identifier) 351 352 self._payload_builder = SchemaTransformPayloadBuilder(identifier, **_kwargs) 353 354 def _rearrange_kwargs(self, identifier): 355 # discover and fetch the external SchemaTransform configuration then 356 # use it to build an appropriate payload 357 schematransform_config = SchemaAwareExternalTransform.discover_config( 358 self._expansion_service, identifier) 359 360 external_config_fields = schematransform_config.configuration_schema._fields 361 ordered_kwargs = OrderedDict() 362 missing_fields = [] 363 364 for field in external_config_fields: 365 if field not in self._kwargs: 366 missing_fields.append(field) 367 else: 368 ordered_kwargs[field] = self._kwargs[field] 369 370 extra_fields = list(set(self._kwargs.keys()) - set(external_config_fields)) 371 if missing_fields: 372 raise ValueError( 373 'Input parameters are missing the following SchemaTransform config ' 374 'fields: %s' % missing_fields) 375 elif extra_fields: 376 raise ValueError( 377 'Input parameters include the following extra fields that are not ' 378 'found in the SchemaTransform config schema: %s' % extra_fields) 379 380 return ordered_kwargs 381 382 def expand(self, pcolls): 383 # Expand the transform using the expansion service. 384 return pcolls | ExternalTransform( 385 common_urns.schematransform_based_expand.urn, 386 self._payload_builder, 387 self._expansion_service) 388 389 @staticmethod 390 def discover(expansion_service): 391 """Discover all SchemaTransforms available to the given expansion service. 392 393 :return: a list of SchemaTransformsConfigs that represent the discovered 394 SchemaTransforms. 395 """ 396 397 with ExternalTransform.service(expansion_service) as service: 398 discover_response = service.DiscoverSchemaTransform( 399 beam_expansion_api_pb2.DiscoverSchemaTransformRequest()) 400 401 for identifier in discover_response.schema_transform_configs: 402 proto_config = discover_response.schema_transform_configs[identifier] 403 schema = named_tuple_from_schema(proto_config.config_schema) 404 405 yield SchemaTransformsConfig( 406 identifier=identifier, 407 configuration_schema=schema, 408 inputs=proto_config.input_pcollection_names, 409 outputs=proto_config.output_pcollection_names) 410 411 @staticmethod 412 def discover_config(expansion_service, name): 413 """Discover one SchemaTransform by name in the given expansion service. 414 415 :return: one SchemaTransformsConfig that represents the discovered 416 SchemaTransform 417 418 :raises: 419 ValueError: if more than one SchemaTransform is discovered, or if none 420 are discovered 421 """ 422 423 schematransforms = SchemaAwareExternalTransform.discover(expansion_service) 424 matched = [] 425 426 for st in schematransforms: 427 if name in st.identifier: 428 matched.append(st) 429 430 if not matched: 431 raise ValueError( 432 "Did not discover any SchemaTransforms resembling the name '%s'" % 433 name) 434 elif len(matched) > 1: 435 raise ValueError( 436 "Found multiple SchemaTransforms with the name '%s':\n%s\n" % 437 (name, [st.identifier for st in matched])) 438 439 return matched[0] 440 441 442 class JavaExternalTransform(ptransform.PTransform): 443 """A proxy for Java-implemented external transforms. 444 445 One builds these transforms just as one would in Java, e.g.:: 446 447 transform = JavaExternalTransform('fully.qualified.ClassName' 448 )(contructorArg, ... ).builderMethod(...) 449 450 or:: 451 452 JavaExternalTransform('fully.qualified.ClassName').staticConstructor( 453 ...).builderMethod1(...).builderMethod2(...) 454 455 :param class_name: fully qualified name of the java class 456 :param expansion_service: (Optional) an expansion service to use. If none is 457 provided, a default expansion service will be started. 458 :param classpath: (Optional) A list paths to additional jars to place on the 459 expansion service classpath. 460 """ 461 def __init__(self, class_name, expansion_service=None, classpath=None): 462 if expansion_service and classpath: 463 raise ValueError( 464 f'Only one of expansion_service ({expansion_service}) ' 465 f'or classpath ({classpath}) may be provided.') 466 self._payload_builder = JavaClassLookupPayloadBuilder(class_name) 467 self._classpath = classpath 468 self._expansion_service = expansion_service 469 # Beam explicitly looks for following attributes. Hence adding 470 # 'None' values here to prevent '__getattr__' from being called. 471 self.inputs = None 472 self._fn_api_payload = None 473 474 def __call__(self, *args, **kwargs): 475 self._payload_builder.with_constructor(*args, **kwargs) 476 return self 477 478 def __getattr__(self, name): 479 # Don't try to emulate special methods. 480 if name.startswith('__') and name.endswith('__'): 481 return super().__getattr__(name) 482 else: 483 return self[name] 484 485 def __getitem__(self, name): 486 # Use directly for keywords or attribute conflicts. 487 def construct(*args, **kwargs): 488 if self._payload_builder._has_constructor(): 489 builder_method = self._payload_builder.add_builder_method 490 else: 491 builder_method = self._payload_builder.with_constructor_method 492 builder_method(name, *args, **kwargs) 493 return self 494 495 return construct 496 497 def expand(self, pcolls): 498 if self._expansion_service is None: 499 self._expansion_service = BeamJarExpansionService( 500 ':sdks:java:expansion-service:app:shadowJar', 501 extra_args=['{{PORT}}', '--javaClassLookupAllowlistFile=*'], 502 classpath=self._classpath) 503 return pcolls | ExternalTransform( 504 common_urns.java_class_lookup.urn, 505 self._payload_builder, 506 self._expansion_service) 507 508 509 class AnnotationBasedPayloadBuilder(SchemaBasedPayloadBuilder): 510 """ 511 Build a payload based on an external transform's type annotations. 512 """ 513 def __init__(self, transform, **values): 514 """ 515 :param transform: a PTransform instance or class. type annotations will 516 be gathered from its __init__ method 517 :param values: values to encode 518 """ 519 self._transform = transform 520 self._values = values 521 522 def _get_named_tuple_instance(self): 523 schema = named_fields_to_schema([ 524 (k, convert_to_typing_type(v)) for k, 525 v in self._transform.__init__.__annotations__.items() 526 if k in self._values 527 ]) 528 return named_tuple_from_schema(schema)(**self._values) 529 530 531 class DataclassBasedPayloadBuilder(SchemaBasedPayloadBuilder): 532 """ 533 Build a payload based on an external transform that uses dataclasses. 534 """ 535 def __init__(self, transform): 536 """ 537 :param transform: a dataclass-decorated PTransform instance from which to 538 gather type annotations and values 539 """ 540 self._transform = transform 541 542 def _get_named_tuple_instance(self): 543 import dataclasses 544 schema = named_fields_to_schema([ 545 (field.name, convert_to_typing_type(field.type)) 546 for field in dataclasses.fields(self._transform) 547 ]) 548 return named_tuple_from_schema(schema)( 549 **dataclasses.asdict(self._transform)) 550 551 552 class ExternalTransform(ptransform.PTransform): 553 """ 554 External provides a cross-language transform via expansion services in 555 foreign SDKs. 556 """ 557 _namespace_counter = 0 558 559 # Variable name _namespace conflicts with DisplayData._namespace so we use 560 # name _external_namespace here. 561 _external_namespace = threading.local() 562 563 _IMPULSE_PREFIX = 'impulse' 564 565 def __init__(self, urn, payload, expansion_service=None): 566 """Wrapper for an external transform with the given urn and payload. 567 568 :param urn: the unique beam identifier for this transform 569 :param payload: the payload, either as a byte string or a PayloadBuilder 570 :param expansion_service: an expansion service implementing the beam 571 ExpansionService protocol, either as an object with an Expand method 572 or an address (as a str) to a grpc server that provides this method. 573 """ 574 expansion_service = expansion_service or DEFAULT_EXPANSION_SERVICE 575 if not urn and isinstance(payload, JavaClassLookupPayloadBuilder): 576 urn = common_urns.java_class_lookup.urn 577 self._urn = urn 578 self._payload = ( 579 payload.payload() if isinstance(payload, PayloadBuilder) else payload) 580 self._expansion_service = expansion_service 581 self._external_namespace = self._fresh_namespace() 582 self._inputs = {} # type: Dict[str, pvalue.PCollection] 583 self._outputs = {} # type: Dict[str, pvalue.PCollection] 584 585 def with_output_types(self, *args, **kwargs): 586 return WithTypeHints.with_output_types(self, *args, **kwargs) 587 588 def replace_named_inputs(self, named_inputs): 589 self._inputs = named_inputs 590 591 def replace_named_outputs(self, named_outputs): 592 self._outputs = named_outputs 593 594 def __post_init__(self, expansion_service): 595 """ 596 This will only be invoked if ExternalTransform is used as a base class 597 for a class decorated with dataclasses.dataclass 598 """ 599 ExternalTransform.__init__( 600 self, self.URN, DataclassBasedPayloadBuilder(self), expansion_service) 601 602 def default_label(self): 603 return '%s(%s)' % (self.__class__.__name__, self._urn) 604 605 @classmethod 606 def get_local_namespace(cls): 607 return getattr(cls._external_namespace, 'value', 'external') 608 609 @classmethod 610 @contextlib.contextmanager 611 def outer_namespace(cls, namespace): 612 prev = cls.get_local_namespace() 613 cls._external_namespace.value = namespace 614 yield 615 cls._external_namespace.value = prev 616 617 @classmethod 618 def _fresh_namespace(cls): 619 # type: () -> str 620 ExternalTransform._namespace_counter += 1 621 return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter) 622 623 def expand(self, pvalueish): 624 # type: (pvalue.PCollection) -> pvalue.PCollection 625 if isinstance(pvalueish, pvalue.PBegin): 626 self._inputs = {} 627 elif isinstance(pvalueish, (list, tuple)): 628 self._inputs = {str(ix): pvalue for ix, pvalue in enumerate(pvalueish)} 629 elif isinstance(pvalueish, dict): 630 self._inputs = pvalueish 631 else: 632 self._inputs = {'input': pvalueish} 633 pipeline = ( 634 next(iter(self._inputs.values())).pipeline 635 if self._inputs else pvalueish.pipeline) 636 context = pipeline_context.PipelineContext( 637 component_id_map=pipeline.component_id_map) 638 transform_proto = beam_runner_api_pb2.PTransform( 639 unique_name=pipeline._current_transform().full_label, 640 spec=beam_runner_api_pb2.FunctionSpec( 641 urn=self._urn, payload=self._payload)) 642 for tag, pcoll in self._inputs.items(): 643 transform_proto.inputs[tag] = context.pcollections.get_id(pcoll) 644 # Conversion to/from proto assumes producers. 645 # TODO: Possibly loosen this. 646 context.transforms.put_proto( 647 '%s_%s' % (self._IMPULSE_PREFIX, tag), 648 beam_runner_api_pb2.PTransform( 649 unique_name='%s_%s' % (self._IMPULSE_PREFIX, tag), 650 spec=beam_runner_api_pb2.FunctionSpec( 651 urn=common_urns.primitives.IMPULSE.urn), 652 outputs={'out': transform_proto.inputs[tag]})) 653 output_coders = None 654 if self._type_hints.output_types: 655 if self._type_hints.output_types[0]: 656 output_coders = dict( 657 (str(k), context.coder_id_from_element_type(v)) 658 for (k, v) in enumerate(self._type_hints.output_types[0])) 659 elif self._type_hints.output_types[1]: 660 output_coders = { 661 k: context.coder_id_from_element_type(v) 662 for (k, v) in self._type_hints.output_types[1].items() 663 } 664 components = context.to_runner_api() 665 request = beam_expansion_api_pb2.ExpansionRequest( 666 components=components, 667 namespace=self._external_namespace, 668 transform=transform_proto, 669 output_coder_requests=output_coders) 670 671 with ExternalTransform.service(self._expansion_service) as service: 672 response = service.Expand(request) 673 if response.error: 674 raise RuntimeError(response.error) 675 self._expanded_components = response.components 676 if any(env.dependencies 677 for env in self._expanded_components.environments.values()): 678 self._expanded_components = self._resolve_artifacts( 679 self._expanded_components, 680 service.artifact_service(), 681 pipeline.local_tempdir) 682 683 self._expanded_transform = response.transform 684 self._expanded_requirements = response.requirements 685 result_context = pipeline_context.PipelineContext(response.components) 686 687 def fix_output(pcoll, tag): 688 pcoll.pipeline = pipeline 689 pcoll.tag = tag 690 return pcoll 691 692 self._outputs = { 693 tag: fix_output(result_context.pcollections.get_by_id(pcoll_id), tag) 694 for tag, 695 pcoll_id in self._expanded_transform.outputs.items() 696 } 697 698 return self._output_to_pvalueish(self._outputs) 699 700 @staticmethod 701 @contextlib.contextmanager 702 def service(expansion_service): 703 if isinstance(expansion_service, str): 704 channel_options = [("grpc.max_receive_message_length", -1), 705 ("grpc.max_send_message_length", -1)] 706 if hasattr(grpc, 'local_channel_credentials'): 707 # Some environments may not support insecure channels. Hence use a 708 # secure channel with local credentials here. 709 # TODO: update this to support secure non-local channels. 710 channel_factory_fn = functools.partial( 711 grpc.secure_channel, 712 expansion_service, 713 grpc.local_channel_credentials(), 714 options=channel_options) 715 else: 716 # local_channel_credentials is an experimental API which is unsupported 717 # by older versions of grpc which may be pulled in due to other project 718 # dependencies. 719 channel_factory_fn = functools.partial( 720 grpc.insecure_channel, expansion_service, options=channel_options) 721 with channel_factory_fn() as channel: 722 yield ExpansionAndArtifactRetrievalStub(channel) 723 elif hasattr(expansion_service, 'Expand'): 724 yield expansion_service 725 else: 726 with expansion_service as stub: 727 yield stub 728 729 def _resolve_artifacts(self, components, service, dest): 730 for env in components.environments.values(): 731 if env.dependencies: 732 resolved = list( 733 artifact_service.resolve_artifacts(env.dependencies, service, dest)) 734 del env.dependencies[:] 735 env.dependencies.extend(resolved) 736 return components 737 738 def _output_to_pvalueish(self, output_dict): 739 if len(output_dict) == 1: 740 return next(iter(output_dict.values())) 741 else: 742 return output_dict 743 744 def to_runner_api_transform(self, context, full_label): 745 pcoll_renames = {} 746 renamed_tag_seen = False 747 for tag, pcoll in self._inputs.items(): 748 if tag not in self._expanded_transform.inputs: 749 if renamed_tag_seen: 750 raise RuntimeError( 751 'Ambiguity due to non-preserved tags: %s vs %s' % ( 752 sorted(self._expanded_transform.inputs.keys()), 753 sorted(self._inputs.keys()))) 754 else: 755 renamed_tag_seen = True 756 tag, = self._expanded_transform.inputs.keys() 757 pcoll_renames[self._expanded_transform.inputs[tag]] = ( 758 context.pcollections.get_id(pcoll)) 759 for tag, pcoll in self._outputs.items(): 760 pcoll_renames[self._expanded_transform.outputs[tag]] = ( 761 context.pcollections.get_id(pcoll)) 762 763 def _equivalent(coder1, coder2): 764 return coder1 == coder2 or _normalize(coder1) == _normalize(coder2) 765 766 def _normalize(coder_proto): 767 normalized = copy.copy(coder_proto) 768 normalized.spec.environment_id = '' 769 # TODO(robertwb): Normalize components as well. 770 return normalized 771 772 for id, proto in self._expanded_components.coders.items(): 773 if id.startswith(self._external_namespace): 774 context.coders.put_proto(id, proto) 775 elif id in context.coders: 776 if not _equivalent(context.coders._id_to_proto[id], proto): 777 raise RuntimeError( 778 'Re-used coder id: %s\n%s\n%s' % 779 (id, context.coders._id_to_proto[id], proto)) 780 else: 781 context.coders.put_proto(id, proto) 782 for id, proto in self._expanded_components.windowing_strategies.items(): 783 if id.startswith(self._external_namespace): 784 context.windowing_strategies.put_proto(id, proto) 785 for id, proto in self._expanded_components.environments.items(): 786 if id.startswith(self._external_namespace): 787 context.environments.put_proto(id, proto) 788 for id, proto in self._expanded_components.pcollections.items(): 789 id = pcoll_renames.get(id, id) 790 if id not in context.pcollections._id_to_obj.keys(): 791 context.pcollections.put_proto(id, proto) 792 793 for id, proto in self._expanded_components.transforms.items(): 794 if id.startswith(self._IMPULSE_PREFIX): 795 # Our fake inputs. 796 continue 797 assert id.startswith( 798 self._external_namespace), (id, self._external_namespace) 799 new_proto = beam_runner_api_pb2.PTransform( 800 unique_name=proto.unique_name, 801 # If URN is not set this is an empty spec. 802 spec=proto.spec if proto.spec.urn else None, 803 subtransforms=proto.subtransforms, 804 inputs={ 805 tag: pcoll_renames.get(pcoll, pcoll) 806 for tag, 807 pcoll in proto.inputs.items() 808 }, 809 outputs={ 810 tag: pcoll_renames.get(pcoll, pcoll) 811 for tag, 812 pcoll in proto.outputs.items() 813 }, 814 display_data=proto.display_data, 815 environment_id=proto.environment_id) 816 context.transforms.put_proto(id, new_proto) 817 818 for requirement in self._expanded_requirements: 819 context.add_requirement(requirement) 820 821 return beam_runner_api_pb2.PTransform( 822 unique_name=full_label, 823 spec=self._expanded_transform.spec, 824 subtransforms=self._expanded_transform.subtransforms, 825 inputs={ 826 tag: pcoll_renames.get(pcoll, pcoll) 827 for tag, 828 pcoll in self._expanded_transform.inputs.items() 829 }, 830 outputs={ 831 tag: pcoll_renames.get(pcoll, pcoll) 832 for tag, 833 pcoll in self._expanded_transform.outputs.items() 834 }, 835 environment_id=self._expanded_transform.environment_id) 836 837 838 class ExpansionAndArtifactRetrievalStub( 839 beam_expansion_api_pb2_grpc.ExpansionServiceStub): 840 def __init__(self, channel, **kwargs): 841 self._channel = channel 842 self._kwargs = kwargs 843 super().__init__(channel, **kwargs) 844 845 def artifact_service(self): 846 return beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub( 847 self._channel, **self._kwargs) 848 849 def ready(self, timeout_sec): 850 grpc.channel_ready_future(self._channel).result(timeout=timeout_sec) 851 852 853 class JavaJarExpansionService(object): 854 """An expansion service based on an Java Jar file. 855 856 This can be passed into an ExternalTransform as the expansion_service 857 argument which will spawn a subprocess using this jar to expand the 858 transform. 859 860 Args: 861 path_to_jar: the path to a locally available executable jar file to be used 862 to start up the expansion service. 863 extra_args: arguments to be provided when starting up the 864 expansion service using the jar file. These arguments will replace the 865 default arguments. 866 classpath: Additional dependencies to be added to the classpath. 867 append_args: arguments to be provided when starting up the 868 expansion service using the jar file. These arguments will be appended to 869 the default arguments. 870 """ 871 def __init__( 872 self, path_to_jar, extra_args=None, classpath=None, append_args=None): 873 if extra_args and append_args: 874 raise ValueError('Only one of extra_args or append_args may be provided') 875 self._path_to_jar = path_to_jar 876 self._extra_args = extra_args 877 self._classpath = classpath or [] 878 self._service_count = 0 879 self._append_args = append_args or [] 880 881 @staticmethod 882 def _expand_jars(jar): 883 if glob.glob(jar): 884 return glob.glob(jar) 885 elif isinstance(jar, str) and (jar.startswith('http://') or 886 jar.startswith('https://')): 887 return [subprocess_server.JavaJarServer.local_jar(jar)] 888 else: 889 # If the input JAR is not a local glob, nor an http/https URL, then 890 # we assume that it's a gradle-style Java artifact in Maven Central, 891 # in the form group:artifact:version, so we attempt to parse that way. 892 try: 893 group_id, artifact_id, version = jar.split(':') 894 except ValueError: 895 # If we are not able to find a JAR, nor a JAR artifact, nor a URL for 896 # a JAR path, we still choose to include it in the path. 897 logging.warning('Unable to parse %s into group:artifact:version.', jar) 898 return [jar] 899 path = subprocess_server.JavaJarServer.local_jar( 900 subprocess_server.JavaJarServer.path_to_maven_jar( 901 artifact_id, group_id, version)) 902 return [path] 903 904 def _default_args(self): 905 """Default arguments to be used by `JavaJarExpansionService`.""" 906 907 to_stage = ','.join([self._path_to_jar] + sum(( 908 JavaJarExpansionService._expand_jars(jar) 909 for jar in self._classpath or []), [])) 910 return ['{{PORT}}', f'--filesToStage={to_stage}'] 911 912 def __enter__(self): 913 if self._service_count == 0: 914 self._path_to_jar = subprocess_server.JavaJarServer.local_jar( 915 self._path_to_jar) 916 if self._extra_args is None: 917 self._extra_args = self._default_args() + self._append_args 918 # Consider memoizing these servers (with some timeout). 919 logging.info( 920 'Starting a JAR-based expansion service from JAR %s ' + ( 921 'and with classpath: %s' % 922 self._classpath if self._classpath else ''), 923 self._path_to_jar) 924 classpath_urls = [ 925 subprocess_server.JavaJarServer.local_jar(path) 926 for jar in self._classpath 927 for path in JavaJarExpansionService._expand_jars(jar) 928 ] 929 self._service_provider = subprocess_server.JavaJarServer( 930 ExpansionAndArtifactRetrievalStub, 931 self._path_to_jar, 932 self._extra_args, 933 classpath=classpath_urls) 934 self._service = self._service_provider.__enter__() 935 self._service_count += 1 936 return self._service 937 938 def __exit__(self, *args): 939 self._service_count -= 1 940 if self._service_count == 0: 941 self._service_provider.__exit__(*args) 942 943 944 class BeamJarExpansionService(JavaJarExpansionService): 945 """An expansion service based on an Beam Java Jar file. 946 947 Attempts to use a locally-built copy of the jar based on the gradle target, 948 if it exists, otherwise attempts to download and cache the released artifact 949 corresponding to this version of Beam from the apache maven repository. 950 951 Args: 952 gradle_target: Beam Gradle target for building an executable jar which will 953 be used to start the expansion service. 954 extra_args: arguments to be provided when starting up the 955 expansion service using the jar file. These arguments will replace the 956 default arguments. 957 gradle_appendix: Gradle appendix of the artifact. 958 classpath: Additional dependencies to be added to the classpath. 959 append_args: arguments to be provided when starting up the 960 expansion service using the jar file. These arguments will be appended to 961 the default arguments. 962 """ 963 def __init__( 964 self, 965 gradle_target, 966 extra_args=None, 967 gradle_appendix=None, 968 classpath=None, 969 append_args=None): 970 path_to_jar = subprocess_server.JavaJarServer.path_to_beam_jar( 971 gradle_target, gradle_appendix) 972 super().__init__( 973 path_to_jar, extra_args, classpath=classpath, append_args=append_args) 974 975 976 def memoize(func): 977 cache = {} 978 979 def wrapper(*args): 980 if args not in cache: 981 cache[args] = func(*args) 982 return cache[args] 983 984 return wrapper