github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/util.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Simple utility PTransforms. 19 """ 20 21 # pytype: skip-file 22 23 import collections 24 import contextlib 25 import logging 26 import random 27 import re 28 import threading 29 import time 30 import uuid 31 from typing import TYPE_CHECKING 32 from typing import Any 33 from typing import Iterable 34 from typing import List 35 from typing import Tuple 36 from typing import TypeVar 37 from typing import Union 38 39 from apache_beam import coders 40 from apache_beam import typehints 41 from apache_beam.metrics import Metrics 42 from apache_beam.portability import common_urns 43 from apache_beam.portability.api import beam_runner_api_pb2 44 from apache_beam.pvalue import AsSideInput 45 from apache_beam.transforms import window 46 from apache_beam.transforms.combiners import CountCombineFn 47 from apache_beam.transforms.core import CombinePerKey 48 from apache_beam.transforms.core import Create 49 from apache_beam.transforms.core import DoFn 50 from apache_beam.transforms.core import FlatMap 51 from apache_beam.transforms.core import Flatten 52 from apache_beam.transforms.core import GroupByKey 53 from apache_beam.transforms.core import Map 54 from apache_beam.transforms.core import MapTuple 55 from apache_beam.transforms.core import ParDo 56 from apache_beam.transforms.core import Windowing 57 from apache_beam.transforms.ptransform import PTransform 58 from apache_beam.transforms.ptransform import ptransform_fn 59 from apache_beam.transforms.timeutil import TimeDomain 60 from apache_beam.transforms.trigger import AccumulationMode 61 from apache_beam.transforms.trigger import Always 62 from apache_beam.transforms.userstate import BagStateSpec 63 from apache_beam.transforms.userstate import CombiningValueStateSpec 64 from apache_beam.transforms.userstate import TimerSpec 65 from apache_beam.transforms.userstate import on_timer 66 from apache_beam.transforms.window import NonMergingWindowFn 67 from apache_beam.transforms.window import TimestampCombiner 68 from apache_beam.transforms.window import TimestampedValue 69 from apache_beam.typehints import trivial_inference 70 from apache_beam.typehints.decorators import get_signature 71 from apache_beam.typehints.sharded_key_type import ShardedKeyType 72 from apache_beam.utils import windowed_value 73 from apache_beam.utils.annotations import deprecated 74 from apache_beam.utils.sharded_key import ShardedKey 75 76 if TYPE_CHECKING: 77 from apache_beam import pvalue 78 from apache_beam.runners.pipeline_context import PipelineContext 79 80 __all__ = [ 81 'BatchElements', 82 'CoGroupByKey', 83 'Distinct', 84 'Keys', 85 'KvSwap', 86 'LogElements', 87 'Regex', 88 'Reify', 89 'RemoveDuplicates', 90 'Reshuffle', 91 'ToString', 92 'Values', 93 'WithKeys', 94 'GroupIntoBatches' 95 ] 96 97 K = TypeVar('K') 98 V = TypeVar('V') 99 T = TypeVar('T') 100 101 102 class CoGroupByKey(PTransform): 103 """Groups results across several PCollections by key. 104 105 Given an input dict of serializable keys (called "tags") to 0 or more 106 PCollections of (key, value) tuples, it creates a single output PCollection 107 of (key, value) tuples whose keys are the unique input keys from all inputs, 108 and whose values are dicts mapping each tag to an iterable of whatever values 109 were under the key in the corresponding PCollection, in this manner:: 110 111 ('some key', {'tag1': ['value 1 under "some key" in pcoll1', 112 'value 2 under "some key" in pcoll1', 113 ...], 114 'tag2': ... , 115 ... }) 116 117 where `[]` refers to an iterable, not a list. 118 119 For example, given:: 120 121 {'tag1': pc1, 'tag2': pc2, 333: pc3} 122 123 where:: 124 125 pc1 = beam.Create([(k1, v1)])) 126 pc2 = beam.Create([]) 127 pc3 = beam.Create([(k1, v31), (k1, v32), (k2, v33)]) 128 129 The output PCollection would consist of items:: 130 131 [(k1, {'tag1': [v1], 'tag2': [], 333: [v31, v32]}), 132 (k2, {'tag1': [], 'tag2': [], 333: [v33]})] 133 134 where `[]` refers to an iterable, not a list. 135 136 CoGroupByKey also works for tuples, lists, or other flat iterables of 137 PCollections, in which case the values of the resulting PCollections 138 will be tuples whose nth value is the iterable of values from the nth 139 PCollection---conceptually, the "tags" are the indices into the input. 140 Thus, for this input:: 141 142 (pc1, pc2, pc3) 143 144 the output would be:: 145 146 [(k1, ([v1], [], [v31, v32]), 147 (k2, ([], [], [v33]))] 148 149 where, again, `[]` refers to an iterable, not a list. 150 151 Attributes: 152 **kwargs: Accepts a single named argument "pipeline", which specifies the 153 pipeline that "owns" this PTransform. Ordinarily CoGroupByKey can obtain 154 this information from one of the input PCollections, but if there are none 155 (or if there's a chance there may be none), this argument is the only way 156 to provide pipeline information, and should be considered mandatory. 157 """ 158 def __init__(self, *, pipeline=None): 159 self.pipeline = pipeline 160 161 def _extract_input_pvalues(self, pvalueish): 162 try: 163 # If this works, it's a dict. 164 return pvalueish, tuple(pvalueish.values()) 165 except AttributeError: 166 # Cast iterables a tuple so we can do re-iteration. 167 pcolls = tuple(pvalueish) 168 return pcolls, pcolls 169 170 def expand(self, pcolls): 171 if not pcolls: 172 pcolls = (self.pipeline | Create([]), ) 173 if isinstance(pcolls, dict): 174 tags = list(pcolls.keys()) 175 if all(isinstance(tag, str) and len(tag) < 10 for tag in tags): 176 # Small, string tags. Pass them as data. 177 pcolls_dict = pcolls 178 restore_tags = None 179 else: 180 # Pass the tags in the restore_tags closure. 181 tags = list(pcolls.keys()) 182 pcolls_dict = {str(ix): pcolls[tag] for (ix, tag) in enumerate(tags)} 183 restore_tags = lambda vs: { 184 tag: vs[str(ix)] 185 for (ix, tag) in enumerate(tags) 186 } 187 else: 188 # Tags are tuple indices. 189 tags = [str(ix) for ix in range(len(pcolls))] 190 pcolls_dict = dict(zip(tags, pcolls)) 191 restore_tags = lambda vs: tuple(vs[tag] for tag in tags) 192 193 input_key_types = [] 194 input_value_types = [] 195 for pcoll in pcolls_dict.values(): 196 key_type, value_type = typehints.trivial_inference.key_value_types( 197 pcoll.element_type) 198 input_key_types.append(key_type) 199 input_value_types.append(value_type) 200 output_key_type = typehints.Union[tuple(input_key_types)] 201 iterable_input_value_types = tuple( 202 typehints.Iterable[t] for t in input_value_types) 203 204 output_value_type = typehints.Dict[ 205 str, typehints.Union[iterable_input_value_types or [typehints.Any]]] 206 result = ( 207 pcolls_dict 208 | 'CoGroupByKeyImpl' >> 209 _CoGBKImpl(pipeline=self.pipeline).with_output_types( 210 typehints.Tuple[output_key_type, output_value_type])) 211 212 if restore_tags: 213 if isinstance(pcolls, dict): 214 dict_key_type = typehints.Union[tuple( 215 trivial_inference.instance_to_type(tag) for tag in tags)] 216 output_value_type = typehints.Dict[ 217 dict_key_type, typehints.Union[iterable_input_value_types]] 218 else: 219 output_value_type = typehints.Tuple[iterable_input_value_types] 220 result |= 'RestoreTags' >> MapTuple( 221 lambda k, vs: (k, restore_tags(vs))).with_output_types( 222 typehints.Tuple[output_key_type, output_value_type]) 223 224 return result 225 226 227 class _CoGBKImpl(PTransform): 228 def __init__(self, *, pipeline=None): 229 self.pipeline = pipeline 230 231 def expand(self, pcolls): 232 # Check input PCollections for PCollection-ness, and that they all belong 233 # to the same pipeline. 234 for pcoll in pcolls.values(): 235 self._check_pcollection(pcoll) 236 if self.pipeline: 237 assert pcoll.pipeline == self.pipeline, ( 238 'All input PCollections must belong to the same pipeline.') 239 240 tags = list(pcolls.keys()) 241 242 def add_tag(tag): 243 return lambda k, v: (k, (tag, v)) 244 245 def collect_values(key, tagged_values): 246 grouped_values = {tag: [] for tag in tags} 247 for tag, value in tagged_values: 248 grouped_values[tag].append(value) 249 return key, grouped_values 250 251 return ([ 252 pcoll 253 | 'Tag[%s]' % tag >> MapTuple(add_tag(tag)) 254 for (tag, pcoll) in pcolls.items() 255 ] 256 | Flatten(pipeline=self.pipeline) 257 | GroupByKey() 258 | MapTuple(collect_values)) 259 260 261 @ptransform_fn 262 @typehints.with_input_types(Tuple[K, V]) 263 @typehints.with_output_types(K) 264 def Keys(pcoll, label='Keys'): # pylint: disable=invalid-name 265 """Produces a PCollection of first elements of 2-tuples in a PCollection.""" 266 return pcoll | label >> MapTuple(lambda k, _: k) 267 268 269 @ptransform_fn 270 @typehints.with_input_types(Tuple[K, V]) 271 @typehints.with_output_types(V) 272 def Values(pcoll, label='Values'): # pylint: disable=invalid-name 273 """Produces a PCollection of second elements of 2-tuples in a PCollection.""" 274 return pcoll | label >> MapTuple(lambda _, v: v) 275 276 277 @ptransform_fn 278 @typehints.with_input_types(Tuple[K, V]) 279 @typehints.with_output_types(Tuple[V, K]) 280 def KvSwap(pcoll, label='KvSwap'): # pylint: disable=invalid-name 281 """Produces a PCollection reversing 2-tuples in a PCollection.""" 282 return pcoll | label >> MapTuple(lambda k, v: (v, k)) 283 284 285 @ptransform_fn 286 @typehints.with_input_types(T) 287 @typehints.with_output_types(T) 288 def Distinct(pcoll): # pylint: disable=invalid-name 289 """Produces a PCollection containing distinct elements of a PCollection.""" 290 return ( 291 pcoll 292 | 'ToPairs' >> Map(lambda v: (v, None)) 293 | 'Group' >> CombinePerKey(lambda vs: None) 294 | 'Distinct' >> Keys()) 295 296 297 @deprecated(since='2.12', current='Distinct') 298 @ptransform_fn 299 @typehints.with_input_types(T) 300 @typehints.with_output_types(T) 301 def RemoveDuplicates(pcoll): 302 """Produces a PCollection containing distinct elements of a PCollection.""" 303 return pcoll | 'RemoveDuplicates' >> Distinct() 304 305 306 class _BatchSizeEstimator(object): 307 """Estimates the best size for batches given historical timing. 308 """ 309 310 _MAX_DATA_POINTS = 100 311 _MAX_GROWTH_FACTOR = 2 312 313 def __init__( 314 self, 315 min_batch_size=1, 316 max_batch_size=10000, 317 target_batch_overhead=.05, 318 target_batch_duration_secs=10, 319 target_batch_duration_secs_including_fixed_cost=None, 320 variance=0.25, 321 clock=time.time, 322 ignore_first_n_seen_per_batch_size=0, 323 record_metrics=True): 324 if min_batch_size > max_batch_size: 325 raise ValueError( 326 "Minimum (%s) must not be greater than maximum (%s)" % 327 (min_batch_size, max_batch_size)) 328 if target_batch_overhead and not 0 < target_batch_overhead <= 1: 329 raise ValueError( 330 "target_batch_overhead (%s) must be between 0 and 1" % 331 (target_batch_overhead)) 332 if target_batch_duration_secs and target_batch_duration_secs <= 0: 333 raise ValueError( 334 "target_batch_duration_secs (%s) must be positive" % 335 (target_batch_duration_secs)) 336 if (target_batch_duration_secs_including_fixed_cost and 337 target_batch_duration_secs_including_fixed_cost <= 0): 338 raise ValueError( 339 "target_batch_duration_secs_including_fixed_cost " 340 "(%s) must be positive" % 341 (target_batch_duration_secs_including_fixed_cost)) 342 if not (target_batch_overhead or target_batch_duration_secs or 343 target_batch_duration_secs_including_fixed_cost): 344 raise ValueError( 345 "At least one of target_batch_overhead or " 346 "target_batch_duration_secs or " 347 "target_batch_duration_secs_including_fixed_cost must be positive.") 348 if ignore_first_n_seen_per_batch_size < 0: 349 raise ValueError( 350 'ignore_first_n_seen_per_batch_size (%s) must be non ' 351 'negative' % (ignore_first_n_seen_per_batch_size)) 352 self._min_batch_size = min_batch_size 353 self._max_batch_size = max_batch_size 354 self._target_batch_overhead = target_batch_overhead 355 self._target_batch_duration_secs = target_batch_duration_secs 356 self._target_batch_duration_secs_including_fixed_cost = ( 357 target_batch_duration_secs_including_fixed_cost) 358 self._variance = variance 359 self._clock = clock 360 self._data = [] 361 self._ignore_next_timing = False 362 self._ignore_first_n_seen_per_batch_size = ( 363 ignore_first_n_seen_per_batch_size) 364 self._batch_size_num_seen = {} 365 self._replay_last_batch_size = None 366 self._record_metrics = record_metrics 367 self._element_count = 0 368 self._batch_count = 0 369 370 if record_metrics: 371 self._size_distribution = Metrics.distribution( 372 'BatchElements', 'batch_size') 373 self._time_distribution = Metrics.distribution( 374 'BatchElements', 'msec_per_batch') 375 else: 376 self._size_distribution = self._time_distribution = None 377 # Beam distributions only accept integer values, so we use this to 378 # accumulate under-reported values until they add up to whole milliseconds. 379 # (Milliseconds are chosen because that's conventionally used elsewhere in 380 # profiling-style counters.) 381 self._remainder_msecs = 0 382 383 def ignore_next_timing(self): 384 """Call to indicate the next timing should be ignored. 385 386 For example, the first emit of a ParDo operation is known to be anomalous 387 due to setup that may occur. 388 """ 389 self._ignore_next_timing = True 390 391 @contextlib.contextmanager 392 def record_time(self, batch_size): 393 start = self._clock() 394 yield 395 elapsed = self._clock() - start 396 elapsed_msec = 1e3 * elapsed + self._remainder_msecs 397 if self._record_metrics: 398 self._size_distribution.update(batch_size) 399 self._time_distribution.update(int(elapsed_msec)) 400 self._element_count += batch_size 401 self._batch_count += 1 402 self._remainder_msecs = elapsed_msec - int(elapsed_msec) 403 # If we ignore the next timing, replay the batch size to get accurate 404 # timing. 405 if self._ignore_next_timing: 406 self._ignore_next_timing = False 407 self._replay_last_batch_size = min(batch_size, self._max_batch_size) 408 else: 409 self._data.append((batch_size, elapsed)) 410 if len(self._data) >= self._MAX_DATA_POINTS: 411 self._thin_data() 412 413 def _thin_data(self): 414 # Make sure we don't change the parity of len(self._data) 415 # As it's used below to alternate jitter. 416 self._data.pop(random.randrange(len(self._data) // 4)) 417 self._data.pop(random.randrange(len(self._data) // 2)) 418 419 @staticmethod 420 def linear_regression_no_numpy(xs, ys): 421 # Least squares fit for y = a + bx over all points. 422 n = float(len(xs)) 423 xbar = sum(xs) / n 424 ybar = sum(ys) / n 425 if xbar == 0: 426 return ybar, 0 427 if all(xs[0] == x for x in xs): 428 # Simply use the mean if all values in xs are same. 429 return 0, ybar / xbar 430 b = ( 431 sum([(x - xbar) * (y - ybar) 432 for x, y in zip(xs, ys)]) / sum([(x - xbar)**2 for x in xs])) 433 a = ybar - b * xbar 434 return a, b 435 436 @staticmethod 437 def linear_regression_numpy(xs, ys): 438 # pylint: disable=wrong-import-order, wrong-import-position 439 import numpy as np 440 from numpy import sum 441 n = len(xs) 442 if all(xs[0] == x for x in xs): 443 # If all values of xs are same then fallback to linear_regression_no_numpy 444 return _BatchSizeEstimator.linear_regression_no_numpy(xs, ys) 445 xs = np.asarray(xs, dtype=float) 446 ys = np.asarray(ys, dtype=float) 447 448 # First do a simple least squares fit for y = a + bx over all points. 449 b, a = np.polyfit(xs, ys, 1) 450 451 if n < 10: 452 return a, b 453 else: 454 # Refine this by throwing out outliers, according to Cook's distance. 455 # https://en.wikipedia.org/wiki/Cook%27s_distance 456 sum_x = sum(xs) 457 sum_x2 = sum(xs**2) 458 errs = a + b * xs - ys 459 s2 = sum(errs**2) / (n - 2) 460 if s2 == 0: 461 # It's an exact fit! 462 return a, b 463 h = (sum_x2 - 2 * sum_x * xs + n * xs**2) / (n * sum_x2 - sum_x**2) 464 cook_ds = 0.5 / s2 * errs**2 * (h / (1 - h)**2) 465 466 # Re-compute the regression, excluding those points with Cook's distance 467 # greater than 0.5, and weighting by the inverse of x to give a more 468 # stable y-intercept (as small batches have relatively more information 469 # about the fixed overhead). 470 weight = (cook_ds <= 0.5) / xs 471 b, a = np.polyfit(xs, ys, 1, w=weight) 472 return a, b 473 474 try: 475 # pylint: disable=wrong-import-order, wrong-import-position 476 import numpy as np 477 linear_regression = linear_regression_numpy 478 except ImportError: 479 linear_regression = linear_regression_no_numpy 480 481 def _calculate_next_batch_size(self): 482 if self._min_batch_size == self._max_batch_size: 483 return self._min_batch_size 484 elif len(self._data) < 1: 485 return self._min_batch_size 486 elif len(self._data) < 2: 487 # Force some variety so we have distinct batch sizes on which to do 488 # linear regression below. 489 return int( 490 max( 491 min( 492 self._max_batch_size, 493 self._min_batch_size * self._MAX_GROWTH_FACTOR), 494 self._min_batch_size + 1)) 495 496 # There tends to be a lot of noise in the top quantile, which also 497 # has outsided influence in the regression. If we have enough data, 498 # Simply declare the top 20% to be outliers. 499 trimmed_data = sorted(self._data)[:max(20, len(self._data) * 4 // 5)] 500 501 # Linear regression for y = a + bx, where x is batch size and y is time. 502 xs, ys = zip(*trimmed_data) 503 a, b = self.linear_regression(xs, ys) 504 505 # Avoid nonsensical or division-by-zero errors below due to noise. 506 a = max(a, 1e-10) 507 b = max(b, 1e-20) 508 509 last_batch_size = self._data[-1][0] 510 cap = min(last_batch_size * self._MAX_GROWTH_FACTOR, self._max_batch_size) 511 512 target = self._max_batch_size 513 514 if self._target_batch_duration_secs_including_fixed_cost: 515 # Solution to 516 # a + b*x = self._target_batch_duration_secs_including_fixed_cost. 517 target = min( 518 target, 519 (self._target_batch_duration_secs_including_fixed_cost - a) / b) 520 521 if self._target_batch_duration_secs: 522 # Solution to b*x = self._target_batch_duration_secs. 523 # We ignore the fixed cost in this computation as it has negligeabel 524 # impact when it is small and unhelpfully forces the minimum batch size 525 # when it is large. 526 target = min(target, self._target_batch_duration_secs / b) 527 528 if self._target_batch_overhead: 529 # Solution to a / (a + b*x) = self._target_batch_overhead. 530 target = min(target, (a / b) * (1 / self._target_batch_overhead - 1)) 531 532 # Avoid getting stuck at a single batch size (especially the minimal 533 # batch size) which would not allow us to extrapolate to other batch 534 # sizes. 535 # Jitter alternates between 0 and 1. 536 jitter = len(self._data) % 2 537 # Smear our samples across a range centered at the target. 538 if len(self._data) > 10: 539 target += int(target * self._variance * 2 * (random.random() - .5)) 540 541 return int(max(self._min_batch_size + jitter, min(target, cap))) 542 543 def next_batch_size(self): 544 # Check if we should replay a previous batch size due to it not being 545 # recorded. 546 if self._replay_last_batch_size: 547 result = self._replay_last_batch_size 548 self._replay_last_batch_size = None 549 else: 550 result = self._calculate_next_batch_size() 551 552 seen_count = self._batch_size_num_seen.get(result, 0) + 1 553 if seen_count <= self._ignore_first_n_seen_per_batch_size: 554 self.ignore_next_timing() 555 self._batch_size_num_seen[result] = seen_count 556 return result 557 558 def stats(self): 559 return "element_count=%s batch_count=%s next_batch_size=%s timings=%s" % ( 560 self._element_count, 561 self._batch_count, 562 self._calculate_next_batch_size(), 563 self._data) 564 565 566 class _GlobalWindowsBatchingDoFn(DoFn): 567 def __init__(self, batch_size_estimator, element_size_fn): 568 self._batch_size_estimator = batch_size_estimator 569 self._element_size_fn = element_size_fn 570 571 def start_bundle(self): 572 self._batch = [] 573 self._running_batch_size = 0 574 self._target_batch_size = self._batch_size_estimator.next_batch_size() 575 # The first emit often involves non-trivial setup. 576 self._batch_size_estimator.ignore_next_timing() 577 578 def process(self, element): 579 self._batch.append(element) 580 self._running_batch_size += self._element_size_fn(element) 581 if self._running_batch_size >= self._target_batch_size: 582 with self._batch_size_estimator.record_time(self._running_batch_size): 583 yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch) 584 self._batch = [] 585 self._running_batch_size = 0 586 self._target_batch_size = self._batch_size_estimator.next_batch_size() 587 588 def finish_bundle(self): 589 if self._batch: 590 with self._batch_size_estimator.record_time(self._running_batch_size): 591 yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch) 592 self._batch = None 593 self._running_batch_size = 0 594 self._target_batch_size = self._batch_size_estimator.next_batch_size() 595 logging.info( 596 "BatchElements statistics: " + self._batch_size_estimator.stats()) 597 598 599 class _SizedBatch(): 600 def __init__(self): 601 self.elements = [] 602 self.size = 0 603 604 605 class _WindowAwareBatchingDoFn(DoFn): 606 607 _MAX_LIVE_WINDOWS = 10 608 609 def __init__(self, batch_size_estimator, element_size_fn): 610 self._batch_size_estimator = batch_size_estimator 611 self._element_size_fn = element_size_fn 612 613 def start_bundle(self): 614 self._batches = collections.defaultdict(_SizedBatch) 615 self._target_batch_size = self._batch_size_estimator.next_batch_size() 616 # The first emit often involves non-trivial setup. 617 self._batch_size_estimator.ignore_next_timing() 618 619 def process(self, element, window=DoFn.WindowParam): 620 batch = self._batches[window] 621 batch.elements.append(element) 622 batch.size += self._element_size_fn(element) 623 if batch.size >= self._target_batch_size: 624 with self._batch_size_estimator.record_time(batch.size): 625 yield windowed_value.WindowedValue( 626 batch.elements, window.max_timestamp(), (window, )) 627 del self._batches[window] 628 self._target_batch_size = self._batch_size_estimator.next_batch_size() 629 elif len(self._batches) > self._MAX_LIVE_WINDOWS: 630 window, batch = max( 631 self._batches.items(), 632 key=lambda window_batch: window_batch[1].size) 633 with self._batch_size_estimator.record_time(batch.size): 634 yield windowed_value.WindowedValue( 635 batch.elements, window.max_timestamp(), (window, )) 636 del self._batches[window] 637 self._target_batch_size = self._batch_size_estimator.next_batch_size() 638 639 def finish_bundle(self): 640 for window, batch in self._batches.items(): 641 if batch: 642 with self._batch_size_estimator.record_time(batch.size): 643 yield windowed_value.WindowedValue( 644 batch.elements, window.max_timestamp(), (window, )) 645 self._batches = None 646 self._target_batch_size = self._batch_size_estimator.next_batch_size() 647 648 649 @typehints.with_input_types(T) 650 @typehints.with_output_types(List[T]) 651 class BatchElements(PTransform): 652 """A Transform that batches elements for amortized processing. 653 654 This transform is designed to precede operations whose processing cost 655 is of the form 656 657 time = fixed_cost + num_elements * per_element_cost 658 659 where the per element cost is (often significantly) smaller than the fixed 660 cost and could be amortized over multiple elements. It consumes a PCollection 661 of element type T and produces a PCollection of element type List[T]. 662 663 This transform attempts to find the best batch size between the minimim 664 and maximum parameters by profiling the time taken by (fused) downstream 665 operations. For a fixed batch size, set the min and max to be equal. 666 667 Elements are batched per-window and batches emitted in the window 668 corresponding to its contents. 669 670 Args: 671 min_batch_size: (optional) the smallest size of a batch 672 max_batch_size: (optional) the largest size of a batch 673 target_batch_overhead: (optional) a target for fixed_cost / time, 674 as used in the formula above 675 target_batch_duration_secs: (optional) a target for total time per bundle, 676 in seconds, excluding fixed cost 677 target_batch_duration_secs_including_fixed_cost: (optional) a target for 678 total time per bundle, in seconds, including fixed cost 679 element_size_fn: (optional) A mapping of an element to its contribution to 680 batch size, defaulting to every element having size 1. When provided, 681 attempts to provide batches of optimal total size which may consist of 682 a varying number of elements. 683 variance: (optional) the permitted (relative) amount of deviation from the 684 (estimated) ideal batch size used to produce a wider base for 685 linear interpolation 686 clock: (optional) an alternative to time.time for measuring the cost of 687 donwstream operations (mostly for testing) 688 record_metrics: (optional) whether or not to record beam metrics on 689 distributions of the batch size. Defaults to True. 690 """ 691 def __init__( 692 self, 693 min_batch_size=1, 694 max_batch_size=10000, 695 target_batch_overhead=.05, 696 target_batch_duration_secs=10, 697 target_batch_duration_secs_including_fixed_cost=None, 698 *, 699 element_size_fn=lambda x: 1, 700 variance=0.25, 701 clock=time.time, 702 record_metrics=True): 703 self._batch_size_estimator = _BatchSizeEstimator( 704 min_batch_size=min_batch_size, 705 max_batch_size=max_batch_size, 706 target_batch_overhead=target_batch_overhead, 707 target_batch_duration_secs=target_batch_duration_secs, 708 target_batch_duration_secs_including_fixed_cost=( 709 target_batch_duration_secs_including_fixed_cost), 710 variance=variance, 711 clock=clock, 712 record_metrics=record_metrics) 713 self._element_size_fn = element_size_fn 714 715 def expand(self, pcoll): 716 if getattr(pcoll.pipeline.runner, 'is_streaming', False): 717 raise NotImplementedError("Requires stateful processing (BEAM-2687)") 718 elif pcoll.windowing.is_default(): 719 # This is the same logic as _GlobalWindowsBatchingDoFn, but optimized 720 # for that simpler case. 721 return pcoll | ParDo( 722 _GlobalWindowsBatchingDoFn( 723 self._batch_size_estimator, self._element_size_fn)) 724 else: 725 return pcoll | ParDo( 726 _WindowAwareBatchingDoFn( 727 self._batch_size_estimator, self._element_size_fn)) 728 729 730 class _IdentityWindowFn(NonMergingWindowFn): 731 """Windowing function that preserves existing windows. 732 733 To be used internally with the Reshuffle transform. 734 Will raise an exception when used after DoFns that return TimestampedValue 735 elements. 736 """ 737 def __init__(self, window_coder): 738 """Create a new WindowFn with compatible coder. 739 To be applied to PCollections with windows that are compatible with the 740 given coder. 741 742 Arguments: 743 window_coder: coders.Coder object to be used on windows. 744 """ 745 super().__init__() 746 if window_coder is None: 747 raise ValueError('window_coder should not be None') 748 self._window_coder = window_coder 749 750 def assign(self, assign_context): 751 if assign_context.window is None: 752 raise ValueError( 753 'assign_context.window should not be None. ' 754 'This might be due to a DoFn returning a TimestampedValue.') 755 return [assign_context.window] 756 757 def get_window_coder(self): 758 return self._window_coder 759 760 761 @typehints.with_input_types(Tuple[K, V]) 762 @typehints.with_output_types(Tuple[K, V]) 763 class ReshufflePerKey(PTransform): 764 """PTransform that returns a PCollection equivalent to its input, 765 but operationally provides some of the side effects of a GroupByKey, 766 in particular checkpointing, and preventing fusion of the surrounding 767 transforms. 768 """ 769 def expand(self, pcoll): 770 windowing_saved = pcoll.windowing 771 if windowing_saved.is_default(): 772 # In this (common) case we can use a trivial trigger driver 773 # and avoid the (expensive) window param. 774 globally_windowed = window.GlobalWindows.windowed_value(None) 775 MIN_TIMESTAMP = window.MIN_TIMESTAMP 776 777 def reify_timestamps(element, timestamp=DoFn.TimestampParam): 778 key, value = element 779 if timestamp == MIN_TIMESTAMP: 780 timestamp = None 781 return key, (value, timestamp) 782 783 def restore_timestamps(element): 784 key, values = element 785 return [ 786 globally_windowed.with_value((key, value)) if timestamp is None else 787 window.GlobalWindows.windowed_value((key, value), timestamp) 788 for (value, timestamp) in values 789 ] 790 else: 791 792 # typing: All conditional function variants must have identical signatures 793 def reify_timestamps( # type: ignore[misc] 794 element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): 795 key, value = element 796 # Transport the window as part of the value and restore it later. 797 return key, windowed_value.WindowedValue(value, timestamp, [window]) 798 799 def restore_timestamps(element): 800 key, windowed_values = element 801 return [wv.with_value((key, wv.value)) for wv in windowed_values] 802 803 ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) 804 805 # TODO(https://github.com/apache/beam/issues/19785) Using global window as 806 # one of the standard window. This is to mitigate the Dataflow Java Runner 807 # Harness limitation to accept only standard coders. 808 ungrouped._windowing = Windowing( 809 window.GlobalWindows(), 810 triggerfn=Always(), 811 accumulation_mode=AccumulationMode.DISCARDING, 812 timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST) 813 result = ( 814 ungrouped 815 | GroupByKey() 816 | FlatMap(restore_timestamps).with_output_types(Any)) 817 result._windowing = windowing_saved 818 return result 819 820 821 @typehints.with_input_types(T) 822 @typehints.with_output_types(T) 823 class Reshuffle(PTransform): 824 """PTransform that returns a PCollection equivalent to its input, 825 but operationally provides some of the side effects of a GroupByKey, 826 in particular checkpointing, and preventing fusion of the surrounding 827 transforms. 828 829 Reshuffle adds a temporary random key to each element, performs a 830 ReshufflePerKey, and finally removes the temporary key. 831 """ 832 833 # We use 32-bit integer as the default number of buckets. 834 _DEFAULT_NUM_BUCKETS = 1 << 32 835 836 def __init__(self, num_buckets=None): 837 """ 838 :param num_buckets: If set, specifies the maximum random keys that would be 839 generated. 840 """ 841 self.num_buckets = num_buckets if num_buckets else self._DEFAULT_NUM_BUCKETS 842 843 valid_buckets = isinstance(num_buckets, int) and num_buckets > 0 844 if not (num_buckets is None or valid_buckets): 845 raise ValueError( 846 'If `num_buckets` is set, it has to be an ' 847 'integer greater than 0, got %s' % num_buckets) 848 849 def expand(self, pcoll): 850 # type: (pvalue.PValue) -> pvalue.PCollection 851 return ( 852 pcoll | 'AddRandomKeys' >> 853 Map(lambda t: (random.randrange(0, self.num_buckets), t) 854 ).with_input_types(T).with_output_types(Tuple[int, T]) 855 | ReshufflePerKey() 856 | 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types( 857 Tuple[int, T]).with_output_types(T)) 858 859 def to_runner_api_parameter(self, unused_context): 860 # type: (PipelineContext) -> Tuple[str, None] 861 return common_urns.composites.RESHUFFLE.urn, None 862 863 @staticmethod 864 @PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None) 865 def from_runner_api_parameter( 866 unused_ptransform, unused_parameter, unused_context): 867 return Reshuffle() 868 869 870 def fn_takes_side_inputs(fn): 871 fn = getattr(fn, '_argspec_fn', fn) 872 try: 873 signature = get_signature(fn) 874 except TypeError: 875 # We can't tell; maybe it does. 876 return True 877 878 return ( 879 len(signature.parameters) > 1 or any( 880 p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD 881 for p in signature.parameters.values())) 882 883 884 @ptransform_fn 885 def WithKeys(pcoll, k, *args, **kwargs): 886 """PTransform that takes a PCollection, and either a constant key or a 887 callable, and returns a PCollection of (K, V), where each of the values in 888 the input PCollection has been paired with either the constant key or a key 889 computed from the value. The callable may optionally accept positional or 890 keyword arguments, which should be passed to WithKeys directly. These may 891 be either SideInputs or static (non-PCollection) values, such as ints. 892 """ 893 if callable(k): 894 if fn_takes_side_inputs(k): 895 if all(isinstance(arg, AsSideInput) 896 for arg in args) and all(isinstance(kwarg, AsSideInput) 897 for kwarg in kwargs.values()): 898 return pcoll | Map( 899 lambda v, 900 *args, 901 **kwargs: (k(v, *args, **kwargs), v), 902 *args, 903 **kwargs) 904 return pcoll | Map(lambda v: (k(v, *args, **kwargs), v)) 905 return pcoll | Map(lambda v: (k(v), v)) 906 return pcoll | Map(lambda v: (k, v)) 907 908 909 @typehints.with_input_types(Tuple[K, V]) 910 @typehints.with_output_types(Tuple[K, Iterable[V]]) 911 class GroupIntoBatches(PTransform): 912 """PTransform that batches the input into desired batch size. Elements are 913 buffered until they are equal to batch size provided in the argument at which 914 point they are output to the output Pcollection. 915 916 Windows are preserved (batches will contain elements from the same window) 917 """ 918 def __init__( 919 self, batch_size, max_buffering_duration_secs=None, clock=time.time): 920 """Create a new GroupIntoBatches. 921 922 Arguments: 923 batch_size: (required) How many elements should be in a batch 924 max_buffering_duration_secs: (optional) How long in seconds at most an 925 incomplete batch of elements is allowed to be buffered in the states. 926 The duration must be a positive second duration and should be given as 927 an int or float. Setting this parameter to zero effectively means no 928 buffering limit. 929 clock: (optional) an alternative to time.time (mostly for testing) 930 """ 931 self.params = _GroupIntoBatchesParams( 932 batch_size, max_buffering_duration_secs) 933 self.clock = clock 934 935 def expand(self, pcoll): 936 input_coder = coders.registry.get_coder(pcoll) 937 return pcoll | ParDo( 938 _pardo_group_into_batches( 939 input_coder, 940 self.params.batch_size, 941 self.params.max_buffering_duration_secs, 942 self.clock)) 943 944 def to_runner_api_parameter( 945 self, 946 unused_context # type: PipelineContext 947 ): # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload] 948 return ( 949 common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn, 950 self.params.get_payload()) 951 952 @staticmethod 953 @PTransform.register_urn( 954 common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn, 955 beam_runner_api_pb2.GroupIntoBatchesPayload) 956 def from_runner_api_parameter(unused_ptransform, proto, unused_context): 957 return GroupIntoBatches(*_GroupIntoBatchesParams.parse_payload(proto)) 958 959 @typehints.with_input_types(Tuple[K, V]) 960 @typehints.with_output_types( 961 typehints.Tuple[ 962 ShardedKeyType[typehints.TypeVariable(K)], # type: ignore[misc] 963 typehints.Iterable[typehints.TypeVariable(V)]]) 964 class WithShardedKey(PTransform): 965 """A GroupIntoBatches transform that outputs batched elements associated 966 with sharded input keys. 967 968 By default, keys are sharded to such that the input elements with the same 969 key are spread to all available threads executing the transform. Runners may 970 override the default sharding to do a better load balancing during the 971 execution time. 972 """ 973 def __init__( 974 self, batch_size, max_buffering_duration_secs=None, clock=time.time): 975 """Create a new GroupIntoBatches with sharded output. 976 See ``GroupIntoBatches`` transform for a description of input parameters. 977 """ 978 self.params = _GroupIntoBatchesParams( 979 batch_size, max_buffering_duration_secs) 980 self.clock = clock 981 982 _shard_id_prefix = uuid.uuid4().bytes 983 984 def expand(self, pcoll): 985 key_type, value_type = pcoll.element_type.tuple_types 986 sharded_pcoll = pcoll | Map( 987 lambda key_value: ( 988 ShardedKey( 989 key_value[0], 990 # Use [uuid, thread id] as the shard id. 991 GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes( 992 threading.get_ident().to_bytes(8, 'big'))), 993 key_value[1])).with_output_types( 994 typehints.Tuple[ 995 ShardedKeyType[key_type], # type: ignore[misc] 996 value_type]) 997 return ( 998 sharded_pcoll 999 | GroupIntoBatches( 1000 self.params.batch_size, 1001 self.params.max_buffering_duration_secs, 1002 self.clock)) 1003 1004 def to_runner_api_parameter( 1005 self, 1006 unused_context # type: PipelineContext 1007 ): # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload] 1008 return ( 1009 common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn, 1010 self.params.get_payload()) 1011 1012 @staticmethod 1013 @PTransform.register_urn( 1014 common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn, 1015 beam_runner_api_pb2.GroupIntoBatchesPayload) 1016 def from_runner_api_parameter(unused_ptransform, proto, unused_context): 1017 return GroupIntoBatches.WithShardedKey( 1018 *_GroupIntoBatchesParams.parse_payload(proto)) 1019 1020 1021 class _GroupIntoBatchesParams: 1022 """This class represents the parameters for 1023 :class:`apache_beam.utils.GroupIntoBatches` transform, used to define how 1024 elements should be batched. 1025 """ 1026 def __init__(self, batch_size, max_buffering_duration_secs): 1027 self.batch_size = batch_size 1028 self.max_buffering_duration_secs = ( 1029 0 1030 if max_buffering_duration_secs is None else max_buffering_duration_secs) 1031 self._validate() 1032 1033 def __eq__(self, other): 1034 if other is None or not isinstance(other, _GroupIntoBatchesParams): 1035 return False 1036 return ( 1037 self.batch_size == other.batch_size and 1038 self.max_buffering_duration_secs == other.max_buffering_duration_secs) 1039 1040 def _validate(self): 1041 assert self.batch_size is not None and self.batch_size > 0, ( 1042 'batch_size must be a positive value') 1043 assert ( 1044 self.max_buffering_duration_secs is not None and 1045 self.max_buffering_duration_secs >= 0), ( 1046 'max_buffering_duration must be a non-negative value') 1047 1048 def get_payload(self): 1049 return beam_runner_api_pb2.GroupIntoBatchesPayload( 1050 batch_size=self.batch_size, 1051 max_buffering_duration_millis=int( 1052 self.max_buffering_duration_secs * 1000)) 1053 1054 @staticmethod 1055 def parse_payload( 1056 proto # type: beam_runner_api_pb2.GroupIntoBatchesPayload 1057 ): 1058 return proto.batch_size, proto.max_buffering_duration_millis / 1000 1059 1060 1061 def _pardo_group_into_batches( 1062 input_coder, batch_size, max_buffering_duration_secs, clock=time.time): 1063 ELEMENT_STATE = BagStateSpec('values', input_coder) 1064 COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn()) 1065 WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK) 1066 BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME) 1067 1068 class _GroupIntoBatchesDoFn(DoFn): 1069 def process( 1070 self, 1071 element, 1072 window=DoFn.WindowParam, 1073 element_state=DoFn.StateParam(ELEMENT_STATE), 1074 count_state=DoFn.StateParam(COUNT_STATE), 1075 window_timer=DoFn.TimerParam(WINDOW_TIMER), 1076 buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): 1077 # Allowed lateness not supported in Python SDK 1078 # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data 1079 window_timer.set(window.end) 1080 element_state.add(element) 1081 count_state.add(1) 1082 count = count_state.read() 1083 if count == 1 and max_buffering_duration_secs > 0: 1084 # This is the first element in batch. Start counting buffering time if a 1085 # limit was set. 1086 # pylint: disable=deprecated-method 1087 buffering_timer.set(clock() + max_buffering_duration_secs) 1088 if count >= batch_size: 1089 return self.flush_batch(element_state, count_state, buffering_timer) 1090 1091 @on_timer(WINDOW_TIMER) 1092 def on_window_timer( 1093 self, 1094 element_state=DoFn.StateParam(ELEMENT_STATE), 1095 count_state=DoFn.StateParam(COUNT_STATE), 1096 buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): 1097 return self.flush_batch(element_state, count_state, buffering_timer) 1098 1099 @on_timer(BUFFERING_TIMER) 1100 def on_buffering_timer( 1101 self, 1102 element_state=DoFn.StateParam(ELEMENT_STATE), 1103 count_state=DoFn.StateParam(COUNT_STATE), 1104 buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): 1105 return self.flush_batch(element_state, count_state, buffering_timer) 1106 1107 def flush_batch(self, element_state, count_state, buffering_timer): 1108 batch = [element for element in element_state.read()] 1109 if not batch: 1110 return 1111 key, _ = batch[0] 1112 batch_values = [v for (k, v) in batch] 1113 element_state.clear() 1114 count_state.clear() 1115 buffering_timer.clear() 1116 yield key, batch_values 1117 1118 return _GroupIntoBatchesDoFn() 1119 1120 1121 class ToString(object): 1122 """ 1123 PTransform for converting a PCollection element, KV or PCollection Iterable 1124 to string. 1125 """ 1126 1127 # pylint: disable=invalid-name 1128 @staticmethod 1129 def Element(): 1130 """ 1131 Transforms each element of the PCollection to a string. 1132 """ 1133 return 'ElementToString' >> Map(str) 1134 1135 @staticmethod 1136 def Iterables(delimiter=None): 1137 """ 1138 Transforms each item in the iterable of the input of PCollection to a 1139 string. There is no trailing delimiter. 1140 """ 1141 if delimiter is None: 1142 delimiter = ',' 1143 return ( 1144 'IterablesToString' >> 1145 Map(lambda xs: delimiter.join(str(x) for x in xs)).with_input_types( 1146 Iterable[Any]).with_output_types(str)) 1147 1148 # An alias for Iterables. 1149 Kvs = Iterables 1150 1151 1152 @typehints.with_input_types(T) 1153 @typehints.with_output_types(T) 1154 class LogElements(PTransform): 1155 """ 1156 PTransform for printing the elements of a PCollection. 1157 """ 1158 class _LoggingFn(DoFn): 1159 def __init__(self, prefix='', with_timestamp=False, with_window=False): 1160 super().__init__() 1161 self.prefix = prefix 1162 self.with_timestamp = with_timestamp 1163 self.with_window = with_window 1164 1165 def process( 1166 self, 1167 element, 1168 timestamp=DoFn.TimestampParam, 1169 window=DoFn.WindowParam, 1170 **kwargs): 1171 log_line = self.prefix + str(element) 1172 1173 if self.with_timestamp: 1174 log_line += ', timestamp=' + repr(timestamp.to_rfc3339()) 1175 1176 if self.with_window: 1177 log_line += ', window(start=' + window.start.to_rfc3339() 1178 log_line += ', end=' + window.end.to_rfc3339() + ')' 1179 1180 print(log_line) 1181 yield element 1182 1183 def __init__( 1184 self, label=None, prefix='', with_timestamp=False, with_window=False): 1185 super().__init__(label) 1186 self.prefix = prefix 1187 self.with_timestamp = with_timestamp 1188 self.with_window = with_window 1189 1190 def expand(self, input): 1191 return input | ParDo( 1192 self._LoggingFn(self.prefix, self.with_timestamp, self.with_window)) 1193 1194 1195 class Reify(object): 1196 """PTransforms for converting between explicit and implicit form of various 1197 Beam values.""" 1198 @typehints.with_input_types(T) 1199 @typehints.with_output_types(T) 1200 class Timestamp(PTransform): 1201 """PTransform to wrap a value in a TimestampedValue with it's 1202 associated timestamp.""" 1203 @staticmethod 1204 def add_timestamp_info(element, timestamp=DoFn.TimestampParam): 1205 yield TimestampedValue(element, timestamp) 1206 1207 def expand(self, pcoll): 1208 return pcoll | ParDo(self.add_timestamp_info) 1209 1210 @typehints.with_input_types(T) 1211 @typehints.with_output_types(T) 1212 class Window(PTransform): 1213 """PTransform to convert an element in a PCollection into a tuple of 1214 (element, timestamp, window), wrapped in a TimestampedValue with it's 1215 associated timestamp.""" 1216 @staticmethod 1217 def add_window_info( 1218 element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): 1219 yield TimestampedValue((element, timestamp, window), timestamp) 1220 1221 def expand(self, pcoll): 1222 return pcoll | ParDo(self.add_window_info) 1223 1224 @typehints.with_input_types(Tuple[K, V]) 1225 @typehints.with_output_types(Tuple[K, V]) 1226 class TimestampInValue(PTransform): 1227 """PTransform to wrap the Value in a KV pair in a TimestampedValue with 1228 the element's associated timestamp.""" 1229 @staticmethod 1230 def add_timestamp_info(element, timestamp=DoFn.TimestampParam): 1231 key, value = element 1232 yield (key, TimestampedValue(value, timestamp)) 1233 1234 def expand(self, pcoll): 1235 return pcoll | ParDo(self.add_timestamp_info) 1236 1237 @typehints.with_input_types(Tuple[K, V]) 1238 @typehints.with_output_types(Tuple[K, V]) 1239 class WindowInValue(PTransform): 1240 """PTransform to convert the Value in a KV pair into a tuple of 1241 (value, timestamp, window), with the whole element being wrapped inside a 1242 TimestampedValue.""" 1243 @staticmethod 1244 def add_window_info( 1245 element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): 1246 key, value = element 1247 yield TimestampedValue((key, (value, timestamp, window)), timestamp) 1248 1249 def expand(self, pcoll): 1250 return pcoll | ParDo(self.add_window_info) 1251 1252 1253 class Regex(object): 1254 """ 1255 PTransform to use Regular Expression to process the elements in a 1256 PCollection. 1257 """ 1258 1259 ALL = "__regex_all_groups" 1260 1261 @staticmethod 1262 def _regex_compile(regex): 1263 """Return re.compile if the regex has a string value""" 1264 if isinstance(regex, str): 1265 regex = re.compile(regex) 1266 return regex 1267 1268 @staticmethod 1269 @typehints.with_input_types(str) 1270 @typehints.with_output_types(str) 1271 @ptransform_fn 1272 def matches(pcoll, regex, group=0): 1273 """ 1274 Returns the matches (group 0 by default) if zero or more characters at the 1275 beginning of string match the regular expression. To match the entire 1276 string, add "$" sign at the end of regex expression. 1277 1278 Group can be integer value or a string value. 1279 1280 Args: 1281 regex: the regular expression string or (re.compile) pattern. 1282 group: (optional) name/number of the group, it can be integer or a string 1283 value. Defaults to 0, meaning the entire matched string will be 1284 returned. 1285 """ 1286 regex = Regex._regex_compile(regex) 1287 1288 def _process(element): 1289 m = regex.match(element) 1290 if m: 1291 yield m.group(group) 1292 1293 return pcoll | FlatMap(_process) 1294 1295 @staticmethod 1296 @typehints.with_input_types(str) 1297 @typehints.with_output_types(List[str]) 1298 @ptransform_fn 1299 def all_matches(pcoll, regex): 1300 """ 1301 Returns all matches (groups) if zero or more characters at the beginning 1302 of string match the regular expression. 1303 1304 Args: 1305 regex: the regular expression string or (re.compile) pattern. 1306 """ 1307 regex = Regex._regex_compile(regex) 1308 1309 def _process(element): 1310 m = regex.match(element) 1311 if m: 1312 yield [m.group(ix) for ix in range(m.lastindex + 1)] 1313 1314 return pcoll | FlatMap(_process) 1315 1316 @staticmethod 1317 @typehints.with_input_types(str) 1318 @typehints.with_output_types(Tuple[str, str]) 1319 @ptransform_fn 1320 def matches_kv(pcoll, regex, keyGroup, valueGroup=0): 1321 """ 1322 Returns the KV pairs if the string matches the regular expression, deriving 1323 the key & value from the specified group of the regular expression. 1324 1325 Args: 1326 regex: the regular expression string or (re.compile) pattern. 1327 keyGroup: The Regex group to use as the key. Can be int or str. 1328 valueGroup: (optional) Regex group to use the value. Can be int or str. 1329 The default value "0" returns entire matched string. 1330 """ 1331 regex = Regex._regex_compile(regex) 1332 1333 def _process(element): 1334 match = regex.match(element) 1335 if match: 1336 yield (match.group(keyGroup), match.group(valueGroup)) 1337 1338 return pcoll | FlatMap(_process) 1339 1340 @staticmethod 1341 @typehints.with_input_types(str) 1342 @typehints.with_output_types(str) 1343 @ptransform_fn 1344 def find(pcoll, regex, group=0): 1345 """ 1346 Returns the matches if a portion of the line matches the Regex. Returns 1347 the entire group (group 0 by default). Group can be integer value or a 1348 string value. 1349 1350 Args: 1351 regex: the regular expression string or (re.compile) pattern. 1352 group: (optional) name of the group, it can be integer or a string value. 1353 """ 1354 regex = Regex._regex_compile(regex) 1355 1356 def _process(element): 1357 r = regex.search(element) 1358 if r: 1359 yield r.group(group) 1360 1361 return pcoll | FlatMap(_process) 1362 1363 @staticmethod 1364 @typehints.with_input_types(str) 1365 @typehints.with_output_types(Union[List[str], List[Tuple[str, str]]]) 1366 @ptransform_fn 1367 def find_all(pcoll, regex, group=0, outputEmpty=True): 1368 """ 1369 Returns the matches if a portion of the line matches the Regex. By default, 1370 list of group 0 will return with empty items. To get all groups, pass the 1371 `Regex.ALL` flag in the `group` parameter which returns all the groups in 1372 the tuple format. 1373 1374 Args: 1375 regex: the regular expression string or (re.compile) pattern. 1376 group: (optional) name of the group, it can be integer or a string value. 1377 outputEmpty: (optional) Should empty be output. True to output empties 1378 and false if not. 1379 """ 1380 regex = Regex._regex_compile(regex) 1381 1382 def _process(element): 1383 matches = regex.finditer(element) 1384 if group == Regex.ALL: 1385 yield [(m.group(), m.groups()[0]) for m in matches 1386 if outputEmpty or m.groups()[0]] 1387 else: 1388 yield [m.group(group) for m in matches if outputEmpty or m.group(group)] 1389 1390 return pcoll | FlatMap(_process) 1391 1392 @staticmethod 1393 @typehints.with_input_types(str) 1394 @typehints.with_output_types(Tuple[str, str]) 1395 @ptransform_fn 1396 def find_kv(pcoll, regex, keyGroup, valueGroup=0): 1397 """ 1398 Returns the matches if a portion of the line matches the Regex. Returns the 1399 specified groups as the key and value pair. 1400 1401 Args: 1402 regex: the regular expression string or (re.compile) pattern. 1403 keyGroup: The Regex group to use as the key. Can be int or str. 1404 valueGroup: (optional) Regex group to use the value. Can be int or str. 1405 The default value "0" returns entire matched string. 1406 """ 1407 regex = Regex._regex_compile(regex) 1408 1409 def _process(element): 1410 matches = regex.finditer(element) 1411 if matches: 1412 for match in matches: 1413 yield (match.group(keyGroup), match.group(valueGroup)) 1414 1415 return pcoll | FlatMap(_process) 1416 1417 @staticmethod 1418 @typehints.with_input_types(str) 1419 @typehints.with_output_types(str) 1420 @ptransform_fn 1421 def replace_all(pcoll, regex, replacement): 1422 """ 1423 Returns the matches if a portion of the line matches the regex and 1424 replaces all matches with the replacement string. 1425 1426 Args: 1427 regex: the regular expression string or (re.compile) pattern. 1428 replacement: the string to be substituted for each match. 1429 """ 1430 regex = Regex._regex_compile(regex) 1431 return pcoll | Map(lambda elem: regex.sub(replacement, elem)) 1432 1433 @staticmethod 1434 @typehints.with_input_types(str) 1435 @typehints.with_output_types(str) 1436 @ptransform_fn 1437 def replace_first(pcoll, regex, replacement): 1438 """ 1439 Returns the matches if a portion of the line matches the regex and replaces 1440 the first match with the replacement string. 1441 1442 Args: 1443 regex: the regular expression string or (re.compile) pattern. 1444 replacement: the string to be substituted for each match. 1445 """ 1446 regex = Regex._regex_compile(regex) 1447 return pcoll | Map(lambda elem: regex.sub(replacement, elem, 1)) 1448 1449 @staticmethod 1450 @typehints.with_input_types(str) 1451 @typehints.with_output_types(List[str]) 1452 @ptransform_fn 1453 def split(pcoll, regex, outputEmpty=False): 1454 """ 1455 Returns the list string which was splitted on the basis of regular 1456 expression. It will not output empty items (by defaults). 1457 1458 Args: 1459 regex: the regular expression string or (re.compile) pattern. 1460 outputEmpty: (optional) Should empty be output. True to output empties 1461 and false if not. 1462 """ 1463 regex = Regex._regex_compile(regex) 1464 outputEmpty = bool(outputEmpty) 1465 1466 def _process(element): 1467 r = regex.split(element) 1468 if r and not outputEmpty: 1469 r = list(filter(None, r)) 1470 yield r 1471 1472 return pcoll | FlatMap(_process)