github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/yaml/yaml_transform.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import collections 19 import json 20 import logging 21 import pprint 22 import re 23 import uuid 24 from typing import Iterable 25 from typing import Mapping 26 27 import yaml 28 from yaml.loader import SafeLoader 29 30 import apache_beam as beam 31 from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform 32 from apache_beam.yaml import yaml_provider 33 34 __all__ = ["YamlTransform"] 35 36 _LOGGER = logging.getLogger(__name__) 37 yaml_provider.fix_pycallable() 38 39 40 def memoize_method(func): 41 def wrapper(self, *args): 42 if not hasattr(self, '_cache'): 43 self._cache = {} 44 key = func.__name__, args 45 if key not in self._cache: 46 self._cache[key] = func(self, *args) 47 return self._cache[key] 48 49 return wrapper 50 51 52 def only_element(xs): 53 x, = xs 54 return x 55 56 57 class SafeLineLoader(SafeLoader): 58 """A yaml loader that attaches line information to mappings and strings.""" 59 class TaggedString(str): 60 """A string class to which we can attach metadata. 61 62 This is primarily used to trace a string's origin back to its place in a 63 yaml file. 64 """ 65 def __reduce__(self): 66 # Pickle as an ordinary string. 67 return str, (str(self), ) 68 69 def construct_scalar(self, node): 70 value = super().construct_scalar(node) 71 if isinstance(value, str): 72 value = SafeLineLoader.TaggedString(value) 73 value._line_ = node.start_mark.line + 1 74 return value 75 76 def construct_mapping(self, node, deep=False): 77 mapping = super().construct_mapping(node, deep=deep) 78 mapping['__line__'] = node.start_mark.line + 1 79 mapping['__uuid__'] = self.create_uuid() 80 return mapping 81 82 @classmethod 83 def create_uuid(cls): 84 return str(uuid.uuid4()) 85 86 @classmethod 87 def strip_metadata(cls, spec, tagged_str=True): 88 if isinstance(spec, Mapping): 89 return { 90 key: cls.strip_metadata(value, tagged_str) 91 for key, 92 value in spec.items() if key not in ('__line__', '__uuid__') 93 } 94 elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): 95 return [cls.strip_metadata(value, tagged_str) for value in spec] 96 elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: 97 return str(spec) 98 else: 99 return spec 100 101 @staticmethod 102 def get_line(obj): 103 if isinstance(obj, dict): 104 return obj.get('__line__', 'unknown') 105 else: 106 return getattr(obj, '_line_', 'unknown') 107 108 109 class LightweightScope(object): 110 def __init__(self, transforms): 111 self._transforms = transforms 112 self._transforms_by_uuid = {t['__uuid__']: t for t in self._transforms} 113 self._uuid_by_name = collections.defaultdict(list) 114 for spec in self._transforms: 115 if 'name' in spec: 116 self._uuid_by_name[spec['name']].append(spec['__uuid__']) 117 if 'type' in spec: 118 self._uuid_by_name[spec['type']].append(spec['__uuid__']) 119 120 def get_transform_id_and_output_name(self, name): 121 if '.' in name: 122 transform_name, output = name.rsplit('.', 1) 123 else: 124 transform_name, output = name, None 125 return self.get_transform_id(transform_name), output 126 127 def get_transform_id(self, transform_name): 128 if transform_name in self._transforms_by_uuid: 129 return transform_name 130 else: 131 candidates = self._uuid_by_name[transform_name] 132 if not candidates: 133 raise ValueError( 134 f'Unknown transform at line ' 135 f'{SafeLineLoader.get_line(transform_name)}: {transform_name}') 136 elif len(candidates) > 1: 137 raise ValueError( 138 f'Ambiguous transform at line ' 139 f'{SafeLineLoader.get_line(transform_name)}: {transform_name}') 140 else: 141 return only_element(candidates) 142 143 144 class Scope(LightweightScope): 145 """To look up PCollections (typically outputs of prior transforms) by name.""" 146 def __init__(self, root, inputs, transforms, providers): 147 super().__init__(transforms) 148 self.root = root 149 self._inputs = inputs 150 self.providers = providers 151 self._seen_names = set() 152 153 def compute_all(self): 154 for transform_id in self._transforms_by_uuid.keys(): 155 self.compute_outputs(transform_id) 156 157 def get_pcollection(self, name): 158 if name in self._inputs: 159 return self._inputs[name] 160 elif '.' in name: 161 transform, output = name.rsplit('.', 1) 162 outputs = self.get_outputs(transform) 163 if output in outputs: 164 return outputs[output] 165 else: 166 raise ValueError( 167 f'Unknown output {repr(output)} ' 168 f'at line {SafeLineLoader.get_line(name)}: ' 169 f'{transform} only has outputs {list(outputs.keys())}') 170 else: 171 outputs = self.get_outputs(name) 172 if len(outputs) == 1: 173 return only_element(outputs.values()) 174 else: 175 raise ValueError( 176 f'Ambiguous output at line {SafeLineLoader.get_line(name)}: ' 177 f'{name} has outputs {list(outputs.keys())}') 178 179 def get_outputs(self, transform_name): 180 return self.compute_outputs(self.get_transform_id(transform_name)) 181 182 @memoize_method 183 def compute_outputs(self, transform_id): 184 return expand_transform(self._transforms_by_uuid[transform_id], self) 185 186 # A method on scope as providers may be scoped... 187 def create_ptransform(self, spec): 188 if 'type' not in spec: 189 raise ValueError(f'Missing transform type: {identify_object(spec)}') 190 191 if spec['type'] not in self.providers: 192 raise ValueError( 193 'Unknown transform type %r at %s' % 194 (spec['type'], identify_object(spec))) 195 196 for provider in self.providers.get(spec['type']): 197 if provider.available(): 198 break 199 else: 200 raise ValueError( 201 'No available provider for type %r at %s' % 202 (spec['type'], identify_object(spec))) 203 204 if 'args' in spec: 205 args = spec['args'] 206 if not isinstance(args, dict): 207 raise ValueError( 208 'Arguments for transform at %s must be a mapping.' % 209 identify_object(spec)) 210 else: 211 args = { 212 key: value 213 for (key, value) in spec.items() 214 if key not in ('type', 'name', 'input', 'output') 215 } 216 real_args = SafeLineLoader.strip_metadata(args) 217 try: 218 # pylint: disable=undefined-loop-variable 219 ptransform = provider.create_transform(spec['type'], real_args) 220 # TODO(robertwb): Should we have a better API for adding annotations 221 # than this? 222 annotations = dict( 223 yaml_type=spec['type'], 224 yaml_args=json.dumps(real_args), 225 yaml_provider=json.dumps(provider.to_json()), 226 **ptransform.annotations()) 227 ptransform.annotations = lambda: annotations 228 return ptransform 229 except Exception as exn: 230 if isinstance(exn, TypeError): 231 # Create a slightly more generic error message for argument errors. 232 msg = str(exn).replace('positional', '').replace('keyword', '') 233 msg = re.sub(r'\S+lambda\S+', '', msg) 234 msg = re.sub(' +', ' ', msg).strip() 235 else: 236 msg = str(exn) 237 raise ValueError( 238 f'Invalid transform specification at {identify_object(spec)}: {msg}' 239 ) from exn 240 241 def unique_name(self, spec, ptransform, strictness=0): 242 if 'name' in spec: 243 name = spec['name'] 244 strictness += 1 245 else: 246 name = ptransform.label 247 if name in self._seen_names: 248 if strictness >= 2: 249 raise ValueError(f'Duplicate name at {identify_object(spec)}: {name}') 250 else: 251 name = f'{name}@{SafeLineLoader.get_line(spec)}' 252 self._seen_names.add(name) 253 return name 254 255 256 def expand_transform(spec, scope): 257 if 'type' not in spec: 258 raise TypeError( 259 f'Missing type parameter for transform at {identify_object(spec)}') 260 type = spec['type'] 261 if type == 'composite': 262 return expand_composite_transform(spec, scope) 263 else: 264 return expand_leaf_transform(spec, scope) 265 266 267 def expand_leaf_transform(spec, scope): 268 spec = normalize_inputs_outputs(spec) 269 inputs_dict = { 270 key: scope.get_pcollection(value) 271 for (key, value) in spec['input'].items() 272 } 273 input_type = spec.get('input_type', 'default') 274 if input_type == 'list': 275 inputs = tuple(inputs_dict.values()) 276 elif input_type == 'map': 277 inputs = inputs_dict 278 else: 279 if len(inputs_dict) == 0: 280 inputs = scope.root 281 elif len(inputs_dict) == 1: 282 inputs = next(iter(inputs_dict.values())) 283 else: 284 inputs = inputs_dict 285 _LOGGER.info("Expanding %s ", identify_object(spec)) 286 ptransform = scope.create_ptransform(spec) 287 try: 288 # TODO: Move validation to construction? 289 with FullyQualifiedNamedTransform.with_filter('*'): 290 outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform 291 except Exception as exn: 292 raise ValueError( 293 f"Error apply transform {identify_object(spec)}: {exn}") from exn 294 if isinstance(outputs, dict): 295 # TODO: Handle (or at least reject) nested case. 296 return outputs 297 elif isinstance(outputs, (tuple, list)): 298 return {'out{ix}': pcoll for (ix, pcoll) in enumerate(outputs)} 299 elif isinstance(outputs, beam.PCollection): 300 return {'out': outputs} 301 else: 302 raise ValueError( 303 f'Transform {identify_object(spec)} returned an unexpected type ' 304 f'{type(outputs)}') 305 306 307 def expand_composite_transform(spec, scope): 308 spec = normalize_inputs_outputs(normalize_source_sink(spec)) 309 310 inner_scope = Scope( 311 scope.root, { 312 key: scope.get_pcollection(value) 313 for key, 314 value in spec['input'].items() 315 }, 316 spec['transforms'], 317 yaml_provider.merge_providers( 318 yaml_provider.parse_providers(spec.get('providers', [])), 319 scope.providers)) 320 321 class CompositePTransform(beam.PTransform): 322 @staticmethod 323 def expand(inputs): 324 inner_scope.compute_all() 325 return { 326 key: inner_scope.get_pcollection(value) 327 for (key, value) in spec['output'].items() 328 } 329 330 if 'name' not in spec: 331 spec['name'] = 'Composite' 332 if spec['name'] is None: # top-level pipeline, don't nest 333 return CompositePTransform.expand(None) 334 else: 335 _LOGGER.info("Expanding %s ", identify_object(spec)) 336 return ({ 337 key: scope.get_pcollection(value) 338 for key, 339 value in spec['input'].items() 340 } or scope.root) | scope.unique_name(spec, None) >> CompositePTransform() 341 342 343 def expand_chain_transform(spec, scope): 344 return expand_composite_transform(chain_as_composite(spec), scope) 345 346 347 def chain_as_composite(spec): 348 # A chain is simply a composite transform where all inputs and outputs 349 # are implicit. 350 spec = normalize_source_sink(spec) 351 if 'transforms' not in spec: 352 raise TypeError( 353 f"Chain at {identify_object(spec)} missing transforms property.") 354 has_explicit_outputs = 'output' in spec 355 composite_spec = normalize_inputs_outputs(spec) 356 new_transforms = [] 357 for ix, transform in enumerate(composite_spec['transforms']): 358 if any(io in transform for io in ('input', 'output', 'input', 'output')): 359 raise ValueError( 360 f'Transform {identify_object(transform)} is part of a chain, ' 361 'must have implicit inputs and outputs.') 362 if ix == 0: 363 transform['input'] = {key: key for key in composite_spec['input'].keys()} 364 else: 365 transform['input'] = new_transforms[-1]['__uuid__'] 366 new_transforms.append(transform) 367 composite_spec['transforms'] = new_transforms 368 369 last_transform = new_transforms[-1]['__uuid__'] 370 if has_explicit_outputs: 371 composite_spec['output'] = { 372 key: f'{last_transform}.{value}' 373 for (key, value) in composite_spec['output'].items() 374 } 375 else: 376 composite_spec['output'] = last_transform 377 if 'name' not in composite_spec: 378 composite_spec['name'] = 'Chain' 379 composite_spec['type'] = 'composite' 380 return composite_spec 381 382 383 def preprocess_chain(spec): 384 if spec['type'] == 'chain': 385 return chain_as_composite(spec) 386 else: 387 return spec 388 389 390 def pipeline_as_composite(spec): 391 if isinstance(spec, list): 392 return { 393 'type': 'composite', 394 'name': None, 395 'transforms': spec, 396 '__line__': spec[0]['__line__'], 397 '__uuid__': SafeLineLoader.create_uuid(), 398 } 399 else: 400 return dict(spec, name=None, type=spec.get('type', 'composite')) 401 402 403 def normalize_source_sink(spec): 404 if 'source' not in spec and 'sink' not in spec: 405 return spec 406 spec = dict(spec) 407 spec['transforms'] = list(spec.get('transforms', [])) 408 if 'source' in spec: 409 spec['transforms'].insert(0, spec.pop('source')) 410 if 'sink' in spec: 411 spec['transforms'].append(spec.pop('sink')) 412 return spec 413 414 415 def preprocess_source_sink(spec): 416 if spec['type'] in ('chain', 'composite'): 417 return normalize_source_sink(spec) 418 else: 419 return spec 420 421 422 def normalize_inputs_outputs(spec): 423 spec = dict(spec) 424 425 def normalize_io(tag): 426 io = spec.get(tag, {}) 427 if isinstance(io, (str, list)): 428 return {tag: io} 429 else: 430 return SafeLineLoader.strip_metadata(io, tagged_str=False) 431 432 return dict(spec, input=normalize_io('input'), output=normalize_io('output')) 433 434 435 def identify_object(spec): 436 line = SafeLineLoader.get_line(spec) 437 name = extract_name(spec) 438 if name: 439 return f'"{name}" at line {line}' 440 else: 441 return f'at line {line}' 442 443 444 def extract_name(spec): 445 if 'name' in spec: 446 return spec['name'] 447 elif 'id' in spec: 448 return spec['id'] 449 elif 'type' in spec: 450 return spec['type'] 451 elif len(spec) == 1: 452 return extract_name(next(iter(spec.values()))) 453 else: 454 return '' 455 456 457 def push_windowing_to_roots(spec): 458 scope = LightweightScope(spec['transforms']) 459 consumed_outputs_by_transform = collections.defaultdict(set) 460 for transform in spec['transforms']: 461 for _, input_ref in transform['input'].items(): 462 try: 463 transform_id, output = scope.get_transform_id_and_output_name(input_ref) 464 consumed_outputs_by_transform[transform_id].add(output) 465 except ValueError: 466 # Could be an input or an ambiguity we'll raise later. 467 pass 468 469 for transform in spec['transforms']: 470 if not transform['input'] and 'windowing' not in transform: 471 transform['windowing'] = spec['windowing'] 472 transform['__consumed_outputs'] = consumed_outputs_by_transform[ 473 transform['__uuid__']] 474 475 return spec 476 477 478 def preprocess_windowing(spec): 479 if spec['type'] == 'WindowInto': 480 # This is the transform where it is actually applied. 481 return spec 482 elif 'windowing' not in spec: 483 # Nothing to do. 484 return spec 485 486 if spec['type'] == 'composite': 487 # Apply the windowing to any reads, creates, etc. in this transform 488 # TODO(robertwb): Better handle the case where a read is followed by a 489 # setting of the timestamps. We should be careful of sliding windows 490 # in particular. 491 spec = push_windowing_to_roots(spec) 492 493 windowing = spec.pop('windowing') 494 if spec['input']: 495 # Apply the windowing to all inputs by wrapping it in a trasnform that 496 # first applies windowing and then applies the original transform. 497 original_inputs = spec['input'] 498 windowing_transforms = [{ 499 'type': 'WindowInto', 500 'name': f'WindowInto[{key}]', 501 'windowing': windowing, 502 'input': key, 503 '__line__': spec['__line__'], 504 '__uuid__': SafeLineLoader.create_uuid(), 505 } for key in original_inputs.keys()] 506 windowed_inputs = { 507 key: t['__uuid__'] 508 for (key, t) in zip(original_inputs.keys(), windowing_transforms) 509 } 510 modified_spec = dict( 511 spec, input=windowed_inputs, __uuid__=SafeLineLoader.create_uuid()) 512 return { 513 'type': 'composite', 514 'name': spec.get('name', None) or spec['type'], 515 'transforms': [modified_spec] + windowing_transforms, 516 'input': spec['input'], 517 'output': modified_spec['__uuid__'], 518 '__line__': spec['__line__'], 519 '__uuid__': spec['__uuid__'], 520 } 521 522 elif spec['type'] == 'composite': 523 # Pushing the windowing down was sufficient. 524 return spec 525 526 else: 527 # No inputs, apply the windowing to all outputs. 528 consumed_outputs = list(spec.pop('__consumed_outputs', {None})) 529 modified_spec = dict(spec, __uuid__=SafeLineLoader.create_uuid()) 530 windowing_transforms = [{ 531 'type': 'WindowInto', 532 'name': f'WindowInto[{out}]', 533 'windowing': windowing, 534 'input': modified_spec['__uuid__'] + ('.' + out if out else ''), 535 '__line__': spec['__line__'], 536 '__uuid__': SafeLineLoader.create_uuid(), 537 } for out in consumed_outputs] 538 if consumed_outputs == [None]: 539 windowed_outputs = only_element(windowing_transforms)['__uuid__'] 540 else: 541 windowed_outputs = { 542 out: t['__uuid__'] 543 for (out, t) in zip(consumed_outputs, windowing_transforms) 544 } 545 return { 546 'type': 'composite', 547 'name': spec.get('name', None) or spec['type'], 548 'transforms': [modified_spec] + windowing_transforms, 549 'output': windowed_outputs, 550 '__line__': spec['__line__'], 551 '__uuid__': spec['__uuid__'], 552 } 553 554 555 def preprocess_flattened_inputs(spec): 556 if spec['type'] != 'composite': 557 return spec 558 559 # Prefer to add the flattens as sibling operations rather than nesting 560 # to keep graph shape consistent when the number of inputs goes from 561 # one to multiple. 562 new_transforms = [] 563 for t in spec['transforms']: 564 if t['type'] == 'Flatten': 565 # Don't flatten before explicit flatten. 566 # But we do have to expand list inputs into singleton inputs. 567 def all_inputs(t): 568 for key, values in t.get('input', {}).items(): 569 if isinstance(values, list): 570 for ix, values in enumerate(values): 571 yield f'{key}{ix}', values 572 else: 573 yield key, values 574 575 inputs_dict = {} 576 for key, value in all_inputs(t): 577 while key in inputs_dict: 578 key += '_' 579 inputs_dict[key] = value 580 t = dict(t, input=inputs_dict) 581 else: 582 replaced_inputs = {} 583 for key, values in t.get('input', {}).items(): 584 if isinstance(values, list): 585 flatten_id = SafeLineLoader.create_uuid() 586 new_transforms.append({ 587 'type': 'Flatten', 588 'name': '%s-Flatten[%s]' % (t.get('name', t['type']), key), 589 'input': { 590 f'input{ix}': value 591 for (ix, value) in enumerate(values) 592 }, 593 '__line__': spec['__line__'], 594 '__uuid__': flatten_id, 595 }) 596 replaced_inputs[key] = flatten_id 597 if replaced_inputs: 598 t = dict(t, input={**t['input'], **replaced_inputs}) 599 new_transforms.append(t) 600 return dict(spec, transforms=new_transforms) 601 602 603 def preprocess(spec, verbose=False): 604 if verbose: 605 pprint.pprint(spec) 606 607 def apply(phase, spec): 608 spec = phase(spec) 609 if spec['type'] in {'composite', 'chain'}: 610 spec = dict( 611 spec, transforms=[apply(phase, t) for t in spec['transforms']]) 612 return spec 613 614 for phase in [preprocess_source_sink, 615 preprocess_chain, 616 normalize_inputs_outputs, 617 preprocess_flattened_inputs, 618 preprocess_windowing]: 619 spec = apply(phase, spec) 620 if verbose: 621 print('=' * 20, phase, '=' * 20) 622 pprint.pprint(spec) 623 return spec 624 625 626 class YamlTransform(beam.PTransform): 627 def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-value 628 if isinstance(spec, str): 629 spec = yaml.load(spec, Loader=SafeLineLoader) 630 self._spec = preprocess(spec) 631 self._providers = yaml_provider.merge_providers( 632 { 633 key: yaml_provider.as_provider_list(key, value) 634 for (key, value) in providers.items() 635 }, 636 yaml_provider.standard_providers()) 637 638 def expand(self, pcolls): 639 if isinstance(pcolls, beam.pvalue.PBegin): 640 root = pcolls 641 pcolls = {} 642 elif isinstance(pcolls, beam.PCollection): 643 root = pcolls.pipeline 644 pcolls = {'input': pcolls} 645 else: 646 root = next(iter(pcolls.values())).pipeline 647 result = expand_transform( 648 self._spec, 649 Scope(root, pcolls, transforms=[], providers=self._providers)) 650 if len(result) == 1: 651 return only_element(result.values()) 652 else: 653 return result 654 655 656 def expand_pipeline(pipeline, pipeline_spec, providers=None): 657 if isinstance(pipeline_spec, str): 658 pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader) 659 # Calling expand directly to avoid outer layer of nesting. 660 return YamlTransform( 661 pipeline_as_composite(pipeline_spec['pipeline']), 662 { 663 **yaml_provider.parse_providers(pipeline_spec.get('providers', [])), 664 **{ 665 key: yaml_provider.as_provider_list(key, value) 666 for (key, value) in (providers or {}).items() 667 } 668 }).expand(beam.pvalue.PBegin(pipeline))