github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/iobase.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Sources and sinks. 19 20 A Source manages record-oriented data input from a particular kind of source 21 (e.g. a set of files, a database table, etc.). The reader() method of a source 22 returns a reader object supporting the iterator protocol; iteration yields 23 raw records of unprocessed, serialized data. 24 25 26 A Sink manages record-oriented data output to a particular kind of sink 27 (e.g. a set of files, a database table, etc.). The writer() method of a sink 28 returns a writer object supporting writing records of serialized data to 29 the sink. 30 """ 31 32 # pytype: skip-file 33 34 import logging 35 import math 36 import random 37 import uuid 38 from collections import namedtuple 39 from typing import Any 40 from typing import Iterator 41 from typing import Optional 42 from typing import Tuple 43 from typing import Union 44 45 from apache_beam import coders 46 from apache_beam import pvalue 47 from apache_beam.coders.coders import _MemoizingPickleCoder 48 from apache_beam.internal import pickler 49 from apache_beam.portability import common_urns 50 from apache_beam.portability import python_urns 51 from apache_beam.portability.api import beam_runner_api_pb2 52 from apache_beam.pvalue import AsIter 53 from apache_beam.pvalue import AsSingleton 54 from apache_beam.transforms import Impulse 55 from apache_beam.transforms import PTransform 56 from apache_beam.transforms import core 57 from apache_beam.transforms import ptransform 58 from apache_beam.transforms import window 59 from apache_beam.transforms.display import DisplayDataItem 60 from apache_beam.transforms.display import HasDisplayData 61 from apache_beam.utils import timestamp 62 from apache_beam.utils import urns 63 from apache_beam.utils.windowed_value import WindowedValue 64 65 __all__ = [ 66 'BoundedSource', 67 'RangeTracker', 68 'Read', 69 'RestrictionProgress', 70 'RestrictionTracker', 71 'WatermarkEstimator', 72 'Sink', 73 'Write', 74 'Writer' 75 ] 76 77 _LOGGER = logging.getLogger(__name__) 78 79 # Encapsulates information about a bundle of a source generated when method 80 # BoundedSource.split() is invoked. 81 # This is a named 4-tuple that has following fields. 82 # * weight - a number that represents the size of the bundle. This value will 83 # be used to compare the relative sizes of bundles generated by the 84 # current source. 85 # The weight returned here could be specified using a unit of your 86 # choice (for example, bundles of sizes 100MB, 200MB, and 700MB may 87 # specify weights 100, 200, 700 or 1, 2, 7) but all bundles of a 88 # source should specify the weight using the same unit. 89 # * source - a BoundedSource object for the bundle. 90 # * start_position - starting position of the bundle 91 # * stop_position - ending position of the bundle. 92 # 93 # Type for start and stop positions are specific to the bounded source and must 94 # be consistent throughout. 95 SourceBundle = namedtuple( 96 'SourceBundle', 'weight source start_position stop_position') 97 98 99 class SourceBase(HasDisplayData, urns.RunnerApiFn): 100 """Base class for all sources that can be passed to beam.io.Read(...). 101 """ 102 urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE) 103 104 def is_bounded(self): 105 # type: () -> bool 106 raise NotImplementedError 107 108 109 class BoundedSource(SourceBase): 110 """A source that reads a finite amount of input records. 111 112 This class defines following operations which can be used to read the source 113 efficiently. 114 115 * Size estimation - method ``estimate_size()`` may return an accurate 116 estimation in bytes for the size of the source. 117 * Splitting into bundles of a given size - method ``split()`` can be used to 118 split the source into a set of sub-sources (bundles) based on a desired 119 bundle size. 120 * Getting a RangeTracker - method ``get_range_tracker()`` should return a 121 ``RangeTracker`` object for a given position range for the position type 122 of the records returned by the source. 123 * Reading the data - method ``read()`` can be used to read data from the 124 source while respecting the boundaries defined by a given 125 ``RangeTracker``. 126 127 A runner will perform reading the source in two steps. 128 129 (1) Method ``get_range_tracker()`` will be invoked with start and end 130 positions to obtain a ``RangeTracker`` for the range of positions the 131 runner intends to read. Source must define a default initial start and end 132 position range. These positions must be used if the start and/or end 133 positions passed to the method ``get_range_tracker()`` are ``None`` 134 (2) Method read() will be invoked with the ``RangeTracker`` obtained in the 135 previous step. 136 137 **Mutability** 138 139 A ``BoundedSource`` object should not be mutated while 140 its methods (for example, ``read()``) are being invoked by a runner. Runner 141 implementations may invoke methods of ``BoundedSource`` objects through 142 multi-threaded and/or reentrant execution modes. 143 """ 144 def estimate_size(self): 145 # type: () -> Optional[int] 146 147 """Estimates the size of source in bytes. 148 149 An estimate of the total size (in bytes) of the data that would be read 150 from this source. This estimate is in terms of external storage size, 151 before performing decompression or other processing. 152 153 Returns: 154 estimated size of the source if the size can be determined, ``None`` 155 otherwise. 156 """ 157 raise NotImplementedError 158 159 def split(self, 160 desired_bundle_size, # type: int 161 start_position=None, # type: Optional[Any] 162 stop_position=None, # type: Optional[Any] 163 ): 164 # type: (...) -> Iterator[SourceBundle] 165 166 """Splits the source into a set of bundles. 167 168 Bundles should be approximately of size ``desired_bundle_size`` bytes. 169 170 Args: 171 desired_bundle_size: the desired size (in bytes) of the bundles returned. 172 start_position: if specified the given position must be used as the 173 starting position of the first bundle. 174 stop_position: if specified the given position must be used as the ending 175 position of the last bundle. 176 Returns: 177 an iterator of objects of type 'SourceBundle' that gives information about 178 the generated bundles. 179 """ 180 raise NotImplementedError 181 182 def get_range_tracker(self, 183 start_position, # type: Optional[Any] 184 stop_position, # type: Optional[Any] 185 ): 186 # type: (...) -> RangeTracker 187 188 """Returns a RangeTracker for a given position range. 189 190 Framework may invoke ``read()`` method with the RangeTracker object returned 191 here to read data from the source. 192 193 Args: 194 start_position: starting position of the range. If 'None' default start 195 position of the source must be used. 196 stop_position: ending position of the range. If 'None' default stop 197 position of the source must be used. 198 Returns: 199 a ``RangeTracker`` for the given position range. 200 """ 201 raise NotImplementedError 202 203 def read(self, range_tracker): 204 """Returns an iterator that reads data from the source. 205 206 The returned set of data must respect the boundaries defined by the given 207 ``RangeTracker`` object. For example: 208 209 * Returned set of data must be for the range 210 ``[range_tracker.start_position, range_tracker.stop_position)``. Note 211 that a source may decide to return records that start after 212 ``range_tracker.stop_position``. See documentation in class 213 ``RangeTracker`` for more details. Also, note that framework might 214 invoke ``range_tracker.try_split()`` to perform dynamic split 215 operations. range_tracker.stop_position may be updated 216 dynamically due to successful dynamic split operations. 217 * Method ``range_tracker.try_split()`` must be invoked for every record 218 that starts at a split point. 219 * Method ``range_tracker.record_current_position()`` may be invoked for 220 records that do not start at split points. 221 222 Args: 223 range_tracker: a ``RangeTracker`` whose boundaries must be respected 224 when reading data from the source. A runner that reads this 225 source muss pass a ``RangeTracker`` object that is not 226 ``None``. 227 Returns: 228 an iterator of data read by the source. 229 """ 230 raise NotImplementedError 231 232 def default_output_coder(self): 233 """Coder that should be used for the records returned by the source. 234 235 Should be overridden by sources that produce objects that can be encoded 236 more efficiently than pickling. 237 """ 238 return coders.registry.get_coder(object) 239 240 def is_bounded(self): 241 return True 242 243 244 class RangeTracker(object): 245 """A thread safe object used by Dataflow source framework. 246 247 A Dataflow source is defined using a ''BoundedSource'' and a ''RangeTracker'' 248 pair. A ''RangeTracker'' is used by Dataflow source framework to perform 249 dynamic work rebalancing of position-based sources. 250 251 **Position-based sources** 252 253 A position-based source is one where the source can be described by a range 254 of positions of an ordered type and the records returned by the reader can be 255 described by positions of the same type. 256 257 In case a record occupies a range of positions in the source, the most 258 important thing about the record is the position where it starts. 259 260 Defining the semantics of positions for a source is entirely up to the source 261 class, however the chosen definitions have to obey certain properties in order 262 to make it possible to correctly split the source into parts, including 263 dynamic splitting. Two main aspects need to be defined: 264 265 1. How to assign starting positions to records. 266 2. Which records should be read by a source with a range '[A, B)'. 267 268 Moreover, reading a range must be *efficient*, i.e., the performance of 269 reading a range should not significantly depend on the location of the range. 270 For example, reading the range [A, B) should not require reading all data 271 before 'A'. 272 273 The sections below explain exactly what properties these definitions must 274 satisfy, and how to use a ``RangeTracker`` with a properly defined source. 275 276 **Properties of position-based sources** 277 278 The main requirement for position-based sources is *associativity*: reading 279 records from '[A, B)' and records from '[B, C)' should give the same 280 records as reading from '[A, C)', where 'A <= B <= C'. This property 281 ensures that no matter how a range of positions is split into arbitrarily many 282 sub-ranges, the total set of records described by them stays the same. 283 284 The other important property is how the source's range relates to positions of 285 records in the source. In many sources each record can be identified by a 286 unique starting position. In this case: 287 288 * All records returned by a source '[A, B)' must have starting positions in 289 this range. 290 * All but the last record should end within this range. The last record may or 291 may not extend past the end of the range. 292 * Records should not overlap. 293 294 Such sources should define "read '[A, B)'" as "read from the first record 295 starting at or after 'A', up to but not including the first record starting 296 at or after 'B'". 297 298 Some examples of such sources include reading lines or CSV from a text file, 299 reading keys and values from a BigTable, etc. 300 301 The concept of *split points* allows to extend the definitions for dealing 302 with sources where some records cannot be identified by a unique starting 303 position. 304 305 In all cases, all records returned by a source '[A, B)' must *start* at or 306 after 'A'. 307 308 **Split points** 309 310 Some sources may have records that are not directly addressable. For example, 311 imagine a file format consisting of a sequence of compressed blocks. Each 312 block can be assigned an offset, but records within the block cannot be 313 directly addressed without decompressing the block. Let us refer to this 314 hypothetical format as <i>CBF (Compressed Blocks Format)</i>. 315 316 Many such formats can still satisfy the associativity property. For example, 317 in CBF, reading '[A, B)' can mean "read all the records in all blocks whose 318 starting offset is in '[A, B)'". 319 320 To support such complex formats, we introduce the notion of *split points*. We 321 say that a record is a split point if there exists a position 'A' such that 322 the record is the first one to be returned when reading the range 323 '[A, infinity)'. In CBF, the only split points would be the first records 324 in each block. 325 326 Split points allow us to define the meaning of a record's position and a 327 source's range in all cases: 328 329 * For a record that is at a split point, its position is defined to be the 330 largest 'A' such that reading a source with the range '[A, infinity)' 331 returns this record. 332 * Positions of other records are only required to be non-decreasing. 333 * Reading the source '[A, B)' must return records starting from the first 334 split point at or after 'A', up to but not including the first split point 335 at or after 'B'. In particular, this means that the first record returned 336 by a source MUST always be a split point. 337 * Positions of split points must be unique. 338 339 As a result, for any decomposition of the full range of the source into 340 position ranges, the total set of records will be the full set of records in 341 the source, and each record will be read exactly once. 342 343 **Consumed positions** 344 345 As the source is being read, and records read from it are being passed to the 346 downstream transforms in the pipeline, we say that positions in the source are 347 being *consumed*. When a reader has read a record (or promised to a caller 348 that a record will be returned), positions up to and including the record's 349 start position are considered *consumed*. 350 351 Dynamic splitting can happen only at *unconsumed* positions. If the reader 352 just returned a record at offset 42 in a file, dynamic splitting can happen 353 only at offset 43 or beyond, as otherwise that record could be read twice (by 354 the current reader and by a reader of the task starting at 43). 355 """ 356 357 SPLIT_POINTS_UNKNOWN = object() 358 359 def start_position(self): 360 """Returns the starting position of the current range, inclusive.""" 361 raise NotImplementedError(type(self)) 362 363 def stop_position(self): 364 """Returns the ending position of the current range, exclusive.""" 365 raise NotImplementedError(type(self)) 366 367 def try_claim(self, position): # pylint: disable=unused-argument 368 """Atomically determines if a record at a split point is within the range. 369 370 This method should be called **if and only if** the record is at a split 371 point. This method may modify the internal state of the ``RangeTracker`` by 372 updating the last-consumed position to ``position``. 373 374 ** Thread safety ** 375 376 Methods of the class ``RangeTracker`` including this method may get invoked 377 by different threads, hence must be made thread-safe, e.g. by using a single 378 lock object. 379 380 Args: 381 position: starting position of a record being read by a source. 382 383 Returns: 384 ``True``, if the given position falls within the current range, returns 385 ``False`` otherwise. 386 """ 387 raise NotImplementedError 388 389 def set_current_position(self, position): 390 """Updates the last-consumed position to the given position. 391 392 A source may invoke this method for records that do not start at split 393 points. This may modify the internal state of the ``RangeTracker``. If the 394 record starts at a split point, method ``try_claim()`` **must** be invoked 395 instead of this method. 396 397 Args: 398 position: starting position of a record being read by a source. 399 """ 400 raise NotImplementedError 401 402 def position_at_fraction(self, fraction): 403 """Returns the position at the given fraction. 404 405 Given a fraction within the range [0.0, 1.0) this method will return the 406 position at the given fraction compared to the position range 407 [self.start_position, self.stop_position). 408 409 ** Thread safety ** 410 411 Methods of the class ``RangeTracker`` including this method may get invoked 412 by different threads, hence must be made thread-safe, e.g. by using a single 413 lock object. 414 415 Args: 416 fraction: a float value within the range [0.0, 1.0). 417 Returns: 418 a position within the range [self.start_position, self.stop_position). 419 """ 420 raise NotImplementedError 421 422 def try_split(self, position): 423 """Atomically splits the current range. 424 425 Determines a position to split the current range, split_position, based on 426 the given position. In most cases split_position and position will be the 427 same. 428 429 Splits the current range '[self.start_position, self.stop_position)' 430 into a "primary" part '[self.start_position, split_position)' and a 431 "residual" part '[split_position, self.stop_position)', assuming the 432 current last-consumed position is within 433 '[self.start_position, split_position)' (i.e., split_position has not been 434 consumed yet). 435 436 If successful, updates the current range to be the primary and returns a 437 tuple (split_position, split_fraction). split_fraction should be the 438 fraction of size of range '[self.start_position, split_position)' compared 439 to the original (before split) range 440 '[self.start_position, self.stop_position)'. 441 442 If the split_position has already been consumed, returns ``None``. 443 444 ** Thread safety ** 445 446 Methods of the class ``RangeTracker`` including this method may get invoked 447 by different threads, hence must be made thread-safe, e.g. by using a single 448 lock object. 449 450 Args: 451 position: suggested position where the current range should try to 452 be split at. 453 Returns: 454 a tuple containing the split position and split fraction if split is 455 successful. Returns ``None`` otherwise. 456 """ 457 raise NotImplementedError 458 459 def fraction_consumed(self): 460 """Returns the approximate fraction of consumed positions in the source. 461 462 ** Thread safety ** 463 464 Methods of the class ``RangeTracker`` including this method may get invoked 465 by different threads, hence must be made thread-safe, e.g. by using a single 466 lock object. 467 468 Returns: 469 the approximate fraction of positions that have been consumed by 470 successful 'try_split()' and 'try_claim()' calls, or 471 0.0 if no such calls have happened. 472 """ 473 raise NotImplementedError 474 475 def split_points(self): 476 """Gives the number of split points consumed and remaining. 477 478 For a ``RangeTracker`` used by a ``BoundedSource`` (within a 479 ``BoundedSource.read()`` invocation) this method produces a 2-tuple that 480 gives the number of split points consumed by the ``BoundedSource`` and the 481 number of split points remaining within the range of the ``RangeTracker`` 482 that has not been consumed by the ``BoundedSource``. 483 484 More specifically, given that the position of the current record being read 485 by ``BoundedSource`` is current_position this method produces a tuple that 486 consists of 487 (1) number of split points in the range [self.start_position(), 488 current_position) without including the split point that is currently being 489 consumed. This represents the total amount of parallelism in the consumed 490 part of the source. 491 (2) number of split points within the range 492 [current_position, self.stop_position()) including the split point that is 493 currently being consumed. This represents the total amount of parallelism in 494 the unconsumed part of the source. 495 496 Methods of the class ``RangeTracker`` including this method may get invoked 497 by different threads, hence must be made thread-safe, e.g. by using a single 498 lock object. 499 500 ** General information about consumed and remaining number of split 501 points returned by this method. ** 502 503 * Before a source read (``BoundedSource.read()`` invocation) claims the 504 first split point, number of consumed split points is 0. This condition 505 holds independent of whether the input is "splittable". A splittable 506 source is a source that has more than one split point. 507 * Any source read that has only claimed one split point has 0 consumed 508 split points since the first split point is the current split point and 509 is still being processed. This condition holds independent of whether 510 the input is splittable. 511 * For an empty source read which never invokes 512 ``RangeTracker.try_claim()``, the consumed number of split points is 0. 513 This condition holds independent of whether the input is splittable. 514 * For a source read which has invoked ``RangeTracker.try_claim()`` n 515 times, the consumed number of split points is n -1. 516 * If a ``BoundedSource`` sets a callback through function 517 ``set_split_points_unclaimed_callback()``, ``RangeTracker`` can use that 518 callback when determining remaining number of split points. 519 * Remaining split points should include the split point that is currently 520 being consumed by the source read. Hence if the above callback returns 521 an integer value n, remaining number of split points should be (n + 1). 522 * After last split point is claimed remaining split points becomes 1, 523 because this unfinished read itself represents an unfinished split 524 point. 525 * After all records of the source has been consumed, remaining number of 526 split points becomes 0 and consumed number of split points becomes equal 527 to the total number of split points within the range being read by the 528 source. This method does not address this condition and will continue to 529 report number of consumed split points as 530 ("total number of split points" - 1) and number of remaining split 531 points as 1. A runner that performs the reading of the source can 532 detect when all records have been consumed and adjust remaining and 533 consumed number of split points accordingly. 534 535 ** Examples ** 536 537 (1) A "perfectly splittable" input which can be read in parallel down to the 538 individual records. 539 540 Consider a perfectly splittable input that consists of 50 split points. 541 542 * Before a source read (``BoundedSource.read()`` invocation) claims the 543 first split point, number of consumed split points is 0 number of 544 remaining split points is 50. 545 * After claiming first split point, consumed number of split points is 0 546 and remaining number of split is 50. 547 * After claiming split point #30, consumed number of split points is 29 548 and remaining number of split points is 21. 549 * After claiming all 50 split points, consumed number of split points is 550 49 and remaining number of split points is 1. 551 552 (2) a "block-compressed" file format such as ``avroio``, in which a block of 553 records has to be read as a whole, but different blocks can be read in 554 parallel. 555 556 Consider a block compressed input that consists of 5 blocks. 557 558 * Before a source read (``BoundedSource.read()`` invocation) claims the 559 first split point (first block), number of consumed split points is 0 560 number of remaining split points is 5. 561 * After claiming first split point, consumed number of split points is 0 562 and remaining number of split is 5. 563 * After claiming split point #3, consumed number of split points is 2 564 and remaining number of split points is 3. 565 * After claiming all 5 split points, consumed number of split points is 566 4 and remaining number of split points is 1. 567 568 (3) an "unsplittable" input such as a cursor in a database or a gzip 569 compressed file. 570 571 Such an input is considered to have only a single split point. Number of 572 consumed split points is always 0 and number of remaining split points 573 is always 1. 574 575 By default ``RangeTracker` returns ``RangeTracker.SPLIT_POINTS_UNKNOWN`` for 576 both consumed and remaining number of split points, which indicates that the 577 number of split points consumed and remaining is unknown. 578 579 Returns: 580 A pair that gives consumed and remaining number of split points. Consumed 581 number of split points should be an integer larger than or equal to zero 582 or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. Remaining number of split points 583 should be an integer larger than zero or 584 ``RangeTracker.SPLIT_POINTS_UNKNOWN``. 585 """ 586 return ( 587 RangeTracker.SPLIT_POINTS_UNKNOWN, RangeTracker.SPLIT_POINTS_UNKNOWN) 588 589 def set_split_points_unclaimed_callback(self, callback): 590 """Sets a callback for determining the unclaimed number of split points. 591 592 By invoking this function, a ``BoundedSource`` can set a callback function 593 that may get invoked by the ``RangeTracker`` to determine the number of 594 unclaimed split points. A split point is unclaimed if 595 ``RangeTracker.try_claim()`` method has not been successfully invoked for 596 that particular split point. The callback function accepts a single 597 parameter, a stop position for the BoundedSource (stop_position). If the 598 record currently being consumed by the ``BoundedSource`` is at position 599 current_position, callback should return the number of split points within 600 the range (current_position, stop_position). Note that, this should not 601 include the split point that is currently being consumed by the source. 602 603 This function must be implemented by subclasses before being used. 604 605 Args: 606 callback: a function that takes a single parameter, a stop position, 607 and returns unclaimed number of split points for the source read 608 operation that is calling this function. Value returned from 609 callback should be either an integer larger than or equal to 610 zero or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. 611 """ 612 raise NotImplementedError 613 614 615 class Sink(HasDisplayData): 616 """This class is deprecated, no backwards-compatibility guarantees. 617 618 A resource that can be written to using the ``beam.io.Write`` transform. 619 620 Here ``beam`` stands for Apache Beam Python code imported in following manner. 621 ``import apache_beam as beam``. 622 623 A parallel write to an ``iobase.Sink`` consists of three phases: 624 625 1. A sequential *initialization* phase (e.g., creating a temporary output 626 directory, etc.) 627 2. A parallel write phase where workers write *bundles* of records 628 3. A sequential *finalization* phase (e.g., committing the writes, merging 629 output files, etc.) 630 631 Implementing a new sink requires extending two classes. 632 633 1. iobase.Sink 634 635 ``iobase.Sink`` is an immutable logical description of the location/resource 636 to write to. Depending on the type of sink, it may contain fields such as the 637 path to an output directory on a filesystem, a database table name, 638 etc. ``iobase.Sink`` provides methods for performing a write operation to the 639 sink described by it. To this end, implementors of an extension of 640 ``iobase.Sink`` must implement three methods: 641 ``initialize_write()``, ``open_writer()``, and ``finalize_write()``. 642 643 2. iobase.Writer 644 645 ``iobase.Writer`` is used to write a single bundle of records. An 646 ``iobase.Writer`` defines two methods: ``write()`` which writes a 647 single record from the bundle and ``close()`` which is called once 648 at the end of writing a bundle. 649 650 See also ``apache_beam.io.filebasedsink.FileBasedSink`` which provides a 651 simpler API for writing sinks that produce files. 652 653 **Execution of the Write transform** 654 655 ``initialize_write()``, ``pre_finalize()``, and ``finalize_write()`` are 656 conceptually called once. However, implementors must 657 ensure that these methods are *idempotent*, as they may be called multiple 658 times on different machines in the case of failure/retry. A method may be 659 called more than once concurrently, in which case it's okay to have a 660 transient failure (such as due to a race condition). This failure should not 661 prevent subsequent retries from succeeding. 662 663 ``initialize_write()`` should perform any initialization that needs to be done 664 prior to writing to the sink. ``initialize_write()`` may return a result 665 (let's call this ``init_result``) that contains any parameters it wants to 666 pass on to its writers about the sink. For example, a sink that writes to a 667 file system may return an ``init_result`` that contains a dynamically 668 generated unique directory to which data should be written. 669 670 To perform writing of a bundle of elements, Dataflow execution engine will 671 create an ``iobase.Writer`` using the implementation of 672 ``iobase.Sink.open_writer()``. When invoking ``open_writer()`` execution 673 engine will provide the ``init_result`` returned by ``initialize_write()`` 674 invocation as well as a *bundle id* (let's call this ``bundle_id``) that is 675 unique for each invocation of ``open_writer()``. 676 677 Execution engine will then invoke ``iobase.Writer.write()`` implementation for 678 each element that has to be written. Once all elements of a bundle are 679 written, execution engine will invoke ``iobase.Writer.close()`` implementation 680 which should return a result (let's call this ``write_result``) that contains 681 information that encodes the result of the write and, in most cases, some 682 encoding of the unique bundle id. For example, if each bundle is written to a 683 unique temporary file, ``close()`` method may return an object that contains 684 the temporary file name. After writing of all bundles is complete, execution 685 engine will invoke ``pre_finalize()`` and then ``finalize_write()`` 686 implementation. 687 688 The execution of a write transform can be illustrated using following pseudo 689 code (assume that the outer for loop happens in parallel across many 690 machines):: 691 692 init_result = sink.initialize_write() 693 write_results = [] 694 for bundle in partition(pcoll): 695 writer = sink.open_writer(init_result, generate_bundle_id()) 696 for elem in bundle: 697 writer.write(elem) 698 write_results.append(writer.close()) 699 pre_finalize_result = sink.pre_finalize(init_result, write_results) 700 sink.finalize_write(init_result, write_results, pre_finalize_result) 701 702 703 **init_result** 704 705 Methods of 'iobase.Sink' should agree on the 'init_result' type that will be 706 returned when initializing the sink. This type can be a client-defined object 707 or an existing type. The returned type must be picklable using Dataflow coder 708 ``coders.PickleCoder``. Returning an init_result is optional. 709 710 **bundle_id** 711 712 In order to ensure fault-tolerance, a bundle may be executed multiple times 713 (e.g., in the event of failure/retry or for redundancy). However, exactly one 714 of these executions will have its result passed to the 715 ``iobase.Sink.finalize_write()`` method. Each call to 716 ``iobase.Sink.open_writer()`` is passed a unique bundle id when it is called 717 by the ``WriteImpl`` transform, so even redundant or retried bundles will have 718 a unique way of identifying their output. 719 720 The bundle id should be used to guarantee that a bundle's output is unique. 721 This uniqueness guarantee is important; if a bundle is to be output to a file, 722 for example, the name of the file must be unique to avoid conflicts with other 723 writers. The bundle id should be encoded in the writer result returned by the 724 writer and subsequently used by the ``finalize_write()`` method to identify 725 the results of successful writes. 726 727 For example, consider the scenario where a Writer writes files containing 728 serialized records and the ``finalize_write()`` is to merge or rename these 729 output files. In this case, a writer may use its unique id to name its output 730 file (to avoid conflicts) and return the name of the file it wrote as its 731 writer result. The ``finalize_write()`` will then receive an ``Iterable`` of 732 output file names that it can then merge or rename using some bundle naming 733 scheme. 734 735 **write_result** 736 737 ``iobase.Writer.close()`` and ``finalize_write()`` implementations must agree 738 on type of the ``write_result`` object returned when invoking 739 ``iobase.Writer.close()``. This type can be a client-defined object or 740 an existing type. The returned type must be picklable using Dataflow coder 741 ``coders.PickleCoder``. Returning a ``write_result`` when 742 ``iobase.Writer.close()`` is invoked is optional but if unique 743 ``write_result`` objects are not returned, sink should, guarantee idempotency 744 when same bundle is written multiple times due to failure/retry or redundancy. 745 746 747 **More information** 748 749 For more information on creating new sinks please refer to the official 750 documentation at 751 ``https://beam.apache.org/documentation/sdks/python-custom-io#creating-sinks`` 752 """ 753 # Whether Beam should skip writing any shards if all are empty. 754 skip_if_empty = False 755 756 def initialize_write(self): 757 """Initializes the sink before writing begins. 758 759 Invoked before any data is written to the sink. 760 761 762 Please see documentation in ``iobase.Sink`` for an example. 763 764 Returns: 765 An object that contains any sink specific state generated by 766 initialization. This object will be passed to open_writer() and 767 finalize_write() methods. 768 """ 769 raise NotImplementedError 770 771 def open_writer(self, init_result, uid): 772 """Opens a writer for writing a bundle of elements to the sink. 773 774 Args: 775 init_result: the result of initialize_write() invocation. 776 uid: a unique identifier generated by the system. 777 Returns: 778 an ``iobase.Writer`` that can be used to write a bundle of records to the 779 current sink. 780 """ 781 raise NotImplementedError 782 783 def pre_finalize(self, init_result, writer_results): 784 """Pre-finalization stage for sink. 785 786 Called after all bundle writes are complete and before finalize_write. 787 Used to setup and verify filesystem and sink states. 788 789 Args: 790 init_result: the result of ``initialize_write()`` invocation. 791 writer_results: an iterable containing results of ``Writer.close()`` 792 invocations. This will only contain results of successful writes, and 793 will only contain the result of a single successful write for a given 794 bundle. 795 796 Returns: 797 An object that contains any sink specific state generated. 798 This object will be passed to finalize_write(). 799 """ 800 raise NotImplementedError 801 802 def finalize_write(self, init_result, writer_results, pre_finalize_result): 803 """Finalizes the sink after all data is written to it. 804 805 Given the result of initialization and an iterable of results from bundle 806 writes, performs finalization after writing and closes the sink. Called 807 after all bundle writes are complete. 808 809 The bundle write results that are passed to finalize are those returned by 810 bundles that completed successfully. Although bundles may have been run 811 multiple times (for fault-tolerance), only one writer result will be passed 812 to finalize for each bundle. An implementation of finalize should perform 813 clean up of any failed and successfully retried bundles. Note that these 814 failed bundles will not have their writer result passed to finalize, so 815 finalize should be capable of locating any temporary/partial output written 816 by failed bundles. 817 818 If all retries of a bundle fails, the whole pipeline will fail *without* 819 finalize_write() being invoked. 820 821 A best practice is to make finalize atomic. If this is impossible given the 822 semantics of the sink, finalize should be idempotent, as it may be called 823 multiple times in the case of failure/retry or for redundancy. 824 825 Note that the iteration order of the writer results is not guaranteed to be 826 consistent if finalize is called multiple times. 827 828 Args: 829 init_result: the result of ``initialize_write()`` invocation. 830 writer_results: an iterable containing results of ``Writer.close()`` 831 invocations. This will only contain results of successful writes, and 832 will only contain the result of a single successful write for a given 833 bundle. 834 pre_finalize_result: the result of ``pre_finalize()`` invocation. 835 """ 836 raise NotImplementedError 837 838 839 class Writer(object): 840 """This class is deprecated, no backwards-compatibility guarantees. 841 842 Writes a bundle of elements from a ``PCollection`` to a sink. 843 844 A Writer ``iobase.Writer.write()`` writes and elements to the sink while 845 ``iobase.Writer.close()`` is called after all elements in the bundle have been 846 written. 847 848 See ``iobase.Sink`` for more detailed documentation about the process of 849 writing to a sink. 850 """ 851 def write(self, value): 852 """Writes a value to the sink using the current writer. 853 """ 854 raise NotImplementedError 855 856 def close(self): 857 """Closes the current writer. 858 859 Please see documentation in ``iobase.Sink`` for an example. 860 861 Returns: 862 An object representing the writes that were performed by the current 863 writer. 864 """ 865 raise NotImplementedError 866 867 def at_capacity(self) -> bool: 868 """Returns whether this writer should be considered at capacity 869 and a new one should be created. 870 """ 871 return False 872 873 874 class Read(ptransform.PTransform): 875 """A transform that reads a PCollection.""" 876 # Import runners here to prevent circular imports 877 from apache_beam.runners.pipeline_context import PipelineContext 878 879 def __init__(self, source): 880 # type: (SourceBase) -> None 881 882 """Initializes a Read transform. 883 884 Args: 885 source: Data source to read from. 886 """ 887 super().__init__() 888 self.source = source 889 890 @staticmethod 891 def get_desired_chunk_size(total_size): 892 if total_size: 893 # 1MB = 1 shard, 1GB = 32 shards, 1TB = 1000 shards, 1PB = 32k shards 894 chunk_size = max(1 << 20, 1000 * int(math.sqrt(total_size))) 895 else: 896 chunk_size = 64 << 20 # 64mb 897 return chunk_size 898 899 def expand(self, pbegin): 900 if isinstance(self.source, BoundedSource): 901 coders.registry.register_coder(BoundedSource, _MemoizingPickleCoder) 902 display_data = self.source.display_data() or {} 903 display_data['source'] = self.source.__class__ 904 905 return ( 906 pbegin 907 | Impulse() 908 | core.Map(lambda _: self.source).with_output_types(BoundedSource) 909 | SDFBoundedSourceReader(display_data)) 910 elif isinstance(self.source, ptransform.PTransform): 911 # The Read transform can also admit a full PTransform as an input 912 # rather than an anctual source. If the input is a PTransform, then 913 # just apply it directly. 914 return pbegin.pipeline | self.source 915 else: 916 # Treat Read itself as a primitive. 917 return pvalue.PCollection( 918 pbegin.pipeline, is_bounded=self.source.is_bounded()) 919 920 def get_windowing(self, unused_inputs): 921 # type: (...) -> core.Windowing 922 return core.Windowing(window.GlobalWindows()) 923 924 def _infer_output_coder(self, input_type=None, input_coder=None): 925 # type: (...) -> Optional[coders.Coder] 926 from apache_beam.runners.dataflow.native_io import iobase as dataflow_io 927 if isinstance(self.source, BoundedSource): 928 return self.source.default_output_coder() 929 elif isinstance(self.source, dataflow_io.NativeSource): 930 return self.source.coder 931 else: 932 return None 933 934 def display_data(self): 935 return { 936 'source': DisplayDataItem(self.source.__class__, label='Read Source'), 937 'source_dd': self.source 938 } 939 940 def to_runner_api_parameter( 941 self, 942 context: PipelineContext, 943 ) -> Tuple[str, Any]: 944 from apache_beam.runners.dataflow.native_io import iobase as dataflow_io 945 if isinstance(self.source, (BoundedSource, dataflow_io.NativeSource)): 946 from apache_beam.io.gcp.pubsub import _PubSubSource 947 if isinstance(self.source, _PubSubSource): 948 return ( 949 common_urns.composites.PUBSUB_READ.urn, 950 beam_runner_api_pb2.PubSubReadPayload( 951 topic=self.source.full_topic, 952 subscription=self.source.full_subscription, 953 timestamp_attribute=self.source.timestamp_attribute, 954 with_attributes=self.source.with_attributes, 955 id_attribute=self.source.id_label)) 956 return ( 957 common_urns.deprecated_primitives.READ.urn, 958 beam_runner_api_pb2.ReadPayload( 959 source=self.source.to_runner_api(context), 960 is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED 961 if self.source.is_bounded() else 962 beam_runner_api_pb2.IsBounded.UNBOUNDED)) 963 elif isinstance(self.source, ptransform.PTransform): 964 return self.source.to_runner_api_parameter(context) 965 raise NotImplementedError( 966 "to_runner_api_parameter not " 967 "implemented for type") 968 969 @staticmethod 970 def from_runner_api_parameter( 971 transform: beam_runner_api_pb2.PTransform, 972 payload: Union[beam_runner_api_pb2.ReadPayload, 973 beam_runner_api_pb2.PubSubReadPayload], 974 context: PipelineContext, 975 ) -> "Read": 976 if transform.spec.urn == common_urns.composites.PUBSUB_READ.urn: 977 assert isinstance(payload, beam_runner_api_pb2.PubSubReadPayload) 978 # Importing locally to prevent circular dependencies. 979 from apache_beam.io.gcp.pubsub import _PubSubSource 980 source = _PubSubSource( 981 topic=payload.topic or None, 982 subscription=payload.subscription or None, 983 id_label=payload.id_attribute or None, 984 with_attributes=payload.with_attributes, 985 timestamp_attribute=payload.timestamp_attribute or None) 986 return Read(source) 987 else: 988 assert isinstance(payload, beam_runner_api_pb2.ReadPayload) 989 return Read(SourceBase.from_runner_api(payload.source, context)) 990 991 @staticmethod 992 def _from_runner_api_parameter_read( 993 transform: beam_runner_api_pb2.PTransform, 994 payload: beam_runner_api_pb2.ReadPayload, 995 context: PipelineContext, 996 ) -> "Read": 997 """Method for type proxying when calling register_urn due to limitations 998 in type exprs in Python""" 999 return Read.from_runner_api_parameter(transform, payload, context) 1000 1001 @staticmethod 1002 def _from_runner_api_parameter_pubsub_read( 1003 transform: beam_runner_api_pb2.PTransform, 1004 payload: beam_runner_api_pb2.PubSubReadPayload, 1005 context: PipelineContext, 1006 ) -> "Read": 1007 """Method for type proxying when calling register_urn due to limitations 1008 in type exprs in Python""" 1009 return Read.from_runner_api_parameter(transform, payload, context) 1010 1011 1012 ptransform.PTransform.register_urn( 1013 common_urns.deprecated_primitives.READ.urn, 1014 beam_runner_api_pb2.ReadPayload, 1015 Read._from_runner_api_parameter_read, 1016 ) 1017 1018 ptransform.PTransform.register_urn( 1019 common_urns.composites.PUBSUB_READ.urn, 1020 beam_runner_api_pb2.PubSubReadPayload, 1021 Read._from_runner_api_parameter_pubsub_read, 1022 ) 1023 1024 1025 class Write(ptransform.PTransform): 1026 """A ``PTransform`` that writes to a sink. 1027 1028 A sink should inherit ``iobase.Sink``. Such implementations are 1029 handled using a composite transform that consists of three ``ParDo``s - 1030 (1) a ``ParDo`` performing a global initialization (2) a ``ParDo`` performing 1031 a parallel write and (3) a ``ParDo`` performing a global finalization. In the 1032 case of an empty ``PCollection``, only the global initialization and 1033 finalization will be performed. Currently only batch workflows support custom 1034 sinks. 1035 1036 Example usage:: 1037 1038 pcollection | beam.io.Write(MySink()) 1039 1040 This returns a ``pvalue.PValue`` object that represents the end of the 1041 Pipeline. 1042 1043 The sink argument may also be a full PTransform, in which case it will be 1044 applied directly. This allows composite sink-like transforms (e.g. a sink 1045 with some pre-processing DoFns) to be used the same as all other sinks. 1046 1047 This transform also supports sinks that inherit ``iobase.NativeSink``. These 1048 are sinks that are implemented natively by the Dataflow service and hence 1049 should not be updated by users. These sinks are processed using a Dataflow 1050 native write transform. 1051 """ 1052 # Import runners here to prevent circular imports 1053 from apache_beam.runners.pipeline_context import PipelineContext 1054 1055 def __init__(self, sink): 1056 """Initializes a Write transform. 1057 1058 Args: 1059 sink: Data sink to write to. 1060 """ 1061 super().__init__() 1062 self.sink = sink 1063 1064 def display_data(self): 1065 return {'sink': self.sink.__class__, 'sink_dd': self.sink} 1066 1067 def expand(self, pcoll): 1068 from apache_beam.runners.dataflow.native_io import iobase as dataflow_io 1069 if isinstance(self.sink, dataflow_io.NativeSink): 1070 # A native sink 1071 return pcoll | 'NativeWrite' >> dataflow_io._NativeWrite(self.sink) 1072 elif isinstance(self.sink, Sink): 1073 # A custom sink 1074 return pcoll | WriteImpl(self.sink) 1075 elif isinstance(self.sink, ptransform.PTransform): 1076 # This allows "composite" sinks to be used like non-composite ones. 1077 return pcoll | self.sink 1078 else: 1079 raise ValueError( 1080 'A sink must inherit iobase.Sink, iobase.NativeSink, ' 1081 'or be a PTransform. Received : %r' % self.sink) 1082 1083 def to_runner_api_parameter( 1084 self, 1085 context: PipelineContext, 1086 ) -> Tuple[str, Any]: 1087 # Importing locally to prevent circular dependencies. 1088 from apache_beam.io.gcp.pubsub import _PubSubSink 1089 if isinstance(self.sink, _PubSubSink): 1090 payload = beam_runner_api_pb2.PubSubWritePayload( 1091 topic=self.sink.full_topic, 1092 id_attribute=self.sink.id_label, 1093 timestamp_attribute=self.sink.timestamp_attribute) 1094 return (common_urns.composites.PUBSUB_WRITE.urn, payload) 1095 else: 1096 return super().to_runner_api_parameter(context) 1097 1098 @staticmethod 1099 @ptransform.PTransform.register_urn( 1100 common_urns.composites.PUBSUB_WRITE.urn, 1101 beam_runner_api_pb2.PubSubWritePayload) 1102 def from_runner_api_parameter( 1103 ptransform: Any, 1104 payload: beam_runner_api_pb2.PubSubWritePayload, 1105 unused_context: PipelineContext, 1106 ) -> "Write": 1107 if ptransform.spec.urn != common_urns.composites.PUBSUB_WRITE.urn: 1108 raise ValueError( 1109 'Write transform cannot be constructed for the given proto %r', 1110 ptransform) 1111 1112 if not payload.topic: 1113 raise NotImplementedError( 1114 "from_runner_api_parameter does not " 1115 "handle empty or None topic") 1116 1117 # Importing locally to prevent circular dependencies. 1118 from apache_beam.io.gcp.pubsub import _PubSubSink 1119 sink = _PubSubSink( 1120 topic=payload.topic, 1121 id_label=payload.id_attribute or None, 1122 timestamp_attribute=payload.timestamp_attribute or None) 1123 return Write(sink) 1124 1125 1126 class WriteImpl(ptransform.PTransform): 1127 """Implements the writing of custom sinks.""" 1128 def __init__(self, sink): 1129 # type: (Sink) -> None 1130 super().__init__() 1131 self.sink = sink 1132 1133 def expand(self, pcoll): 1134 do_once = pcoll.pipeline | 'DoOnce' >> core.Create([None]) 1135 init_result_coll = do_once | 'InitializeWrite' >> core.Map( 1136 lambda _, sink: sink.initialize_write(), self.sink) 1137 if getattr(self.sink, 'num_shards', 0): 1138 min_shards = self.sink.num_shards 1139 if min_shards == 1: 1140 keyed_pcoll = pcoll | core.Map(lambda x: (None, x)) 1141 else: 1142 keyed_pcoll = pcoll | core.ParDo(_RoundRobinKeyFn(), count=min_shards) 1143 write_result_coll = ( 1144 keyed_pcoll 1145 | core.WindowInto(window.GlobalWindows()) 1146 | core.GroupByKey() 1147 | 'WriteBundles' >> core.ParDo( 1148 _WriteKeyedBundleDoFn(self.sink), AsSingleton(init_result_coll))) 1149 else: 1150 min_shards = 1 1151 write_result_coll = ( 1152 pcoll 1153 | core.WindowInto(window.GlobalWindows()) 1154 | 'WriteBundles' >> core.ParDo( 1155 _WriteBundleDoFn(self.sink), AsSingleton(init_result_coll)) 1156 | 'Pair' >> core.Map(lambda x: (None, x)) 1157 | core.GroupByKey() 1158 | 'Extract' >> core.FlatMap(lambda x: x[1])) 1159 # PreFinalize should run before FinalizeWrite, and the two should not be 1160 # fused. 1161 pre_finalize_coll = ( 1162 do_once 1163 | 'PreFinalize' >> core.FlatMap( 1164 _pre_finalize, 1165 self.sink, 1166 AsSingleton(init_result_coll), 1167 AsIter(write_result_coll))) 1168 return do_once | 'FinalizeWrite' >> core.FlatMap( 1169 _finalize_write, 1170 self.sink, 1171 AsSingleton(init_result_coll), 1172 AsIter(write_result_coll), 1173 min_shards, 1174 AsSingleton(pre_finalize_coll)).with_output_types(str) 1175 1176 1177 class _WriteBundleDoFn(core.DoFn): 1178 """A DoFn for writing elements to an iobase.Writer. 1179 Opens a writer at the first element and closes the writer at finish_bundle(). 1180 """ 1181 def __init__(self, sink): 1182 self.sink = sink 1183 1184 def display_data(self): 1185 return {'sink_dd': self.sink} 1186 1187 def start_bundle(self): 1188 self.writer = None 1189 1190 def process(self, element, init_result): 1191 if self.writer is None: 1192 # We ignore UUID collisions here since they are extremely rare. 1193 self.writer = self.sink.open_writer(init_result, str(uuid.uuid4())) 1194 self.writer.write(element) 1195 if self.writer.at_capacity(): 1196 yield self.writer.close() 1197 self.writer = None 1198 1199 def finish_bundle(self): 1200 if self.writer is not None: 1201 yield WindowedValue( 1202 self.writer.close(), 1203 window.GlobalWindow().max_timestamp(), [window.GlobalWindow()]) 1204 1205 1206 class _WriteKeyedBundleDoFn(core.DoFn): 1207 def __init__(self, sink): 1208 self.sink = sink 1209 1210 def display_data(self): 1211 return {'sink_dd': self.sink} 1212 1213 def process(self, element, init_result): 1214 bundle = element 1215 writer = self.sink.open_writer(init_result, str(uuid.uuid4())) 1216 for e in bundle[1]: # values 1217 writer.write(e) 1218 return [window.TimestampedValue(writer.close(), timestamp.MAX_TIMESTAMP)] 1219 1220 1221 def _pre_finalize(unused_element, sink, init_result, write_results): 1222 return sink.pre_finalize(init_result, write_results) 1223 1224 1225 def _finalize_write( 1226 unused_element, 1227 sink, 1228 init_result, 1229 write_results, 1230 min_shards, 1231 pre_finalize_results): 1232 write_results = list(write_results) 1233 extra_shards = [] 1234 if len(write_results) < min_shards: 1235 if write_results or not sink.skip_if_empty: 1236 _LOGGER.debug( 1237 'Creating %s empty shard(s).', min_shards - len(write_results)) 1238 for _ in range(min_shards - len(write_results)): 1239 writer = sink.open_writer(init_result, str(uuid.uuid4())) 1240 extra_shards.append(writer.close()) 1241 outputs = sink.finalize_write( 1242 init_result, write_results + extra_shards, pre_finalize_results) 1243 if outputs: 1244 return ( 1245 window.TimestampedValue(v, timestamp.MAX_TIMESTAMP) for v in outputs) 1246 1247 1248 class _RoundRobinKeyFn(core.DoFn): 1249 def start_bundle(self): 1250 self.counter = None 1251 1252 def process(self, element, count): 1253 if self.counter is None: 1254 self.counter = random.randrange(0, count) 1255 self.counter = (1 + self.counter) % count 1256 yield self.counter, element 1257 1258 1259 class RestrictionTracker(object): 1260 """Manages access to a restriction. 1261 1262 Keeps track of the restrictions claimed part for a Splittable DoFn. 1263 1264 The restriction may be modified by different threads, however the system will 1265 ensure sufficient locking such that no methods on the restriction tracker 1266 will be called concurrently. 1267 1268 See following documents for more details. 1269 * https://s.apache.org/splittable-do-fn 1270 * https://s.apache.org/splittable-do-fn-python-sdk 1271 """ 1272 def current_restriction(self): 1273 """Returns the current restriction. 1274 1275 Returns a restriction accurately describing the full range of work the 1276 current ``DoFn.process()`` call will do, including already completed work. 1277 1278 The current restriction returned by method may be updated dynamically due 1279 to due to concurrent invocation of other methods of the 1280 ``RestrictionTracker``, For example, ``split()``. 1281 1282 This API is required to be implemented. 1283 1284 Returns: a restriction object. 1285 """ 1286 raise NotImplementedError 1287 1288 def current_progress(self): 1289 # type: () -> RestrictionProgress 1290 1291 """Returns a RestrictionProgress object representing the current progress. 1292 1293 This API is recommended to be implemented. The runner can do a better job 1294 at parallel processing with better progress signals. 1295 """ 1296 raise NotImplementedError 1297 1298 def check_done(self): 1299 """Checks whether the restriction has been fully processed. 1300 1301 Called by the SDK harness after iterator returned by ``DoFn.process()`` 1302 has been fully read. 1303 1304 This method must raise a `ValueError` if there is still any unclaimed work 1305 remaining in the restriction when this method is invoked. Exception raised 1306 must have an informative error message. 1307 1308 This API is required to be implemented in order to make sure no data loss 1309 during SDK processing. 1310 1311 Returns: ``True`` if current restriction has been fully processed. 1312 Raises: 1313 ValueError: if there is still any unclaimed work remaining. 1314 """ 1315 raise NotImplementedError 1316 1317 def try_split(self, fraction_of_remainder): 1318 """Splits current restriction based on fraction_of_remainder. 1319 1320 If splitting the current restriction is possible, the current restriction is 1321 split into a primary and residual restriction pair. This invocation updates 1322 the ``current_restriction()`` to be the primary restriction effectively 1323 having the current ``DoFn.process()`` execution responsible for performing 1324 the work that the primary restriction represents. The residual restriction 1325 will be executed in a separate ``DoFn.process()`` invocation (likely in a 1326 different process). The work performed by executing the primary and residual 1327 restrictions as separate ``DoFn.process()`` invocations MUST be equivalent 1328 to the work performed as if this split never occurred. 1329 1330 The ``fraction_of_remainder`` should be used in a best effort manner to 1331 choose a primary and residual restriction based upon the fraction of the 1332 remaining work that the current ``DoFn.process()`` invocation is responsible 1333 for. For example, if a ``DoFn.process()`` was reading a file with a 1334 restriction representing the offset range [100, 200) and has processed up to 1335 offset 130 with a fraction_of_remainder of 0.7, the primary and residual 1336 restrictions returned would be [100, 179), [179, 200) (note: current_offset 1337 + fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179). 1338 1339 ``fraction_of_remainder`` = 0 means a checkpoint is required. 1340 1341 The API is recommended to be implemented for batch pipeline given that it is 1342 very important for pipeline scaling and end to end pipeline execution. 1343 1344 The API is required to be implemented for a streaming pipeline. 1345 1346 Args: 1347 fraction_of_remainder: A hint as to the fraction of work the primary 1348 restriction should represent based upon the current known remaining 1349 amount of work. 1350 1351 Returns: 1352 (primary_restriction, residual_restriction) if a split was possible, 1353 otherwise returns ``None``. 1354 """ 1355 raise NotImplementedError 1356 1357 def try_claim(self, position): 1358 """Attempts to claim the block of work in the current restriction 1359 identified by the given position. Each claimed position MUST be a valid 1360 split point. 1361 1362 If this succeeds, the DoFn MUST execute the entire block of work. If it 1363 fails, the ``DoFn.process()`` MUST return ``None`` without performing any 1364 additional work or emitting output (note that emitting output or performing 1365 work from ``DoFn.process()`` is also not allowed before the first call of 1366 this method). 1367 1368 The API is required to be implemented. 1369 1370 Args: 1371 position: current position that wants to be claimed. 1372 1373 Returns: ``True`` if the position can be claimed as current_position. 1374 Otherwise, returns ``False``. 1375 """ 1376 raise NotImplementedError 1377 1378 def is_bounded(self): 1379 """Returns whether the amount of work represented by the current restriction 1380 is bounded. 1381 1382 The boundedness of the restriction is used to determine the default behavior 1383 of how to truncate restrictions when a pipeline is being 1384 `drained <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#>`_. # pylint: disable=line-too-long 1385 If the restriction is bounded, then the entire restriction will be processed 1386 otherwise the restriction will be processed till a checkpoint is possible. 1387 1388 The API is required to be implemented. 1389 1390 Returns: ``True`` if the restriction represents a finite amount of work. 1391 Otherwise, returns ``False``. 1392 """ 1393 raise NotImplementedError 1394 1395 1396 class WatermarkEstimator(object): 1397 """A WatermarkEstimator which is used for estimating output_watermark based on 1398 the timestamp of output records or manual modifications. Please refer to 1399 ``watermark_estiamtors`` for commonly used watermark estimators. 1400 1401 The base class provides common APIs that are called by the framework, which 1402 are also accessible inside a DoFn.process() body. Derived watermark estimator 1403 should implement all APIs listed below. Additional methods can be implemented 1404 and will be available when invoked within a DoFn. 1405 1406 Internal state must not be updated asynchronously. 1407 """ 1408 def get_estimator_state(self): 1409 """Get current state of the WatermarkEstimator instance, which can be used 1410 to recreate the WatermarkEstimator when processing the restriction. See 1411 WatermarkEstimatorProvider.create_watermark_estimator. 1412 """ 1413 raise NotImplementedError(type(self)) 1414 1415 def current_watermark(self): 1416 # type: () -> timestamp.Timestamp 1417 1418 """Return estimated output_watermark. This function must return 1419 monotonically increasing watermarks.""" 1420 raise NotImplementedError(type(self)) 1421 1422 def observe_timestamp(self, timestamp): 1423 # type: (timestamp.Timestamp) -> None 1424 1425 """Update tracking watermark with latest output timestamp. 1426 1427 Args: 1428 timestamp: the `timestamp.Timestamp` of current output element. 1429 1430 This is called with the timestamp of every element output from the DoFn. 1431 """ 1432 raise NotImplementedError(type(self)) 1433 1434 1435 class RestrictionProgress(object): 1436 """Used to record the progress of a restriction.""" 1437 def __init__(self, **kwargs): 1438 # Only accept keyword arguments. 1439 self._fraction = kwargs.pop('fraction', None) 1440 self._completed = kwargs.pop('completed', None) 1441 self._remaining = kwargs.pop('remaining', None) 1442 assert not kwargs 1443 1444 def __repr__(self): 1445 return 'RestrictionProgress(fraction=%s, completed=%s, remaining=%s)' % ( 1446 self._fraction, self._completed, self._remaining) 1447 1448 @property 1449 def completed_work(self): 1450 # type: () -> float 1451 if self._completed is not None: 1452 return self._completed 1453 elif self._remaining is not None and self._fraction is not None: 1454 return self._remaining * self._fraction / (1 - self._fraction) 1455 else: 1456 return self._fraction 1457 1458 @property 1459 def remaining_work(self): 1460 # type: () -> float 1461 if self._remaining is not None: 1462 return self._remaining 1463 elif self._completed is not None and self._fraction: 1464 return self._completed * (1 - self._fraction) / self._fraction 1465 else: 1466 return 1 - self._fraction 1467 1468 @property 1469 def total_work(self): 1470 # type: () -> float 1471 return self.completed_work + self.remaining_work 1472 1473 @property 1474 def fraction_completed(self): 1475 # type: () -> float 1476 if self._fraction is not None: 1477 return self._fraction 1478 else: 1479 return float(self._completed) / self.total_work 1480 1481 @property 1482 def fraction_remaining(self): 1483 # type: () -> float 1484 if self._fraction is not None: 1485 return 1 - self._fraction 1486 else: 1487 return float(self._remaining) / self.total_work 1488 1489 def with_completed(self, completed): 1490 # type: (int) -> RestrictionProgress 1491 return RestrictionProgress( 1492 fraction=self._fraction, remaining=self._remaining, completed=completed) 1493 1494 1495 class _SDFBoundedSourceRestriction(object): 1496 """ A restriction wraps SourceBundle and RangeTracker. """ 1497 def __init__(self, source_bundle, range_tracker=None): 1498 self._source_bundle = source_bundle 1499 self._range_tracker = range_tracker 1500 1501 def __reduce__(self): 1502 # The instance of RangeTracker shouldn't be serialized. 1503 return (self.__class__, (self._source_bundle, )) 1504 1505 def range_tracker(self): 1506 if not self._range_tracker: 1507 self._range_tracker = self._source_bundle.source.get_range_tracker( 1508 self._source_bundle.start_position, self._source_bundle.stop_position) 1509 return self._range_tracker 1510 1511 def weight(self): 1512 return self._source_bundle.weight 1513 1514 def source(self): 1515 return self._source_bundle.source 1516 1517 def try_split(self, fraction_of_remainder): 1518 try: 1519 consumed_fraction = self.range_tracker().fraction_consumed() 1520 fraction = ( 1521 consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder) 1522 position = self.range_tracker().position_at_fraction(fraction) 1523 # Need to stash current stop_pos before splitting since 1524 # range_tracker.split will update its stop_pos if splits 1525 # successfully. 1526 stop_pos = self._source_bundle.stop_position 1527 split_result = self.range_tracker().try_split(position) 1528 if split_result: 1529 split_pos, split_fraction = split_result 1530 primary_weight = self._source_bundle.weight * split_fraction 1531 residual_weight = self._source_bundle.weight - primary_weight 1532 # Update self to primary weight and end position. 1533 self._source_bundle = SourceBundle( 1534 primary_weight, 1535 self._source_bundle.source, 1536 self._source_bundle.start_position, 1537 split_pos) 1538 return ( 1539 self, 1540 _SDFBoundedSourceRestriction( 1541 SourceBundle( 1542 residual_weight, 1543 self._source_bundle.source, 1544 split_pos, 1545 stop_pos))) 1546 except Exception: 1547 # For any exceptions from underlying trySplit calls, the wrapper will 1548 # think that the source refuses to split at this point. In this case, 1549 # no split happens at the wrapper level. 1550 return None 1551 1552 1553 class _SDFBoundedSourceRestrictionTracker(RestrictionTracker): 1554 """An `iobase.RestrictionTracker` implementations for wrapping BoundedSource 1555 with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which 1556 wraps SourceBundle and RangeTracker. 1557 1558 Delegated RangeTracker guarantees synchronization safety. 1559 """ 1560 def __init__(self, restriction): 1561 if not isinstance(restriction, _SDFBoundedSourceRestriction): 1562 raise ValueError( 1563 'Initializing SDFBoundedSourceRestrictionTracker' 1564 ' requires a _SDFBoundedSourceRestriction. Got %s instead.' % 1565 restriction) 1566 self.restriction = restriction 1567 1568 def current_progress(self): 1569 # type: () -> RestrictionProgress 1570 return RestrictionProgress( 1571 fraction=self.restriction.range_tracker().fraction_consumed()) 1572 1573 def current_restriction(self): 1574 self.restriction.range_tracker() 1575 return self.restriction 1576 1577 def start_pos(self): 1578 return self.restriction.range_tracker().start_position() 1579 1580 def stop_pos(self): 1581 return self.restriction.range_tracker().stop_position() 1582 1583 def try_claim(self, position): 1584 return self.restriction.range_tracker().try_claim(position) 1585 1586 def try_split(self, fraction_of_remainder): 1587 return self.restriction.try_split(fraction_of_remainder) 1588 1589 def check_done(self): 1590 return self.restriction.range_tracker().fraction_consumed() >= 1.0 1591 1592 def is_bounded(self): 1593 return True 1594 1595 1596 class _SDFBoundedSourceWrapperRestrictionCoder(coders.Coder): 1597 def decode(self, value): 1598 return _SDFBoundedSourceRestriction(SourceBundle(*pickler.loads(value))) 1599 1600 def encode(self, restriction): 1601 return pickler.dumps(( 1602 restriction._source_bundle.weight, 1603 restriction._source_bundle.source, 1604 restriction._source_bundle.start_position, 1605 restriction._source_bundle.stop_position)) 1606 1607 1608 class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider): 1609 """ 1610 A `RestrictionProvider` that is used by SDF for `BoundedSource`. 1611 1612 This restriction provider initializes restriction based on input 1613 element that is expected to be of BoundedSource type. 1614 """ 1615 def __init__(self, desired_chunk_size=None, restriction_coder=None): 1616 self._desired_chunk_size = desired_chunk_size 1617 self._restriction_coder = ( 1618 restriction_coder or _SDFBoundedSourceWrapperRestrictionCoder()) 1619 1620 def _check_source(self, src): 1621 if not isinstance(src, BoundedSource): 1622 raise RuntimeError( 1623 'SDFBoundedSourceRestrictionProvider can only utilize BoundedSource') 1624 1625 def initial_restriction(self, element_source: BoundedSource): 1626 self._check_source(element_source) 1627 range_tracker = element_source.get_range_tracker(None, None) 1628 return _SDFBoundedSourceRestriction( 1629 SourceBundle( 1630 None, 1631 element_source, 1632 range_tracker.start_position(), 1633 range_tracker.stop_position())) 1634 1635 def create_tracker(self, restriction): 1636 return _SDFBoundedSourceRestrictionTracker(restriction) 1637 1638 def split(self, element, restriction): 1639 if self._desired_chunk_size is None: 1640 try: 1641 estimated_size = restriction.source().estimate_size() 1642 except NotImplementedError: 1643 estimated_size = None 1644 self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size) 1645 1646 # Invoke source.split to get initial splitting results. 1647 source_bundles = restriction.source().split(self._desired_chunk_size) 1648 for source_bundle in source_bundles: 1649 yield _SDFBoundedSourceRestriction(source_bundle) 1650 1651 def restriction_size(self, element, restriction): 1652 return restriction.weight() 1653 1654 def restriction_coder(self): 1655 return self._restriction_coder 1656 1657 1658 class SDFBoundedSourceReader(PTransform): 1659 """A ``PTransform`` that uses SDF to read from each ``BoundedSource`` in a 1660 PCollection. 1661 1662 NOTE: This transform can only be used with beam_fn_api enabled. 1663 """ 1664 def __init__(self, data_to_display=None): 1665 self._data_to_display = data_to_display or {} 1666 super().__init__() 1667 1668 def _create_sdf_bounded_source_dofn(self): 1669 class SDFBoundedSourceDoFn(core.DoFn): 1670 def __init__(self, dd): 1671 self._dd = dd 1672 1673 def display_data(self): 1674 return self._dd 1675 1676 def process( 1677 self, 1678 unused_element, 1679 restriction_tracker=core.DoFn.RestrictionParam( 1680 _SDFBoundedSourceRestrictionProvider())): 1681 current_restriction = restriction_tracker.current_restriction() 1682 assert isinstance(current_restriction, _SDFBoundedSourceRestriction) 1683 1684 return current_restriction.source().read( 1685 current_restriction.range_tracker()) 1686 1687 return SDFBoundedSourceDoFn(self._data_to_display) 1688 1689 def expand(self, pvalue): 1690 return pvalue | core.ParDo(self._create_sdf_bounded_source_dofn()) 1691 1692 def get_windowing(self, unused_inputs): 1693 return core.Windowing(window.GlobalWindows()) 1694 1695 def display_data(self): 1696 return self._data_to_display