github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/iobase.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Sources and sinks.
    19  
    20  A Source manages record-oriented data input from a particular kind of source
    21  (e.g. a set of files, a database table, etc.). The reader() method of a source
    22  returns a reader object supporting the iterator protocol; iteration yields
    23  raw records of unprocessed, serialized data.
    24  
    25  
    26  A Sink manages record-oriented data output to a particular kind of sink
    27  (e.g. a set of files, a database table, etc.). The writer() method of a sink
    28  returns a writer object supporting writing records of serialized data to
    29  the sink.
    30  """
    31  
    32  # pytype: skip-file
    33  
    34  import logging
    35  import math
    36  import random
    37  import uuid
    38  from collections import namedtuple
    39  from typing import Any
    40  from typing import Iterator
    41  from typing import Optional
    42  from typing import Tuple
    43  from typing import Union
    44  
    45  from apache_beam import coders
    46  from apache_beam import pvalue
    47  from apache_beam.coders.coders import _MemoizingPickleCoder
    48  from apache_beam.internal import pickler
    49  from apache_beam.portability import common_urns
    50  from apache_beam.portability import python_urns
    51  from apache_beam.portability.api import beam_runner_api_pb2
    52  from apache_beam.pvalue import AsIter
    53  from apache_beam.pvalue import AsSingleton
    54  from apache_beam.transforms import Impulse
    55  from apache_beam.transforms import PTransform
    56  from apache_beam.transforms import core
    57  from apache_beam.transforms import ptransform
    58  from apache_beam.transforms import window
    59  from apache_beam.transforms.display import DisplayDataItem
    60  from apache_beam.transforms.display import HasDisplayData
    61  from apache_beam.utils import timestamp
    62  from apache_beam.utils import urns
    63  from apache_beam.utils.windowed_value import WindowedValue
    64  
    65  __all__ = [
    66      'BoundedSource',
    67      'RangeTracker',
    68      'Read',
    69      'RestrictionProgress',
    70      'RestrictionTracker',
    71      'WatermarkEstimator',
    72      'Sink',
    73      'Write',
    74      'Writer'
    75  ]
    76  
    77  _LOGGER = logging.getLogger(__name__)
    78  
    79  # Encapsulates information about a bundle of a source generated when method
    80  # BoundedSource.split() is invoked.
    81  # This is a named 4-tuple that has following fields.
    82  # * weight - a number that represents the size of the bundle. This value will
    83  #            be used to compare the relative sizes of bundles generated by the
    84  #            current source.
    85  #            The weight returned here could be specified using a unit of your
    86  #            choice (for example, bundles of sizes 100MB, 200MB, and 700MB may
    87  #            specify weights 100, 200, 700 or 1, 2, 7) but all bundles of a
    88  #            source should specify the weight using the same unit.
    89  # * source - a BoundedSource object for the  bundle.
    90  # * start_position - starting position of the bundle
    91  # * stop_position - ending position of the bundle.
    92  #
    93  # Type for start and stop positions are specific to the bounded source and must
    94  # be consistent throughout.
    95  SourceBundle = namedtuple(
    96      'SourceBundle', 'weight source start_position stop_position')
    97  
    98  
    99  class SourceBase(HasDisplayData, urns.RunnerApiFn):
   100    """Base class for all sources that can be passed to beam.io.Read(...).
   101    """
   102    urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE)
   103  
   104    def is_bounded(self):
   105      # type: () -> bool
   106      raise NotImplementedError
   107  
   108  
   109  class BoundedSource(SourceBase):
   110    """A source that reads a finite amount of input records.
   111  
   112    This class defines following operations which can be used to read the source
   113    efficiently.
   114  
   115    * Size estimation - method ``estimate_size()`` may return an accurate
   116      estimation in bytes for the size of the source.
   117    * Splitting into bundles of a given size - method ``split()`` can be used to
   118      split the source into a set of sub-sources (bundles) based on a desired
   119      bundle size.
   120    * Getting a RangeTracker - method ``get_range_tracker()`` should return a
   121      ``RangeTracker`` object for a given position range for the position type
   122      of the records returned by the source.
   123    * Reading the data - method ``read()`` can be used to read data from the
   124      source while respecting the boundaries defined by a given
   125      ``RangeTracker``.
   126  
   127    A runner will perform reading the source in two steps.
   128  
   129    (1) Method ``get_range_tracker()`` will be invoked with start and end
   130        positions to obtain a ``RangeTracker`` for the range of positions the
   131        runner intends to read. Source must define a default initial start and end
   132        position range. These positions must be used if the start and/or end
   133        positions passed to the method ``get_range_tracker()`` are ``None``
   134    (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
   135        previous step.
   136  
   137    **Mutability**
   138  
   139    A ``BoundedSource`` object should not be mutated while
   140    its methods (for example, ``read()``) are being invoked by a runner. Runner
   141    implementations may invoke methods of ``BoundedSource`` objects through
   142    multi-threaded and/or reentrant execution modes.
   143    """
   144    def estimate_size(self):
   145      # type: () -> Optional[int]
   146  
   147      """Estimates the size of source in bytes.
   148  
   149      An estimate of the total size (in bytes) of the data that would be read
   150      from this source. This estimate is in terms of external storage size,
   151      before performing decompression or other processing.
   152  
   153      Returns:
   154        estimated size of the source if the size can be determined, ``None``
   155        otherwise.
   156      """
   157      raise NotImplementedError
   158  
   159    def split(self,
   160              desired_bundle_size,  # type: int
   161              start_position=None,  # type: Optional[Any]
   162              stop_position=None,  # type: Optional[Any]
   163             ):
   164      # type: (...) -> Iterator[SourceBundle]
   165  
   166      """Splits the source into a set of bundles.
   167  
   168      Bundles should be approximately of size ``desired_bundle_size`` bytes.
   169  
   170      Args:
   171        desired_bundle_size: the desired size (in bytes) of the bundles returned.
   172        start_position: if specified the given position must be used as the
   173                        starting position of the first bundle.
   174        stop_position: if specified the given position must be used as the ending
   175                       position of the last bundle.
   176      Returns:
   177        an iterator of objects of type 'SourceBundle' that gives information about
   178        the generated bundles.
   179      """
   180      raise NotImplementedError
   181  
   182    def get_range_tracker(self,
   183                          start_position,  # type: Optional[Any]
   184                          stop_position,  # type: Optional[Any]
   185                         ):
   186      # type: (...) -> RangeTracker
   187  
   188      """Returns a RangeTracker for a given position range.
   189  
   190      Framework may invoke ``read()`` method with the RangeTracker object returned
   191      here to read data from the source.
   192  
   193      Args:
   194        start_position: starting position of the range. If 'None' default start
   195                        position of the source must be used.
   196        stop_position:  ending position of the range. If 'None' default stop
   197                        position of the source must be used.
   198      Returns:
   199        a ``RangeTracker`` for the given position range.
   200      """
   201      raise NotImplementedError
   202  
   203    def read(self, range_tracker):
   204      """Returns an iterator that reads data from the source.
   205  
   206      The returned set of data must respect the boundaries defined by the given
   207      ``RangeTracker`` object. For example:
   208  
   209        * Returned set of data must be for the range
   210          ``[range_tracker.start_position, range_tracker.stop_position)``. Note
   211          that a source may decide to return records that start after
   212          ``range_tracker.stop_position``. See documentation in class
   213          ``RangeTracker`` for more details. Also, note that framework might
   214          invoke ``range_tracker.try_split()`` to perform dynamic split
   215          operations. range_tracker.stop_position may be updated
   216          dynamically due to successful dynamic split operations.
   217        * Method ``range_tracker.try_split()`` must be invoked for every record
   218          that starts at a split point.
   219        * Method ``range_tracker.record_current_position()`` may be invoked for
   220          records that do not start at split points.
   221  
   222      Args:
   223        range_tracker: a ``RangeTracker`` whose boundaries must be respected
   224                       when reading data from the source. A runner that reads this
   225                       source muss pass a ``RangeTracker`` object that is not
   226                       ``None``.
   227      Returns:
   228        an iterator of data read by the source.
   229      """
   230      raise NotImplementedError
   231  
   232    def default_output_coder(self):
   233      """Coder that should be used for the records returned by the source.
   234  
   235      Should be overridden by sources that produce objects that can be encoded
   236      more efficiently than pickling.
   237      """
   238      return coders.registry.get_coder(object)
   239  
   240    def is_bounded(self):
   241      return True
   242  
   243  
   244  class RangeTracker(object):
   245    """A thread safe object used by Dataflow source framework.
   246  
   247    A Dataflow source is defined using a ''BoundedSource'' and a ''RangeTracker''
   248    pair. A ''RangeTracker'' is used by Dataflow source framework to perform
   249    dynamic work rebalancing of position-based sources.
   250  
   251    **Position-based sources**
   252  
   253    A position-based source is one where the source can be described by a range
   254    of positions of an ordered type and the records returned by the reader can be
   255    described by positions of the same type.
   256  
   257    In case a record occupies a range of positions in the source, the most
   258    important thing about the record is the position where it starts.
   259  
   260    Defining the semantics of positions for a source is entirely up to the source
   261    class, however the chosen definitions have to obey certain properties in order
   262    to make it possible to correctly split the source into parts, including
   263    dynamic splitting. Two main aspects need to be defined:
   264  
   265    1. How to assign starting positions to records.
   266    2. Which records should be read by a source with a range '[A, B)'.
   267  
   268    Moreover, reading a range must be *efficient*, i.e., the performance of
   269    reading a range should not significantly depend on the location of the range.
   270    For example, reading the range [A, B) should not require reading all data
   271    before 'A'.
   272  
   273    The sections below explain exactly what properties these definitions must
   274    satisfy, and how to use a ``RangeTracker`` with a properly defined source.
   275  
   276    **Properties of position-based sources**
   277  
   278    The main requirement for position-based sources is *associativity*: reading
   279    records from '[A, B)' and records from '[B, C)' should give the same
   280    records as reading from '[A, C)', where 'A <= B <= C'. This property
   281    ensures that no matter how a range of positions is split into arbitrarily many
   282    sub-ranges, the total set of records described by them stays the same.
   283  
   284    The other important property is how the source's range relates to positions of
   285    records in the source. In many sources each record can be identified by a
   286    unique starting position. In this case:
   287  
   288    * All records returned by a source '[A, B)' must have starting positions in
   289      this range.
   290    * All but the last record should end within this range. The last record may or
   291      may not extend past the end of the range.
   292    * Records should not overlap.
   293  
   294    Such sources should define "read '[A, B)'" as "read from the first record
   295    starting at or after 'A', up to but not including the first record starting
   296    at or after 'B'".
   297  
   298    Some examples of such sources include reading lines or CSV from a text file,
   299    reading keys and values from a BigTable, etc.
   300  
   301    The concept of *split points* allows to extend the definitions for dealing
   302    with sources where some records cannot be identified by a unique starting
   303    position.
   304  
   305    In all cases, all records returned by a source '[A, B)' must *start* at or
   306    after 'A'.
   307  
   308    **Split points**
   309  
   310    Some sources may have records that are not directly addressable. For example,
   311    imagine a file format consisting of a sequence of compressed blocks. Each
   312    block can be assigned an offset, but records within the block cannot be
   313    directly addressed without decompressing the block. Let us refer to this
   314    hypothetical format as <i>CBF (Compressed Blocks Format)</i>.
   315  
   316    Many such formats can still satisfy the associativity property. For example,
   317    in CBF, reading '[A, B)' can mean "read all the records in all blocks whose
   318    starting offset is in '[A, B)'".
   319  
   320    To support such complex formats, we introduce the notion of *split points*. We
   321    say that a record is a split point if there exists a position 'A' such that
   322    the record is the first one to be returned when reading the range
   323    '[A, infinity)'. In CBF, the only split points would be the first records
   324    in each block.
   325  
   326    Split points allow us to define the meaning of a record's position and a
   327    source's range in all cases:
   328  
   329    * For a record that is at a split point, its position is defined to be the
   330      largest 'A' such that reading a source with the range '[A, infinity)'
   331      returns this record.
   332    * Positions of other records are only required to be non-decreasing.
   333    * Reading the source '[A, B)' must return records starting from the first
   334      split point at or after 'A', up to but not including the first split point
   335      at or after 'B'. In particular, this means that the first record returned
   336      by a source MUST always be a split point.
   337    * Positions of split points must be unique.
   338  
   339    As a result, for any decomposition of the full range of the source into
   340    position ranges, the total set of records will be the full set of records in
   341    the source, and each record will be read exactly once.
   342  
   343    **Consumed positions**
   344  
   345    As the source is being read, and records read from it are being passed to the
   346    downstream transforms in the pipeline, we say that positions in the source are
   347    being *consumed*. When a reader has read a record (or promised to a caller
   348    that a record will be returned), positions up to and including the record's
   349    start position are considered *consumed*.
   350  
   351    Dynamic splitting can happen only at *unconsumed* positions. If the reader
   352    just returned a record at offset 42 in a file, dynamic splitting can happen
   353    only at offset 43 or beyond, as otherwise that record could be read twice (by
   354    the current reader and by a reader of the task starting at 43).
   355    """
   356  
   357    SPLIT_POINTS_UNKNOWN = object()
   358  
   359    def start_position(self):
   360      """Returns the starting position of the current range, inclusive."""
   361      raise NotImplementedError(type(self))
   362  
   363    def stop_position(self):
   364      """Returns the ending position of the current range, exclusive."""
   365      raise NotImplementedError(type(self))
   366  
   367    def try_claim(self, position):  # pylint: disable=unused-argument
   368      """Atomically determines if a record at a split point is within the range.
   369  
   370      This method should be called **if and only if** the record is at a split
   371      point. This method may modify the internal state of the ``RangeTracker`` by
   372      updating the last-consumed position to ``position``.
   373  
   374      ** Thread safety **
   375  
   376      Methods of the class ``RangeTracker`` including this method may get invoked
   377      by different threads, hence must be made thread-safe, e.g. by using a single
   378      lock object.
   379  
   380      Args:
   381        position: starting position of a record being read by a source.
   382  
   383      Returns:
   384        ``True``, if the given position falls within the current range, returns
   385        ``False`` otherwise.
   386      """
   387      raise NotImplementedError
   388  
   389    def set_current_position(self, position):
   390      """Updates the last-consumed position to the given position.
   391  
   392      A source may invoke this method for records that do not start at split
   393      points. This may modify the internal state of the ``RangeTracker``. If the
   394      record starts at a split point, method ``try_claim()`` **must** be invoked
   395      instead of this method.
   396  
   397      Args:
   398        position: starting position of a record being read by a source.
   399      """
   400      raise NotImplementedError
   401  
   402    def position_at_fraction(self, fraction):
   403      """Returns the position at the given fraction.
   404  
   405      Given a fraction within the range [0.0, 1.0) this method will return the
   406      position at the given fraction compared to the position range
   407      [self.start_position, self.stop_position).
   408  
   409      ** Thread safety **
   410  
   411      Methods of the class ``RangeTracker`` including this method may get invoked
   412      by different threads, hence must be made thread-safe, e.g. by using a single
   413      lock object.
   414  
   415      Args:
   416        fraction: a float value within the range [0.0, 1.0).
   417      Returns:
   418        a position within the range [self.start_position, self.stop_position).
   419      """
   420      raise NotImplementedError
   421  
   422    def try_split(self, position):
   423      """Atomically splits the current range.
   424  
   425      Determines a position to split the current range, split_position, based on
   426      the given position. In most cases split_position and position will be the
   427      same.
   428  
   429      Splits the current range '[self.start_position, self.stop_position)'
   430      into a "primary" part '[self.start_position, split_position)' and a
   431      "residual" part '[split_position, self.stop_position)', assuming the
   432      current last-consumed position is within
   433      '[self.start_position, split_position)' (i.e., split_position has not been
   434      consumed yet).
   435  
   436      If successful, updates the current range to be the primary and returns a
   437      tuple (split_position, split_fraction). split_fraction should be the
   438      fraction of size of range '[self.start_position, split_position)' compared
   439      to the original (before split) range
   440      '[self.start_position, self.stop_position)'.
   441  
   442      If the split_position has already been consumed, returns ``None``.
   443  
   444      ** Thread safety **
   445  
   446      Methods of the class ``RangeTracker`` including this method may get invoked
   447      by different threads, hence must be made thread-safe, e.g. by using a single
   448      lock object.
   449  
   450      Args:
   451        position: suggested position where the current range should try to
   452                  be split at.
   453      Returns:
   454        a tuple containing the split position and split fraction if split is
   455        successful. Returns ``None`` otherwise.
   456      """
   457      raise NotImplementedError
   458  
   459    def fraction_consumed(self):
   460      """Returns the approximate fraction of consumed positions in the source.
   461  
   462      ** Thread safety **
   463  
   464      Methods of the class ``RangeTracker`` including this method may get invoked
   465      by different threads, hence must be made thread-safe, e.g. by using a single
   466      lock object.
   467  
   468      Returns:
   469        the approximate fraction of positions that have been consumed by
   470        successful 'try_split()' and  'try_claim()'  calls, or
   471        0.0 if no such calls have happened.
   472      """
   473      raise NotImplementedError
   474  
   475    def split_points(self):
   476      """Gives the number of split points consumed and remaining.
   477  
   478      For a ``RangeTracker`` used by a ``BoundedSource`` (within a
   479      ``BoundedSource.read()`` invocation) this method produces a 2-tuple that
   480      gives the number of split points consumed by the ``BoundedSource`` and the
   481      number of split points remaining within the range of the ``RangeTracker``
   482      that has not been consumed by the ``BoundedSource``.
   483  
   484      More specifically, given that the position of the current record being read
   485      by ``BoundedSource`` is current_position this method produces a tuple that
   486      consists of
   487      (1) number of split points in the range [self.start_position(),
   488      current_position) without including the split point that is currently being
   489      consumed. This represents the total amount of parallelism in the consumed
   490      part of the source.
   491      (2) number of split points within the range
   492      [current_position, self.stop_position()) including the split point that is
   493      currently being consumed. This represents the total amount of parallelism in
   494      the unconsumed part of the source.
   495  
   496      Methods of the class ``RangeTracker`` including this method may get invoked
   497      by different threads, hence must be made thread-safe, e.g. by using a single
   498      lock object.
   499  
   500      ** General information about consumed and remaining number of split
   501         points returned by this method. **
   502  
   503        * Before a source read (``BoundedSource.read()`` invocation) claims the
   504          first split point, number of consumed split points is 0. This condition
   505          holds independent of whether the input is "splittable". A splittable
   506          source is a source that has more than one split point.
   507        * Any source read that has only claimed one split point has 0 consumed
   508          split points since the first split point is the current split point and
   509          is still being processed. This condition holds independent of whether
   510          the input is splittable.
   511        * For an empty source read which never invokes
   512          ``RangeTracker.try_claim()``, the consumed number of split points is 0.
   513          This condition holds independent of whether the input is splittable.
   514        * For a source read which has invoked ``RangeTracker.try_claim()`` n
   515          times, the consumed number of split points is  n -1.
   516        * If a ``BoundedSource`` sets a callback through function
   517          ``set_split_points_unclaimed_callback()``, ``RangeTracker`` can use that
   518          callback when determining remaining number of split points.
   519        * Remaining split points should include the split point that is currently
   520          being consumed by the source read. Hence if the above callback returns
   521          an integer value n, remaining number of split points should be (n + 1).
   522        * After last split point is claimed remaining split points becomes 1,
   523          because this unfinished read itself represents an  unfinished split
   524          point.
   525        * After all records of the source has been consumed, remaining number of
   526          split points becomes 0 and consumed number of split points becomes equal
   527          to the total number of split points within the range being read by the
   528          source. This method does not address this condition and will continue to
   529          report number of consumed split points as
   530          ("total number of split points" - 1) and number of remaining split
   531          points as 1. A runner that performs the reading of the source can
   532          detect when all records have been consumed and adjust remaining and
   533          consumed number of split points accordingly.
   534  
   535      ** Examples **
   536  
   537      (1) A "perfectly splittable" input which can be read in parallel down to the
   538          individual records.
   539  
   540          Consider a perfectly splittable input that consists of 50 split points.
   541  
   542        * Before a source read (``BoundedSource.read()`` invocation) claims the
   543          first split point, number of consumed split points is 0 number of
   544          remaining split points is 50.
   545        * After claiming first split point, consumed number of split points is 0
   546          and remaining number of split is 50.
   547        * After claiming split point #30, consumed number of split points is 29
   548          and remaining number of split points is 21.
   549        * After claiming all 50 split points, consumed number of split points is
   550          49 and remaining number of split points is 1.
   551  
   552      (2) a "block-compressed" file format such as ``avroio``, in which a block of
   553          records has to be read as a whole, but different blocks can be read in
   554          parallel.
   555  
   556          Consider a block compressed input that consists of 5 blocks.
   557  
   558        * Before a source read (``BoundedSource.read()`` invocation) claims the
   559          first split point (first block), number of consumed split points is 0
   560          number of remaining split points is 5.
   561        * After claiming first split point, consumed number of split points is 0
   562          and remaining number of split is 5.
   563        * After claiming split point #3, consumed number of split points is 2
   564          and remaining number of split points is 3.
   565        * After claiming all 5 split points, consumed number of split points is
   566          4 and remaining number of split points is 1.
   567  
   568      (3) an "unsplittable" input such as a cursor in a database or a gzip
   569          compressed file.
   570  
   571          Such an input is considered to have only a single split point. Number of
   572          consumed split points is always 0 and number of remaining split points
   573          is always 1.
   574  
   575      By default ``RangeTracker` returns ``RangeTracker.SPLIT_POINTS_UNKNOWN`` for
   576      both consumed and remaining number of split points, which indicates that the
   577      number of split points consumed and remaining is unknown.
   578  
   579      Returns:
   580        A pair that gives consumed and remaining number of split points. Consumed
   581        number of split points should be an integer larger than or equal to zero
   582        or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. Remaining number of split points
   583        should be an integer larger than zero or
   584        ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
   585      """
   586      return (
   587          RangeTracker.SPLIT_POINTS_UNKNOWN, RangeTracker.SPLIT_POINTS_UNKNOWN)
   588  
   589    def set_split_points_unclaimed_callback(self, callback):
   590      """Sets a callback for determining the unclaimed number of split points.
   591  
   592      By invoking this function, a ``BoundedSource`` can set a callback function
   593      that may get invoked by the ``RangeTracker`` to determine the number of
   594      unclaimed split points. A split point is unclaimed if
   595      ``RangeTracker.try_claim()`` method has not been successfully invoked for
   596      that particular split point. The callback function accepts a single
   597      parameter, a stop position for the BoundedSource (stop_position). If the
   598      record currently being consumed by the ``BoundedSource`` is at position
   599      current_position, callback should return the number of split points within
   600      the range (current_position, stop_position). Note that, this should not
   601      include the split point that is currently being consumed by the source.
   602  
   603      This function must be implemented by subclasses before being used.
   604  
   605      Args:
   606        callback: a function that takes a single parameter, a stop position,
   607                  and returns unclaimed number of split points for the source read
   608                  operation that is calling this function. Value returned from
   609                  callback should be either an integer larger than or equal to
   610                  zero or ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
   611      """
   612      raise NotImplementedError
   613  
   614  
   615  class Sink(HasDisplayData):
   616    """This class is deprecated, no backwards-compatibility guarantees.
   617  
   618    A resource that can be written to using the ``beam.io.Write`` transform.
   619  
   620    Here ``beam`` stands for Apache Beam Python code imported in following manner.
   621    ``import apache_beam as beam``.
   622  
   623    A parallel write to an ``iobase.Sink`` consists of three phases:
   624  
   625    1. A sequential *initialization* phase (e.g., creating a temporary output
   626       directory, etc.)
   627    2. A parallel write phase where workers write *bundles* of records
   628    3. A sequential *finalization* phase (e.g., committing the writes, merging
   629       output files, etc.)
   630  
   631    Implementing a new sink requires extending two classes.
   632  
   633    1. iobase.Sink
   634  
   635    ``iobase.Sink`` is an immutable logical description of the location/resource
   636    to write to. Depending on the type of sink, it may contain fields such as the
   637    path to an output directory on a filesystem, a database table name,
   638    etc. ``iobase.Sink`` provides methods for performing a write operation to the
   639    sink described by it. To this end, implementors of an extension of
   640    ``iobase.Sink`` must implement three methods:
   641    ``initialize_write()``, ``open_writer()``, and ``finalize_write()``.
   642  
   643    2. iobase.Writer
   644  
   645    ``iobase.Writer`` is used to write a single bundle of records. An
   646    ``iobase.Writer`` defines two methods: ``write()`` which writes a
   647    single record from the bundle and ``close()`` which is called once
   648    at the end of writing a bundle.
   649  
   650    See also ``apache_beam.io.filebasedsink.FileBasedSink`` which provides a
   651    simpler API for writing sinks that produce files.
   652  
   653    **Execution of the Write transform**
   654  
   655    ``initialize_write()``, ``pre_finalize()``, and ``finalize_write()`` are
   656    conceptually called once. However, implementors must
   657    ensure that these methods are *idempotent*, as they may be called multiple
   658    times on different machines in the case of failure/retry. A method may be
   659    called more than once concurrently, in which case it's okay to have a
   660    transient failure (such as due to a race condition). This failure should not
   661    prevent subsequent retries from succeeding.
   662  
   663    ``initialize_write()`` should perform any initialization that needs to be done
   664    prior to writing to the sink. ``initialize_write()`` may return a result
   665    (let's call this ``init_result``) that contains any parameters it wants to
   666    pass on to its writers about the sink. For example, a sink that writes to a
   667    file system may return an ``init_result`` that contains a dynamically
   668    generated unique directory to which data should be written.
   669  
   670    To perform writing of a bundle of elements, Dataflow execution engine will
   671    create an ``iobase.Writer`` using the implementation of
   672    ``iobase.Sink.open_writer()``. When invoking ``open_writer()`` execution
   673    engine will provide the ``init_result`` returned by ``initialize_write()``
   674    invocation as well as a *bundle id* (let's call this ``bundle_id``) that is
   675    unique for each invocation of ``open_writer()``.
   676  
   677    Execution engine will then invoke ``iobase.Writer.write()`` implementation for
   678    each element that has to be written. Once all elements of a bundle are
   679    written, execution engine will invoke ``iobase.Writer.close()`` implementation
   680    which should return a result (let's call this ``write_result``) that contains
   681    information that encodes the result of the write and, in most cases, some
   682    encoding of the unique bundle id. For example, if each bundle is written to a
   683    unique temporary file, ``close()`` method may return an object that contains
   684    the temporary file name. After writing of all bundles is complete, execution
   685    engine will invoke ``pre_finalize()`` and then ``finalize_write()``
   686    implementation.
   687  
   688    The execution of a write transform can be illustrated using following pseudo
   689    code (assume that the outer for loop happens in parallel across many
   690    machines)::
   691  
   692      init_result = sink.initialize_write()
   693      write_results = []
   694      for bundle in partition(pcoll):
   695        writer = sink.open_writer(init_result, generate_bundle_id())
   696        for elem in bundle:
   697          writer.write(elem)
   698        write_results.append(writer.close())
   699      pre_finalize_result = sink.pre_finalize(init_result, write_results)
   700      sink.finalize_write(init_result, write_results, pre_finalize_result)
   701  
   702  
   703    **init_result**
   704  
   705    Methods of 'iobase.Sink' should agree on the 'init_result' type that will be
   706    returned when initializing the sink. This type can be a client-defined object
   707    or an existing type. The returned type must be picklable using Dataflow coder
   708    ``coders.PickleCoder``. Returning an init_result is optional.
   709  
   710    **bundle_id**
   711  
   712    In order to ensure fault-tolerance, a bundle may be executed multiple times
   713    (e.g., in the event of failure/retry or for redundancy). However, exactly one
   714    of these executions will have its result passed to the
   715    ``iobase.Sink.finalize_write()`` method. Each call to
   716    ``iobase.Sink.open_writer()`` is passed a unique bundle id when it is called
   717    by the ``WriteImpl`` transform, so even redundant or retried bundles will have
   718    a unique way of identifying their output.
   719  
   720    The bundle id should be used to guarantee that a bundle's output is unique.
   721    This uniqueness guarantee is important; if a bundle is to be output to a file,
   722    for example, the name of the file must be unique to avoid conflicts with other
   723    writers. The bundle id should be encoded in the writer result returned by the
   724    writer and subsequently used by the ``finalize_write()`` method to identify
   725    the results of successful writes.
   726  
   727    For example, consider the scenario where a Writer writes files containing
   728    serialized records and the ``finalize_write()`` is to merge or rename these
   729    output files. In this case, a writer may use its unique id to name its output
   730    file (to avoid conflicts) and return the name of the file it wrote as its
   731    writer result. The ``finalize_write()`` will then receive an ``Iterable`` of
   732    output file names that it can then merge or rename using some bundle naming
   733    scheme.
   734  
   735    **write_result**
   736  
   737    ``iobase.Writer.close()`` and ``finalize_write()`` implementations must agree
   738    on type of the ``write_result`` object returned when invoking
   739    ``iobase.Writer.close()``. This type can be a client-defined object or
   740    an existing type. The returned type must be picklable using Dataflow coder
   741    ``coders.PickleCoder``. Returning a ``write_result`` when
   742    ``iobase.Writer.close()`` is invoked is optional but if unique
   743    ``write_result`` objects are not returned, sink should, guarantee idempotency
   744    when same bundle is written multiple times due to failure/retry or redundancy.
   745  
   746  
   747    **More information**
   748  
   749    For more information on creating new sinks please refer to the official
   750    documentation at
   751    ``https://beam.apache.org/documentation/sdks/python-custom-io#creating-sinks``
   752    """
   753    # Whether Beam should skip writing any shards if all are empty.
   754    skip_if_empty = False
   755  
   756    def initialize_write(self):
   757      """Initializes the sink before writing begins.
   758  
   759      Invoked before any data is written to the sink.
   760  
   761  
   762      Please see documentation in ``iobase.Sink`` for an example.
   763  
   764      Returns:
   765        An object that contains any sink specific state generated by
   766        initialization. This object will be passed to open_writer() and
   767        finalize_write() methods.
   768      """
   769      raise NotImplementedError
   770  
   771    def open_writer(self, init_result, uid):
   772      """Opens a writer for writing a bundle of elements to the sink.
   773  
   774      Args:
   775        init_result: the result of initialize_write() invocation.
   776        uid: a unique identifier generated by the system.
   777      Returns:
   778        an ``iobase.Writer`` that can be used to write a bundle of records to the
   779        current sink.
   780      """
   781      raise NotImplementedError
   782  
   783    def pre_finalize(self, init_result, writer_results):
   784      """Pre-finalization stage for sink.
   785  
   786      Called after all bundle writes are complete and before finalize_write.
   787      Used to setup and verify filesystem and sink states.
   788  
   789      Args:
   790        init_result: the result of ``initialize_write()`` invocation.
   791        writer_results: an iterable containing results of ``Writer.close()``
   792          invocations. This will only contain results of successful writes, and
   793          will only contain the result of a single successful write for a given
   794          bundle.
   795  
   796      Returns:
   797        An object that contains any sink specific state generated.
   798        This object will be passed to finalize_write().
   799      """
   800      raise NotImplementedError
   801  
   802    def finalize_write(self, init_result, writer_results, pre_finalize_result):
   803      """Finalizes the sink after all data is written to it.
   804  
   805      Given the result of initialization and an iterable of results from bundle
   806      writes, performs finalization after writing and closes the sink. Called
   807      after all bundle writes are complete.
   808  
   809      The bundle write results that are passed to finalize are those returned by
   810      bundles that completed successfully. Although bundles may have been run
   811      multiple times (for fault-tolerance), only one writer result will be passed
   812      to finalize for each bundle. An implementation of finalize should perform
   813      clean up of any failed and successfully retried bundles.  Note that these
   814      failed bundles will not have their writer result passed to finalize, so
   815      finalize should be capable of locating any temporary/partial output written
   816      by failed bundles.
   817  
   818      If all retries of a bundle fails, the whole pipeline will fail *without*
   819      finalize_write() being invoked.
   820  
   821      A best practice is to make finalize atomic. If this is impossible given the
   822      semantics of the sink, finalize should be idempotent, as it may be called
   823      multiple times in the case of failure/retry or for redundancy.
   824  
   825      Note that the iteration order of the writer results is not guaranteed to be
   826      consistent if finalize is called multiple times.
   827  
   828      Args:
   829        init_result: the result of ``initialize_write()`` invocation.
   830        writer_results: an iterable containing results of ``Writer.close()``
   831          invocations. This will only contain results of successful writes, and
   832          will only contain the result of a single successful write for a given
   833          bundle.
   834        pre_finalize_result: the result of ``pre_finalize()`` invocation.
   835      """
   836      raise NotImplementedError
   837  
   838  
   839  class Writer(object):
   840    """This class is deprecated, no backwards-compatibility guarantees.
   841  
   842    Writes a bundle of elements from a ``PCollection`` to a sink.
   843  
   844    A Writer  ``iobase.Writer.write()`` writes and elements to the sink while
   845    ``iobase.Writer.close()`` is called after all elements in the bundle have been
   846    written.
   847  
   848    See ``iobase.Sink`` for more detailed documentation about the process of
   849    writing to a sink.
   850    """
   851    def write(self, value):
   852      """Writes a value to the sink using the current writer.
   853      """
   854      raise NotImplementedError
   855  
   856    def close(self):
   857      """Closes the current writer.
   858  
   859      Please see documentation in ``iobase.Sink`` for an example.
   860  
   861      Returns:
   862        An object representing the writes that were performed by the current
   863        writer.
   864      """
   865      raise NotImplementedError
   866  
   867    def at_capacity(self) -> bool:
   868      """Returns whether this writer should be considered at capacity
   869      and a new one should be created.
   870      """
   871      return False
   872  
   873  
   874  class Read(ptransform.PTransform):
   875    """A transform that reads a PCollection."""
   876    # Import runners here to prevent circular imports
   877    from apache_beam.runners.pipeline_context import PipelineContext
   878  
   879    def __init__(self, source):
   880      # type: (SourceBase) -> None
   881  
   882      """Initializes a Read transform.
   883  
   884      Args:
   885        source: Data source to read from.
   886      """
   887      super().__init__()
   888      self.source = source
   889  
   890    @staticmethod
   891    def get_desired_chunk_size(total_size):
   892      if total_size:
   893        # 1MB = 1 shard, 1GB = 32 shards, 1TB = 1000 shards, 1PB = 32k shards
   894        chunk_size = max(1 << 20, 1000 * int(math.sqrt(total_size)))
   895      else:
   896        chunk_size = 64 << 20  # 64mb
   897      return chunk_size
   898  
   899    def expand(self, pbegin):
   900      if isinstance(self.source, BoundedSource):
   901        coders.registry.register_coder(BoundedSource, _MemoizingPickleCoder)
   902        display_data = self.source.display_data() or {}
   903        display_data['source'] = self.source.__class__
   904  
   905        return (
   906            pbegin
   907            | Impulse()
   908            | core.Map(lambda _: self.source).with_output_types(BoundedSource)
   909            | SDFBoundedSourceReader(display_data))
   910      elif isinstance(self.source, ptransform.PTransform):
   911        # The Read transform can also admit a full PTransform as an input
   912        # rather than an anctual source. If the input is a PTransform, then
   913        # just apply it directly.
   914        return pbegin.pipeline | self.source
   915      else:
   916        # Treat Read itself as a primitive.
   917        return pvalue.PCollection(
   918            pbegin.pipeline, is_bounded=self.source.is_bounded())
   919  
   920    def get_windowing(self, unused_inputs):
   921      # type: (...) -> core.Windowing
   922      return core.Windowing(window.GlobalWindows())
   923  
   924    def _infer_output_coder(self, input_type=None, input_coder=None):
   925      # type: (...) -> Optional[coders.Coder]
   926      from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
   927      if isinstance(self.source, BoundedSource):
   928        return self.source.default_output_coder()
   929      elif isinstance(self.source, dataflow_io.NativeSource):
   930        return self.source.coder
   931      else:
   932        return None
   933  
   934    def display_data(self):
   935      return {
   936          'source': DisplayDataItem(self.source.__class__, label='Read Source'),
   937          'source_dd': self.source
   938      }
   939  
   940    def to_runner_api_parameter(
   941        self,
   942        context: PipelineContext,
   943    ) -> Tuple[str, Any]:
   944      from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
   945      if isinstance(self.source, (BoundedSource, dataflow_io.NativeSource)):
   946        from apache_beam.io.gcp.pubsub import _PubSubSource
   947        if isinstance(self.source, _PubSubSource):
   948          return (
   949              common_urns.composites.PUBSUB_READ.urn,
   950              beam_runner_api_pb2.PubSubReadPayload(
   951                  topic=self.source.full_topic,
   952                  subscription=self.source.full_subscription,
   953                  timestamp_attribute=self.source.timestamp_attribute,
   954                  with_attributes=self.source.with_attributes,
   955                  id_attribute=self.source.id_label))
   956        return (
   957            common_urns.deprecated_primitives.READ.urn,
   958            beam_runner_api_pb2.ReadPayload(
   959                source=self.source.to_runner_api(context),
   960                is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED
   961                if self.source.is_bounded() else
   962                beam_runner_api_pb2.IsBounded.UNBOUNDED))
   963      elif isinstance(self.source, ptransform.PTransform):
   964        return self.source.to_runner_api_parameter(context)
   965      raise NotImplementedError(
   966          "to_runner_api_parameter not "
   967          "implemented for type")
   968  
   969    @staticmethod
   970    def from_runner_api_parameter(
   971        transform: beam_runner_api_pb2.PTransform,
   972        payload: Union[beam_runner_api_pb2.ReadPayload,
   973                       beam_runner_api_pb2.PubSubReadPayload],
   974        context: PipelineContext,
   975    ) -> "Read":
   976      if transform.spec.urn == common_urns.composites.PUBSUB_READ.urn:
   977        assert isinstance(payload, beam_runner_api_pb2.PubSubReadPayload)
   978        # Importing locally to prevent circular dependencies.
   979        from apache_beam.io.gcp.pubsub import _PubSubSource
   980        source = _PubSubSource(
   981            topic=payload.topic or None,
   982            subscription=payload.subscription or None,
   983            id_label=payload.id_attribute or None,
   984            with_attributes=payload.with_attributes,
   985            timestamp_attribute=payload.timestamp_attribute or None)
   986        return Read(source)
   987      else:
   988        assert isinstance(payload, beam_runner_api_pb2.ReadPayload)
   989        return Read(SourceBase.from_runner_api(payload.source, context))
   990  
   991    @staticmethod
   992    def _from_runner_api_parameter_read(
   993        transform: beam_runner_api_pb2.PTransform,
   994        payload: beam_runner_api_pb2.ReadPayload,
   995        context: PipelineContext,
   996    ) -> "Read":
   997      """Method for type proxying when calling register_urn due to limitations
   998       in type exprs in Python"""
   999      return Read.from_runner_api_parameter(transform, payload, context)
  1000  
  1001    @staticmethod
  1002    def _from_runner_api_parameter_pubsub_read(
  1003        transform: beam_runner_api_pb2.PTransform,
  1004        payload: beam_runner_api_pb2.PubSubReadPayload,
  1005        context: PipelineContext,
  1006    ) -> "Read":
  1007      """Method for type proxying when calling register_urn due to limitations
  1008       in type exprs in Python"""
  1009      return Read.from_runner_api_parameter(transform, payload, context)
  1010  
  1011  
  1012  ptransform.PTransform.register_urn(
  1013      common_urns.deprecated_primitives.READ.urn,
  1014      beam_runner_api_pb2.ReadPayload,
  1015      Read._from_runner_api_parameter_read,
  1016  )
  1017  
  1018  ptransform.PTransform.register_urn(
  1019      common_urns.composites.PUBSUB_READ.urn,
  1020      beam_runner_api_pb2.PubSubReadPayload,
  1021      Read._from_runner_api_parameter_pubsub_read,
  1022  )
  1023  
  1024  
  1025  class Write(ptransform.PTransform):
  1026    """A ``PTransform`` that writes to a sink.
  1027  
  1028    A sink should inherit ``iobase.Sink``. Such implementations are
  1029    handled using a composite transform that consists of three ``ParDo``s -
  1030    (1) a ``ParDo`` performing a global initialization (2) a ``ParDo`` performing
  1031    a parallel write and (3) a ``ParDo`` performing a global finalization. In the
  1032    case of an empty ``PCollection``, only the global initialization and
  1033    finalization will be performed. Currently only batch workflows support custom
  1034    sinks.
  1035  
  1036    Example usage::
  1037  
  1038        pcollection | beam.io.Write(MySink())
  1039  
  1040    This returns a ``pvalue.PValue`` object that represents the end of the
  1041    Pipeline.
  1042  
  1043    The sink argument may also be a full PTransform, in which case it will be
  1044    applied directly.  This allows composite sink-like transforms (e.g. a sink
  1045    with some pre-processing DoFns) to be used the same as all other sinks.
  1046  
  1047    This transform also supports sinks that inherit ``iobase.NativeSink``. These
  1048    are sinks that are implemented natively by the Dataflow service and hence
  1049    should not be updated by users. These sinks are processed using a Dataflow
  1050    native write transform.
  1051    """
  1052    # Import runners here to prevent circular imports
  1053    from apache_beam.runners.pipeline_context import PipelineContext
  1054  
  1055    def __init__(self, sink):
  1056      """Initializes a Write transform.
  1057  
  1058      Args:
  1059        sink: Data sink to write to.
  1060      """
  1061      super().__init__()
  1062      self.sink = sink
  1063  
  1064    def display_data(self):
  1065      return {'sink': self.sink.__class__, 'sink_dd': self.sink}
  1066  
  1067    def expand(self, pcoll):
  1068      from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
  1069      if isinstance(self.sink, dataflow_io.NativeSink):
  1070        # A native sink
  1071        return pcoll | 'NativeWrite' >> dataflow_io._NativeWrite(self.sink)
  1072      elif isinstance(self.sink, Sink):
  1073        # A custom sink
  1074        return pcoll | WriteImpl(self.sink)
  1075      elif isinstance(self.sink, ptransform.PTransform):
  1076        # This allows "composite" sinks to be used like non-composite ones.
  1077        return pcoll | self.sink
  1078      else:
  1079        raise ValueError(
  1080            'A sink must inherit iobase.Sink, iobase.NativeSink, '
  1081            'or be a PTransform. Received : %r' % self.sink)
  1082  
  1083    def to_runner_api_parameter(
  1084        self,
  1085        context: PipelineContext,
  1086    ) -> Tuple[str, Any]:
  1087      # Importing locally to prevent circular dependencies.
  1088      from apache_beam.io.gcp.pubsub import _PubSubSink
  1089      if isinstance(self.sink, _PubSubSink):
  1090        payload = beam_runner_api_pb2.PubSubWritePayload(
  1091            topic=self.sink.full_topic,
  1092            id_attribute=self.sink.id_label,
  1093            timestamp_attribute=self.sink.timestamp_attribute)
  1094        return (common_urns.composites.PUBSUB_WRITE.urn, payload)
  1095      else:
  1096        return super().to_runner_api_parameter(context)
  1097  
  1098    @staticmethod
  1099    @ptransform.PTransform.register_urn(
  1100        common_urns.composites.PUBSUB_WRITE.urn,
  1101        beam_runner_api_pb2.PubSubWritePayload)
  1102    def from_runner_api_parameter(
  1103        ptransform: Any,
  1104        payload: beam_runner_api_pb2.PubSubWritePayload,
  1105        unused_context: PipelineContext,
  1106    ) -> "Write":
  1107      if ptransform.spec.urn != common_urns.composites.PUBSUB_WRITE.urn:
  1108        raise ValueError(
  1109            'Write transform cannot be constructed for the given proto %r',
  1110            ptransform)
  1111  
  1112      if not payload.topic:
  1113        raise NotImplementedError(
  1114            "from_runner_api_parameter does not "
  1115            "handle empty or None topic")
  1116  
  1117      # Importing locally to prevent circular dependencies.
  1118      from apache_beam.io.gcp.pubsub import _PubSubSink
  1119      sink = _PubSubSink(
  1120          topic=payload.topic,
  1121          id_label=payload.id_attribute or None,
  1122          timestamp_attribute=payload.timestamp_attribute or None)
  1123      return Write(sink)
  1124  
  1125  
  1126  class WriteImpl(ptransform.PTransform):
  1127    """Implements the writing of custom sinks."""
  1128    def __init__(self, sink):
  1129      # type: (Sink) -> None
  1130      super().__init__()
  1131      self.sink = sink
  1132  
  1133    def expand(self, pcoll):
  1134      do_once = pcoll.pipeline | 'DoOnce' >> core.Create([None])
  1135      init_result_coll = do_once | 'InitializeWrite' >> core.Map(
  1136          lambda _, sink: sink.initialize_write(), self.sink)
  1137      if getattr(self.sink, 'num_shards', 0):
  1138        min_shards = self.sink.num_shards
  1139        if min_shards == 1:
  1140          keyed_pcoll = pcoll | core.Map(lambda x: (None, x))
  1141        else:
  1142          keyed_pcoll = pcoll | core.ParDo(_RoundRobinKeyFn(), count=min_shards)
  1143        write_result_coll = (
  1144            keyed_pcoll
  1145            | core.WindowInto(window.GlobalWindows())
  1146            | core.GroupByKey()
  1147            | 'WriteBundles' >> core.ParDo(
  1148                _WriteKeyedBundleDoFn(self.sink), AsSingleton(init_result_coll)))
  1149      else:
  1150        min_shards = 1
  1151        write_result_coll = (
  1152            pcoll
  1153            | core.WindowInto(window.GlobalWindows())
  1154            | 'WriteBundles' >> core.ParDo(
  1155                _WriteBundleDoFn(self.sink), AsSingleton(init_result_coll))
  1156            | 'Pair' >> core.Map(lambda x: (None, x))
  1157            | core.GroupByKey()
  1158            | 'Extract' >> core.FlatMap(lambda x: x[1]))
  1159      # PreFinalize should run before FinalizeWrite, and the two should not be
  1160      # fused.
  1161      pre_finalize_coll = (
  1162          do_once
  1163          | 'PreFinalize' >> core.FlatMap(
  1164              _pre_finalize,
  1165              self.sink,
  1166              AsSingleton(init_result_coll),
  1167              AsIter(write_result_coll)))
  1168      return do_once | 'FinalizeWrite' >> core.FlatMap(
  1169          _finalize_write,
  1170          self.sink,
  1171          AsSingleton(init_result_coll),
  1172          AsIter(write_result_coll),
  1173          min_shards,
  1174          AsSingleton(pre_finalize_coll)).with_output_types(str)
  1175  
  1176  
  1177  class _WriteBundleDoFn(core.DoFn):
  1178    """A DoFn for writing elements to an iobase.Writer.
  1179    Opens a writer at the first element and closes the writer at finish_bundle().
  1180    """
  1181    def __init__(self, sink):
  1182      self.sink = sink
  1183  
  1184    def display_data(self):
  1185      return {'sink_dd': self.sink}
  1186  
  1187    def start_bundle(self):
  1188      self.writer = None
  1189  
  1190    def process(self, element, init_result):
  1191      if self.writer is None:
  1192        # We ignore UUID collisions here since they are extremely rare.
  1193        self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
  1194      self.writer.write(element)
  1195      if self.writer.at_capacity():
  1196        yield self.writer.close()
  1197        self.writer = None
  1198  
  1199    def finish_bundle(self):
  1200      if self.writer is not None:
  1201        yield WindowedValue(
  1202            self.writer.close(),
  1203            window.GlobalWindow().max_timestamp(), [window.GlobalWindow()])
  1204  
  1205  
  1206  class _WriteKeyedBundleDoFn(core.DoFn):
  1207    def __init__(self, sink):
  1208      self.sink = sink
  1209  
  1210    def display_data(self):
  1211      return {'sink_dd': self.sink}
  1212  
  1213    def process(self, element, init_result):
  1214      bundle = element
  1215      writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
  1216      for e in bundle[1]:  # values
  1217        writer.write(e)
  1218      return [window.TimestampedValue(writer.close(), timestamp.MAX_TIMESTAMP)]
  1219  
  1220  
  1221  def _pre_finalize(unused_element, sink, init_result, write_results):
  1222    return sink.pre_finalize(init_result, write_results)
  1223  
  1224  
  1225  def _finalize_write(
  1226      unused_element,
  1227      sink,
  1228      init_result,
  1229      write_results,
  1230      min_shards,
  1231      pre_finalize_results):
  1232    write_results = list(write_results)
  1233    extra_shards = []
  1234    if len(write_results) < min_shards:
  1235      if write_results or not sink.skip_if_empty:
  1236        _LOGGER.debug(
  1237            'Creating %s empty shard(s).', min_shards - len(write_results))
  1238        for _ in range(min_shards - len(write_results)):
  1239          writer = sink.open_writer(init_result, str(uuid.uuid4()))
  1240          extra_shards.append(writer.close())
  1241    outputs = sink.finalize_write(
  1242        init_result, write_results + extra_shards, pre_finalize_results)
  1243    if outputs:
  1244      return (
  1245          window.TimestampedValue(v, timestamp.MAX_TIMESTAMP) for v in outputs)
  1246  
  1247  
  1248  class _RoundRobinKeyFn(core.DoFn):
  1249    def start_bundle(self):
  1250      self.counter = None
  1251  
  1252    def process(self, element, count):
  1253      if self.counter is None:
  1254        self.counter = random.randrange(0, count)
  1255      self.counter = (1 + self.counter) % count
  1256      yield self.counter, element
  1257  
  1258  
  1259  class RestrictionTracker(object):
  1260    """Manages access to a restriction.
  1261  
  1262    Keeps track of the restrictions claimed part for a Splittable DoFn.
  1263  
  1264    The restriction may be modified by different threads, however the system will
  1265    ensure sufficient locking such that no methods on the restriction tracker
  1266    will be called concurrently.
  1267  
  1268    See following documents for more details.
  1269    * https://s.apache.org/splittable-do-fn
  1270    * https://s.apache.org/splittable-do-fn-python-sdk
  1271    """
  1272    def current_restriction(self):
  1273      """Returns the current restriction.
  1274  
  1275      Returns a restriction accurately describing the full range of work the
  1276      current ``DoFn.process()`` call will do, including already completed work.
  1277  
  1278      The current restriction returned by method may be updated dynamically due
  1279      to due to concurrent invocation of other methods of the
  1280      ``RestrictionTracker``, For example, ``split()``.
  1281  
  1282      This API is required to be implemented.
  1283  
  1284      Returns: a restriction object.
  1285      """
  1286      raise NotImplementedError
  1287  
  1288    def current_progress(self):
  1289      # type: () -> RestrictionProgress
  1290  
  1291      """Returns a RestrictionProgress object representing the current progress.
  1292  
  1293      This API is recommended to be implemented. The runner can do a better job
  1294      at parallel processing with better progress signals.
  1295      """
  1296      raise NotImplementedError
  1297  
  1298    def check_done(self):
  1299      """Checks whether the restriction has been fully processed.
  1300  
  1301      Called by the SDK harness after iterator returned by ``DoFn.process()``
  1302      has been fully read.
  1303  
  1304      This method must raise a `ValueError` if there is still any unclaimed work
  1305      remaining in the restriction when this method is invoked. Exception raised
  1306      must have an informative error message.
  1307  
  1308      This API is required to be implemented in order to make sure no data loss
  1309      during SDK processing.
  1310  
  1311      Returns: ``True`` if current restriction has been fully processed.
  1312      Raises:
  1313        ValueError: if there is still any unclaimed work remaining.
  1314      """
  1315      raise NotImplementedError
  1316  
  1317    def try_split(self, fraction_of_remainder):
  1318      """Splits current restriction based on fraction_of_remainder.
  1319  
  1320      If splitting the current restriction is possible, the current restriction is
  1321      split into a primary and residual restriction pair. This invocation updates
  1322      the ``current_restriction()`` to be the primary restriction effectively
  1323      having the current ``DoFn.process()`` execution responsible for performing
  1324      the work that the primary restriction represents. The residual restriction
  1325      will be executed in a separate ``DoFn.process()`` invocation (likely in a
  1326      different process). The work performed by executing the primary and residual
  1327      restrictions as separate ``DoFn.process()`` invocations MUST be equivalent
  1328      to the work performed as if this split never occurred.
  1329  
  1330      The ``fraction_of_remainder`` should be used in a best effort manner to
  1331      choose a primary and residual restriction based upon the fraction of the
  1332      remaining work that the current ``DoFn.process()`` invocation is responsible
  1333      for. For example, if a ``DoFn.process()`` was reading a file with a
  1334      restriction representing the offset range [100, 200) and has processed up to
  1335      offset 130 with a fraction_of_remainder of 0.7, the primary and residual
  1336      restrictions returned would be [100, 179), [179, 200) (note: current_offset
  1337      + fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).
  1338  
  1339      ``fraction_of_remainder`` = 0 means a checkpoint is required.
  1340  
  1341      The API is recommended to be implemented for batch pipeline given that it is
  1342      very important for pipeline scaling and end to end pipeline execution.
  1343  
  1344      The API is required to be implemented for a streaming pipeline.
  1345  
  1346      Args:
  1347        fraction_of_remainder: A hint as to the fraction of work the primary
  1348          restriction should represent based upon the current known remaining
  1349          amount of work.
  1350  
  1351      Returns:
  1352        (primary_restriction, residual_restriction) if a split was possible,
  1353        otherwise returns ``None``.
  1354      """
  1355      raise NotImplementedError
  1356  
  1357    def try_claim(self, position):
  1358      """Attempts to claim the block of work in the current restriction
  1359      identified by the given position. Each claimed position MUST be a valid
  1360      split point.
  1361  
  1362      If this succeeds, the DoFn MUST execute the entire block of work. If it
  1363      fails, the ``DoFn.process()`` MUST return ``None`` without performing any
  1364      additional work or emitting output (note that emitting output or performing
  1365      work from ``DoFn.process()`` is also not allowed before the first call of
  1366      this method).
  1367  
  1368      The API is required to be implemented.
  1369  
  1370      Args:
  1371        position: current position that wants to be claimed.
  1372  
  1373      Returns: ``True`` if the position can be claimed as current_position.
  1374      Otherwise, returns ``False``.
  1375      """
  1376      raise NotImplementedError
  1377  
  1378    def is_bounded(self):
  1379      """Returns whether the amount of work represented by the current restriction
  1380      is bounded.
  1381  
  1382      The boundedness of the restriction is used to determine the default behavior
  1383      of how to truncate restrictions when a pipeline is being
  1384      `drained <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#>`_.  # pylint: disable=line-too-long
  1385      If the restriction is bounded, then the entire restriction will be processed
  1386      otherwise the restriction will be processed till a checkpoint is possible.
  1387  
  1388      The API is required to be implemented.
  1389  
  1390      Returns: ``True`` if the restriction represents a finite amount of work.
  1391      Otherwise, returns ``False``.
  1392      """
  1393      raise NotImplementedError
  1394  
  1395  
  1396  class WatermarkEstimator(object):
  1397    """A WatermarkEstimator which is used for estimating output_watermark based on
  1398    the timestamp of output records or manual modifications. Please refer to
  1399    ``watermark_estiamtors`` for commonly used watermark estimators.
  1400  
  1401    The base class provides common APIs that are called by the framework, which
  1402    are also accessible inside a DoFn.process() body. Derived watermark estimator
  1403    should implement all APIs listed below. Additional methods can be implemented
  1404    and will be available when invoked within a DoFn.
  1405  
  1406    Internal state must not be updated asynchronously.
  1407    """
  1408    def get_estimator_state(self):
  1409      """Get current state of the WatermarkEstimator instance, which can be used
  1410      to recreate the WatermarkEstimator when processing the restriction. See
  1411      WatermarkEstimatorProvider.create_watermark_estimator.
  1412      """
  1413      raise NotImplementedError(type(self))
  1414  
  1415    def current_watermark(self):
  1416      # type: () -> timestamp.Timestamp
  1417  
  1418      """Return estimated output_watermark. This function must return
  1419      monotonically increasing watermarks."""
  1420      raise NotImplementedError(type(self))
  1421  
  1422    def observe_timestamp(self, timestamp):
  1423      # type: (timestamp.Timestamp) -> None
  1424  
  1425      """Update tracking  watermark with latest output timestamp.
  1426  
  1427      Args:
  1428        timestamp: the `timestamp.Timestamp` of current output element.
  1429  
  1430      This is called with the timestamp of every element output from the DoFn.
  1431      """
  1432      raise NotImplementedError(type(self))
  1433  
  1434  
  1435  class RestrictionProgress(object):
  1436    """Used to record the progress of a restriction."""
  1437    def __init__(self, **kwargs):
  1438      # Only accept keyword arguments.
  1439      self._fraction = kwargs.pop('fraction', None)
  1440      self._completed = kwargs.pop('completed', None)
  1441      self._remaining = kwargs.pop('remaining', None)
  1442      assert not kwargs
  1443  
  1444    def __repr__(self):
  1445      return 'RestrictionProgress(fraction=%s, completed=%s, remaining=%s)' % (
  1446          self._fraction, self._completed, self._remaining)
  1447  
  1448    @property
  1449    def completed_work(self):
  1450      # type: () -> float
  1451      if self._completed is not None:
  1452        return self._completed
  1453      elif self._remaining is not None and self._fraction is not None:
  1454        return self._remaining * self._fraction / (1 - self._fraction)
  1455      else:
  1456        return self._fraction
  1457  
  1458    @property
  1459    def remaining_work(self):
  1460      # type: () -> float
  1461      if self._remaining is not None:
  1462        return self._remaining
  1463      elif self._completed is not None and self._fraction:
  1464        return self._completed * (1 - self._fraction) / self._fraction
  1465      else:
  1466        return 1 - self._fraction
  1467  
  1468    @property
  1469    def total_work(self):
  1470      # type: () -> float
  1471      return self.completed_work + self.remaining_work
  1472  
  1473    @property
  1474    def fraction_completed(self):
  1475      # type: () -> float
  1476      if self._fraction is not None:
  1477        return self._fraction
  1478      else:
  1479        return float(self._completed) / self.total_work
  1480  
  1481    @property
  1482    def fraction_remaining(self):
  1483      # type: () -> float
  1484      if self._fraction is not None:
  1485        return 1 - self._fraction
  1486      else:
  1487        return float(self._remaining) / self.total_work
  1488  
  1489    def with_completed(self, completed):
  1490      # type: (int) -> RestrictionProgress
  1491      return RestrictionProgress(
  1492          fraction=self._fraction, remaining=self._remaining, completed=completed)
  1493  
  1494  
  1495  class _SDFBoundedSourceRestriction(object):
  1496    """ A restriction wraps SourceBundle and RangeTracker. """
  1497    def __init__(self, source_bundle, range_tracker=None):
  1498      self._source_bundle = source_bundle
  1499      self._range_tracker = range_tracker
  1500  
  1501    def __reduce__(self):
  1502      # The instance of RangeTracker shouldn't be serialized.
  1503      return (self.__class__, (self._source_bundle, ))
  1504  
  1505    def range_tracker(self):
  1506      if not self._range_tracker:
  1507        self._range_tracker = self._source_bundle.source.get_range_tracker(
  1508            self._source_bundle.start_position, self._source_bundle.stop_position)
  1509      return self._range_tracker
  1510  
  1511    def weight(self):
  1512      return self._source_bundle.weight
  1513  
  1514    def source(self):
  1515      return self._source_bundle.source
  1516  
  1517    def try_split(self, fraction_of_remainder):
  1518      try:
  1519        consumed_fraction = self.range_tracker().fraction_consumed()
  1520        fraction = (
  1521            consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
  1522        position = self.range_tracker().position_at_fraction(fraction)
  1523        # Need to stash current stop_pos before splitting since
  1524        # range_tracker.split will update its stop_pos if splits
  1525        # successfully.
  1526        stop_pos = self._source_bundle.stop_position
  1527        split_result = self.range_tracker().try_split(position)
  1528        if split_result:
  1529          split_pos, split_fraction = split_result
  1530          primary_weight = self._source_bundle.weight * split_fraction
  1531          residual_weight = self._source_bundle.weight - primary_weight
  1532          # Update self to primary weight and end position.
  1533          self._source_bundle = SourceBundle(
  1534              primary_weight,
  1535              self._source_bundle.source,
  1536              self._source_bundle.start_position,
  1537              split_pos)
  1538          return (
  1539              self,
  1540              _SDFBoundedSourceRestriction(
  1541                  SourceBundle(
  1542                      residual_weight,
  1543                      self._source_bundle.source,
  1544                      split_pos,
  1545                      stop_pos)))
  1546      except Exception:
  1547        # For any exceptions from underlying trySplit calls, the wrapper will
  1548        # think that the source refuses to split at this point. In this case,
  1549        # no split happens at the wrapper level.
  1550        return None
  1551  
  1552  
  1553  class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
  1554    """An `iobase.RestrictionTracker` implementations for wrapping BoundedSource
  1555    with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
  1556    wraps SourceBundle and RangeTracker.
  1557  
  1558    Delegated RangeTracker guarantees synchronization safety.
  1559    """
  1560    def __init__(self, restriction):
  1561      if not isinstance(restriction, _SDFBoundedSourceRestriction):
  1562        raise ValueError(
  1563            'Initializing SDFBoundedSourceRestrictionTracker'
  1564            ' requires a _SDFBoundedSourceRestriction. Got %s instead.' %
  1565            restriction)
  1566      self.restriction = restriction
  1567  
  1568    def current_progress(self):
  1569      # type: () -> RestrictionProgress
  1570      return RestrictionProgress(
  1571          fraction=self.restriction.range_tracker().fraction_consumed())
  1572  
  1573    def current_restriction(self):
  1574      self.restriction.range_tracker()
  1575      return self.restriction
  1576  
  1577    def start_pos(self):
  1578      return self.restriction.range_tracker().start_position()
  1579  
  1580    def stop_pos(self):
  1581      return self.restriction.range_tracker().stop_position()
  1582  
  1583    def try_claim(self, position):
  1584      return self.restriction.range_tracker().try_claim(position)
  1585  
  1586    def try_split(self, fraction_of_remainder):
  1587      return self.restriction.try_split(fraction_of_remainder)
  1588  
  1589    def check_done(self):
  1590      return self.restriction.range_tracker().fraction_consumed() >= 1.0
  1591  
  1592    def is_bounded(self):
  1593      return True
  1594  
  1595  
  1596  class _SDFBoundedSourceWrapperRestrictionCoder(coders.Coder):
  1597    def decode(self, value):
  1598      return _SDFBoundedSourceRestriction(SourceBundle(*pickler.loads(value)))
  1599  
  1600    def encode(self, restriction):
  1601      return pickler.dumps((
  1602          restriction._source_bundle.weight,
  1603          restriction._source_bundle.source,
  1604          restriction._source_bundle.start_position,
  1605          restriction._source_bundle.stop_position))
  1606  
  1607  
  1608  class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
  1609    """
  1610    A `RestrictionProvider` that is used by SDF for `BoundedSource`.
  1611  
  1612    This restriction provider initializes restriction based on input
  1613    element that is expected to be of BoundedSource type.
  1614    """
  1615    def __init__(self, desired_chunk_size=None, restriction_coder=None):
  1616      self._desired_chunk_size = desired_chunk_size
  1617      self._restriction_coder = (
  1618          restriction_coder or _SDFBoundedSourceWrapperRestrictionCoder())
  1619  
  1620    def _check_source(self, src):
  1621      if not isinstance(src, BoundedSource):
  1622        raise RuntimeError(
  1623            'SDFBoundedSourceRestrictionProvider can only utilize BoundedSource')
  1624  
  1625    def initial_restriction(self, element_source: BoundedSource):
  1626      self._check_source(element_source)
  1627      range_tracker = element_source.get_range_tracker(None, None)
  1628      return _SDFBoundedSourceRestriction(
  1629          SourceBundle(
  1630              None,
  1631              element_source,
  1632              range_tracker.start_position(),
  1633              range_tracker.stop_position()))
  1634  
  1635    def create_tracker(self, restriction):
  1636      return _SDFBoundedSourceRestrictionTracker(restriction)
  1637  
  1638    def split(self, element, restriction):
  1639      if self._desired_chunk_size is None:
  1640        try:
  1641          estimated_size = restriction.source().estimate_size()
  1642        except NotImplementedError:
  1643          estimated_size = None
  1644        self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size)
  1645  
  1646      # Invoke source.split to get initial splitting results.
  1647      source_bundles = restriction.source().split(self._desired_chunk_size)
  1648      for source_bundle in source_bundles:
  1649        yield _SDFBoundedSourceRestriction(source_bundle)
  1650  
  1651    def restriction_size(self, element, restriction):
  1652      return restriction.weight()
  1653  
  1654    def restriction_coder(self):
  1655      return self._restriction_coder
  1656  
  1657  
  1658  class SDFBoundedSourceReader(PTransform):
  1659    """A ``PTransform`` that uses SDF to read from each ``BoundedSource`` in a
  1660    PCollection.
  1661  
  1662    NOTE: This transform can only be used with beam_fn_api enabled.
  1663    """
  1664    def __init__(self, data_to_display=None):
  1665      self._data_to_display = data_to_display or {}
  1666      super().__init__()
  1667  
  1668    def _create_sdf_bounded_source_dofn(self):
  1669      class SDFBoundedSourceDoFn(core.DoFn):
  1670        def __init__(self, dd):
  1671          self._dd = dd
  1672  
  1673        def display_data(self):
  1674          return self._dd
  1675  
  1676        def process(
  1677            self,
  1678            unused_element,
  1679            restriction_tracker=core.DoFn.RestrictionParam(
  1680                _SDFBoundedSourceRestrictionProvider())):
  1681          current_restriction = restriction_tracker.current_restriction()
  1682          assert isinstance(current_restriction, _SDFBoundedSourceRestriction)
  1683  
  1684          return current_restriction.source().read(
  1685              current_restriction.range_tracker())
  1686  
  1687      return SDFBoundedSourceDoFn(self._data_to_display)
  1688  
  1689    def expand(self, pvalue):
  1690      return pvalue | core.ParDo(self._create_sdf_bounded_source_dofn())
  1691  
  1692    def get_windowing(self, unused_inputs):
  1693      return core.Windowing(window.GlobalWindows())
  1694  
  1695    def display_data(self):
  1696      return self._data_to_display