github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/parquetio.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """``PTransforms`` for reading from and writing to Parquet files.
    19  
    20  Provides two read ``PTransform``\\s, ``ReadFromParquet`` and
    21  ``ReadAllFromParquet``, that produces a ``PCollection`` of records.
    22  Each record of this ``PCollection`` will contain a single record read from
    23  a Parquet file. Records that are of simple types will be mapped into
    24  corresponding Python types. The actual parquet file operations are done by
    25  pyarrow. Source splitting is supported at row group granularity.
    26  
    27  Additionally, this module provides a write ``PTransform`` ``WriteToParquet``
    28  that can be used to write a given ``PCollection`` of Python objects to a
    29  Parquet file.
    30  """
    31  # pytype: skip-file
    32  
    33  from functools import partial
    34  
    35  from pkg_resources import parse_version
    36  
    37  from apache_beam.io import filebasedsink
    38  from apache_beam.io import filebasedsource
    39  from apache_beam.io.filesystem import CompressionTypes
    40  from apache_beam.io.iobase import RangeTracker
    41  from apache_beam.io.iobase import Read
    42  from apache_beam.io.iobase import Write
    43  from apache_beam.transforms import DoFn
    44  from apache_beam.transforms import ParDo
    45  from apache_beam.transforms import PTransform
    46  from apache_beam.transforms import window
    47  
    48  try:
    49    import pyarrow as pa
    50    import pyarrow.parquet as pq
    51  except ImportError:
    52    pa = None
    53    pq = None
    54    ARROW_MAJOR_VERSION = None
    55  else:
    56    base_pa_version = parse_version(pa.__version__).base_version
    57    ARROW_MAJOR_VERSION, _, _ = map(int, base_pa_version.split('.'))
    58  
    59  __all__ = [
    60      'ReadFromParquet',
    61      'ReadAllFromParquet',
    62      'ReadFromParquetBatched',
    63      'ReadAllFromParquetBatched',
    64      'WriteToParquet',
    65      'WriteToParquetBatched'
    66  ]
    67  
    68  
    69  class _ArrowTableToRowDictionaries(DoFn):
    70    """ A DoFn that consumes an Arrow table and yields a python dictionary for
    71    each row in the table."""
    72    def process(self, table, with_filename=False):
    73      if with_filename:
    74        file_name = table[0]
    75        table = table[1]
    76      num_rows = table.num_rows
    77      data_items = table.to_pydict().items()
    78      for n in range(num_rows):
    79        row = {}
    80        for column, values in data_items:
    81          row[column] = values[n]
    82        if with_filename:
    83          yield (file_name, row)
    84        else:
    85          yield row
    86  
    87  
    88  class _RowDictionariesToArrowTable(DoFn):
    89    """ A DoFn that consumes python dictionarys and yields a pyarrow table."""
    90    def __init__(
    91        self,
    92        schema,
    93        row_group_buffer_size=64 * 1024 * 1024,
    94        record_batch_size=1000):
    95      self._schema = schema
    96      self._row_group_buffer_size = row_group_buffer_size
    97      self._buffer = [[] for _ in range(len(schema.names))]
    98      self._buffer_size = record_batch_size
    99      self._record_batches = []
   100      self._record_batches_byte_size = 0
   101  
   102    def process(self, row):
   103      if len(self._buffer[0]) >= self._buffer_size:
   104        self._flush_buffer()
   105  
   106      if self._record_batches_byte_size >= self._row_group_buffer_size:
   107        table = self._create_table()
   108        yield table
   109  
   110      # reorder the data in columnar format.
   111      for i, n in enumerate(self._schema.names):
   112        self._buffer[i].append(row[n])
   113  
   114    def finish_bundle(self):
   115      if len(self._buffer[0]) > 0:
   116        self._flush_buffer()
   117      if self._record_batches_byte_size > 0:
   118        table = self._create_table()
   119        yield window.GlobalWindows.windowed_value_at_end_of_window(table)
   120  
   121    def display_data(self):
   122      res = super().display_data()
   123      res['row_group_buffer_size'] = str(self._row_group_buffer_size)
   124      res['buffer_size'] = str(self._buffer_size)
   125  
   126      return res
   127  
   128    def _create_table(self):
   129      table = pa.Table.from_batches(self._record_batches, schema=self._schema)
   130      self._record_batches = []
   131      self._record_batches_byte_size = 0
   132      return table
   133  
   134    def _flush_buffer(self):
   135      arrays = [[] for _ in range(len(self._schema.names))]
   136      for x, y in enumerate(self._buffer):
   137        arrays[x] = pa.array(y, type=self._schema.types[x])
   138        self._buffer[x] = []
   139      rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
   140      self._record_batches.append(rb)
   141      size = 0
   142      for x in arrays:
   143        for b in x.buffers():
   144          if b is not None:
   145            size = size + b.size
   146      self._record_batches_byte_size = self._record_batches_byte_size + size
   147  
   148  
   149  class ReadFromParquetBatched(PTransform):
   150    """A :class:`~apache_beam.transforms.ptransform.PTransform` for reading
   151       Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is
   152       currently experimental. No backward-compatibility guarantees."""
   153    def __init__(
   154        self, file_pattern=None, min_bundle_size=0, validate=True, columns=None):
   155      """ Initializes :class:`~ReadFromParquetBatched`
   156  
   157      An alternative to :class:`~ReadFromParquet` that yields each row group from
   158      the Parquet file as a `pyarrow.Table`.  These Table instances can be
   159      processed directly, or converted to a pandas DataFrame for processing.  For
   160      more information on supported types and schema, please see the pyarrow
   161      documentation.
   162  
   163      .. testcode::
   164  
   165        with beam.Pipeline() as p:
   166          dataframes = p \\
   167              | 'Read' >> beam.io.ReadFromParquetBatched('/mypath/mypqfiles*') \\
   168              | 'Convert to pandas' >> beam.Map(lambda table: table.to_pandas())
   169  
   170      .. NOTE: We're not actually interested in this error; but if we get here,
   171         it means that the way of calling this transform hasn't changed.
   172  
   173      .. testoutput::
   174        :hide:
   175  
   176        Traceback (most recent call last):
   177         ...
   178        OSError: No files found based on the file pattern
   179  
   180      See also: :class:`~ReadFromParquet`.
   181  
   182      Args:
   183        file_pattern (str): the file glob to read
   184        min_bundle_size (int): the minimum size in bytes, to be considered when
   185          splitting the input into bundles.
   186        validate (bool): flag to verify that the files exist during the pipeline
   187          creation time.
   188        columns (List[str]): list of columns that will be read from files.
   189          A column name may be a prefix of a nested field, e.g. 'a' will select
   190          'a.b', 'a.c', and 'a.d.e'
   191      """
   192  
   193      super().__init__()
   194      self._source = _create_parquet_source(
   195          file_pattern,
   196          min_bundle_size,
   197          validate=validate,
   198          columns=columns,
   199      )
   200  
   201    def expand(self, pvalue):
   202      return pvalue.pipeline | Read(self._source)
   203  
   204    def display_data(self):
   205      return {'source_dd': self._source}
   206  
   207  
   208  class ReadFromParquet(PTransform):
   209    """A :class:`~apache_beam.transforms.ptransform.PTransform` for reading
   210       Parquet files as a `PCollection` of dictionaries. This `PTransform` is
   211       currently experimental. No backward-compatibility guarantees."""
   212    def __init__(
   213        self, file_pattern=None, min_bundle_size=0, validate=True, columns=None):
   214      """Initializes :class:`ReadFromParquet`.
   215  
   216      Uses source ``_ParquetSource`` to read a set of Parquet files defined by
   217      a given file pattern.
   218  
   219      If ``/mypath/myparquetfiles*`` is a file-pattern that points to a set of
   220      Parquet files, a :class:`~apache_beam.pvalue.PCollection` for the records in
   221      these Parquet files can be created in the following manner.
   222  
   223      .. testcode::
   224  
   225        with beam.Pipeline() as p:
   226          records = p | 'Read' >> beam.io.ReadFromParquet('/mypath/mypqfiles*')
   227  
   228      .. NOTE: We're not actually interested in this error; but if we get here,
   229         it means that the way of calling this transform hasn't changed.
   230  
   231      .. testoutput::
   232        :hide:
   233  
   234        Traceback (most recent call last):
   235         ...
   236        OSError: No files found based on the file pattern
   237  
   238      Each element of this :class:`~apache_beam.pvalue.PCollection` will contain
   239      a Python dictionary representing a single record. The keys will be of type
   240      :class:`str` and named after their corresponding column names. The values
   241      will be of the type defined in the corresponding Parquet schema. Records
   242      that are of simple types will be mapped into corresponding Python types.
   243      Records that are of complex types like list and struct will be mapped to
   244      Python list and dictionary respectively. For more information on supported
   245      types and schema, please see the pyarrow documentation.
   246  
   247      See also: :class:`~ReadFromParquetBatched`.
   248  
   249      Args:
   250        file_pattern (str): the file glob to read
   251        min_bundle_size (int): the minimum size in bytes, to be considered when
   252          splitting the input into bundles.
   253        validate (bool): flag to verify that the files exist during the pipeline
   254          creation time.
   255        columns (List[str]): list of columns that will be read from files.
   256          A column name may be a prefix of a nested field, e.g. 'a' will select
   257          'a.b', 'a.c', and 'a.d.e'
   258      """
   259      super().__init__()
   260      self._source = _create_parquet_source(
   261          file_pattern,
   262          min_bundle_size,
   263          validate=validate,
   264          columns=columns,
   265      )
   266  
   267    def expand(self, pvalue):
   268      return pvalue | Read(self._source) | ParDo(_ArrowTableToRowDictionaries())
   269  
   270    def display_data(self):
   271      return {'source_dd': self._source}
   272  
   273  
   274  class ReadAllFromParquetBatched(PTransform):
   275    """A ``PTransform`` for reading ``PCollection`` of Parquet files.
   276  
   277     Uses source ``_ParquetSource`` to read a ``PCollection`` of Parquet files or
   278     file patterns and produce a ``PCollection`` of ``pyarrow.Table``, one for
   279     each Parquet file row group. This ``PTransform`` is currently experimental.
   280     No backward-compatibility guarantees.
   281    """
   282  
   283    DEFAULT_DESIRED_BUNDLE_SIZE = 64 * 1024 * 1024  # 64MB
   284  
   285    def __init__(
   286        self,
   287        min_bundle_size=0,
   288        desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE,
   289        columns=None,
   290        with_filename=False,
   291        label='ReadAllFiles'):
   292      """Initializes ``ReadAllFromParquet``.
   293  
   294      Args:
   295        min_bundle_size: the minimum size in bytes, to be considered when
   296                         splitting the input into bundles.
   297        desired_bundle_size: the desired size in bytes, to be considered when
   298                         splitting the input into bundles.
   299        columns: list of columns that will be read from files. A column name
   300                         may be a prefix of a nested field, e.g. 'a' will select
   301                         'a.b', 'a.c', and 'a.d.e'
   302        with_filename: If True, returns a Key Value with the key being the file
   303          name and the value being the actual data. If False, it only returns
   304          the data.
   305      """
   306      super().__init__()
   307      source_from_file = partial(
   308          _create_parquet_source,
   309          min_bundle_size=min_bundle_size,
   310          columns=columns)
   311      self._read_all_files = filebasedsource.ReadAllFiles(
   312          True,
   313          CompressionTypes.UNCOMPRESSED,
   314          desired_bundle_size,
   315          min_bundle_size,
   316          source_from_file,
   317          with_filename)
   318  
   319      self.label = label
   320  
   321    def expand(self, pvalue):
   322      return pvalue | self.label >> self._read_all_files
   323  
   324  
   325  class ReadAllFromParquet(PTransform):
   326    def __init__(self, with_filename=False, **kwargs):
   327      self._with_filename = with_filename
   328      self._read_batches = ReadAllFromParquetBatched(
   329          with_filename=self._with_filename, **kwargs)
   330  
   331    def expand(self, pvalue):
   332      return pvalue | self._read_batches | ParDo(
   333          _ArrowTableToRowDictionaries(), with_filename=self._with_filename)
   334  
   335  
   336  def _create_parquet_source(
   337      file_pattern=None, min_bundle_size=0, validate=False, columns=None):
   338    return \
   339      _ParquetSource(
   340          file_pattern=file_pattern,
   341          min_bundle_size=min_bundle_size,
   342          validate=validate,
   343          columns=columns,
   344      )
   345  
   346  
   347  class _ParquetUtils(object):
   348    @staticmethod
   349    def find_first_row_group_index(pf, start_offset):
   350      for i in range(_ParquetUtils.get_number_of_row_groups(pf)):
   351        row_group_start_offset = _ParquetUtils.get_offset(pf, i)
   352        if row_group_start_offset >= start_offset:
   353          return i
   354      return -1
   355  
   356    @staticmethod
   357    def get_offset(pf, row_group_index):
   358      first_column_metadata =\
   359        pf.metadata.row_group(row_group_index).column(0)
   360      if first_column_metadata.has_dictionary_page:
   361        return first_column_metadata.dictionary_page_offset
   362      else:
   363        return first_column_metadata.data_page_offset
   364  
   365    @staticmethod
   366    def get_number_of_row_groups(pf):
   367      return pf.metadata.num_row_groups
   368  
   369  
   370  class _ParquetSource(filebasedsource.FileBasedSource):
   371    """A source for reading Parquet files.
   372    """
   373    def __init__(self, file_pattern, min_bundle_size, validate, columns):
   374      super().__init__(
   375          file_pattern=file_pattern,
   376          min_bundle_size=min_bundle_size,
   377          validate=validate)
   378      self._columns = columns
   379  
   380    def read_records(self, file_name, range_tracker):
   381      next_block_start = -1
   382  
   383      def split_points_unclaimed(stop_position):
   384        if next_block_start >= stop_position:
   385          # Next block starts at or after the suggested stop position. Hence
   386          # there will not be split points to be claimed for the range ending at
   387          # suggested stop position.
   388          return 0
   389        return RangeTracker.SPLIT_POINTS_UNKNOWN
   390  
   391      range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed)
   392  
   393      start_offset = range_tracker.start_position()
   394      if start_offset is None:
   395        start_offset = 0
   396  
   397      with self.open_file(file_name) as f:
   398        pf = pq.ParquetFile(f)
   399  
   400        # find the first dictionary page (or data page if there's no dictionary
   401        # page available) offset after the given start_offset. This offset is also
   402        # the starting offset of any row group since the Parquet specification
   403        # describes that the data pages always come first before the meta data in
   404        # each row group.
   405        index = _ParquetUtils.find_first_row_group_index(pf, start_offset)
   406        if index != -1:
   407          next_block_start = _ParquetUtils.get_offset(pf, index)
   408        else:
   409          next_block_start = range_tracker.stop_position()
   410        number_of_row_groups = _ParquetUtils.get_number_of_row_groups(pf)
   411  
   412        while range_tracker.try_claim(next_block_start):
   413          table = pf.read_row_group(index, self._columns)
   414  
   415          if index + 1 < number_of_row_groups:
   416            index = index + 1
   417            next_block_start = _ParquetUtils.get_offset(pf, index)
   418          else:
   419            next_block_start = range_tracker.stop_position()
   420  
   421          yield table
   422  
   423  
   424  class WriteToParquet(PTransform):
   425    """A ``PTransform`` for writing parquet files.
   426  
   427      This ``PTransform`` is currently experimental. No backward-compatibility
   428      guarantees.
   429    """
   430    def __init__(
   431        self,
   432        file_path_prefix,
   433        schema,
   434        row_group_buffer_size=64 * 1024 * 1024,
   435        record_batch_size=1000,
   436        codec='none',
   437        use_deprecated_int96_timestamps=False,
   438        use_compliant_nested_type=False,
   439        file_name_suffix='',
   440        num_shards=0,
   441        shard_name_template=None,
   442        mime_type='application/x-parquet'):
   443      """Initialize a WriteToParquet transform.
   444  
   445      Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of
   446      records. Each record is a dictionary with keys of a string type that
   447      represent column names. Schema must be specified like the example below.
   448  
   449      .. testsetup::
   450  
   451        from tempfile import NamedTemporaryFile
   452        import glob
   453        import os
   454        import pyarrow
   455  
   456        filename = NamedTemporaryFile(delete=False).name
   457  
   458      .. testcode::
   459  
   460        with beam.Pipeline() as p:
   461          records = p | 'Read' >> beam.Create(
   462              [{'name': 'foo', 'age': 10}, {'name': 'bar', 'age': 20}]
   463          )
   464          _ = records | 'Write' >> beam.io.WriteToParquet(filename,
   465              pyarrow.schema(
   466                  [('name', pyarrow.binary()), ('age', pyarrow.int64())]
   467              )
   468          )
   469  
   470      .. testcleanup::
   471  
   472        for output in glob.glob('{}*'.format(filename)):
   473          os.remove(output)
   474  
   475      For more information on supported types and schema, please see the pyarrow
   476      document.
   477  
   478      Args:
   479        file_path_prefix: The file path to write to. The files written will begin
   480          with this prefix, followed by a shard identifier (see num_shards), and
   481          end in a common extension, if given by file_name_suffix. In most cases,
   482          only this argument is specified and num_shards, shard_name_template, and
   483          file_name_suffix use default values.
   484        schema: The schema to use, as type of ``pyarrow.Schema``.
   485        row_group_buffer_size: The byte size of the row group buffer. Note that
   486          this size is for uncompressed data on the memory and normally much
   487          bigger than the actual row group size written to a file.
   488        record_batch_size: The number of records in each record batch. Record
   489          batch is a basic unit used for storing data in the row group buffer.
   490          A higher record batch size implies low granularity on a row group buffer
   491          size. For configuring a row group size based on the number of records,
   492          set ``row_group_buffer_size`` to 1 and use ``record_batch_size`` to
   493          adjust the value.
   494        codec: The codec to use for block-level compression. Any string supported
   495          by the pyarrow specification is accepted.
   496        use_deprecated_int96_timestamps: Write nanosecond resolution timestamps to
   497          INT96 Parquet format. Defaults to False.
   498        use_compliant_nested_type: Write compliant Parquet nested type (lists).
   499        file_name_suffix: Suffix for the files written.
   500        num_shards: The number of files (shards) used for output. If not set, the
   501          service will decide on the optimal number of shards.
   502          Constraining the number of shards is likely to reduce
   503          the performance of a pipeline.  Setting this value is not recommended
   504          unless you require a specific number of output files.
   505        shard_name_template: A template string containing placeholders for
   506          the shard number and shard count. When constructing a filename for a
   507          particular shard number, the upper-case letters 'S' and 'N' are
   508          replaced with the 0-padded shard number and shard count respectively.
   509          This argument can be '' in which case it behaves as if num_shards was
   510          set to 1 and only one file will be generated. The default pattern used
   511          is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
   512        mime_type: The MIME type to use for the produced files, if the filesystem
   513          supports specifying MIME types.
   514  
   515      Returns:
   516        A WriteToParquet transform usable for writing.
   517      """
   518      super().__init__()
   519      self._schema = schema
   520      self._row_group_buffer_size = row_group_buffer_size
   521      self._record_batch_size = record_batch_size
   522  
   523      self._sink = \
   524        _create_parquet_sink(
   525            file_path_prefix,
   526            schema,
   527            codec,
   528            use_deprecated_int96_timestamps,
   529            use_compliant_nested_type,
   530            file_name_suffix,
   531            num_shards,
   532            shard_name_template,
   533            mime_type
   534        )
   535  
   536    def expand(self, pcoll):
   537      return pcoll | ParDo(
   538          _RowDictionariesToArrowTable(
   539              self._schema, self._row_group_buffer_size,
   540              self._record_batch_size)) | Write(self._sink)
   541  
   542    def display_data(self):
   543      return {
   544          'sink_dd': self._sink,
   545          'row_group_buffer_size': str(self._row_group_buffer_size)
   546      }
   547  
   548  
   549  class WriteToParquetBatched(PTransform):
   550    """A ``PTransform`` for writing parquet files from a `PCollection` of
   551      `pyarrow.Table`.
   552  
   553      This ``PTransform`` is currently experimental. No backward-compatibility
   554      guarantees.
   555    """
   556    def __init__(
   557        self,
   558        file_path_prefix,
   559        schema=None,
   560        codec='none',
   561        use_deprecated_int96_timestamps=False,
   562        use_compliant_nested_type=False,
   563        file_name_suffix='',
   564        num_shards=0,
   565        shard_name_template=None,
   566        mime_type='application/x-parquet',
   567    ):
   568      """Initialize a WriteToParquetBatched transform.
   569  
   570      Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of
   571      records. Each record is a pa.Table Schema must be specified like the
   572      example below.
   573  
   574      .. testsetup:: batched
   575  
   576        from tempfile import NamedTemporaryFile
   577        import glob
   578        import os
   579        import pyarrow
   580  
   581        filename = NamedTemporaryFile(delete=False).name
   582  
   583      .. testcode:: batched
   584  
   585        table = pyarrow.Table.from_pylist([{'name': 'foo', 'age': 10},
   586                                           {'name': 'bar', 'age': 20}])
   587        with beam.Pipeline() as p:
   588          records = p | 'Read' >> beam.Create([table])
   589          _ = records | 'Write' >> beam.io.WriteToParquetBatched(filename,
   590              pyarrow.schema(
   591                  [('name', pyarrow.string()), ('age', pyarrow.int64())]
   592              )
   593          )
   594  
   595      .. testcleanup:: batched
   596  
   597        for output in glob.glob('{}*'.format(filename)):
   598          os.remove(output)
   599  
   600      For more information on supported types and schema, please see the pyarrow
   601      document.
   602  
   603      Args:
   604        file_path_prefix: The file path to write to. The files written will begin
   605          with this prefix, followed by a shard identifier (see num_shards), and
   606          end in a common extension, if given by file_name_suffix. In most cases,
   607          only this argument is specified and num_shards, shard_name_template, and
   608          file_name_suffix use default values.
   609        schema: The schema to use, as type of ``pyarrow.Schema``.
   610        codec: The codec to use for block-level compression. Any string supported
   611          by the pyarrow specification is accepted.
   612        use_deprecated_int96_timestamps: Write nanosecond resolution timestamps to
   613          INT96 Parquet format. Defaults to False.
   614        use_compliant_nested_type: Write compliant Parquet nested type (lists).
   615        file_name_suffix: Suffix for the files written.
   616        num_shards: The number of files (shards) used for output. If not set, the
   617          service will decide on the optimal number of shards.
   618          Constraining the number of shards is likely to reduce
   619          the performance of a pipeline.  Setting this value is not recommended
   620          unless you require a specific number of output files.
   621        shard_name_template: A template string containing placeholders for
   622          the shard number and shard count. When constructing a filename for a
   623          particular shard number, the upper-case letters 'S' and 'N' are
   624          replaced with the 0-padded shard number and shard count respectively.
   625          This argument can be '' in which case it behaves as if num_shards was
   626          set to 1 and only one file will be generated. The default pattern used
   627          is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
   628        mime_type: The MIME type to use for the produced files, if the filesystem
   629          supports specifying MIME types.
   630  
   631      Returns:
   632        A WriteToParquetBatched transform usable for writing.
   633      """
   634      super().__init__()
   635      self._sink = \
   636        _create_parquet_sink(
   637            file_path_prefix,
   638            schema,
   639            codec,
   640            use_deprecated_int96_timestamps,
   641            use_compliant_nested_type,
   642            file_name_suffix,
   643            num_shards,
   644            shard_name_template,
   645            mime_type
   646        )
   647  
   648    def expand(self, pcoll):
   649      return pcoll | Write(self._sink)
   650  
   651    def display_data(self):
   652      return {'sink_dd': self._sink}
   653  
   654  
   655  def _create_parquet_sink(
   656      file_path_prefix,
   657      schema,
   658      codec,
   659      use_deprecated_int96_timestamps,
   660      use_compliant_nested_type,
   661      file_name_suffix,
   662      num_shards,
   663      shard_name_template,
   664      mime_type):
   665    return \
   666      _ParquetSink(
   667          file_path_prefix,
   668          schema,
   669          codec,
   670          use_deprecated_int96_timestamps,
   671          use_compliant_nested_type,
   672          file_name_suffix,
   673          num_shards,
   674          shard_name_template,
   675          mime_type
   676      )
   677  
   678  
   679  class _ParquetSink(filebasedsink.FileBasedSink):
   680    """A sink for parquet files from batches."""
   681    def __init__(
   682        self,
   683        file_path_prefix,
   684        schema,
   685        codec,
   686        use_deprecated_int96_timestamps,
   687        use_compliant_nested_type,
   688        file_name_suffix,
   689        num_shards,
   690        shard_name_template,
   691        mime_type):
   692      super().__init__(
   693          file_path_prefix,
   694          file_name_suffix=file_name_suffix,
   695          num_shards=num_shards,
   696          shard_name_template=shard_name_template,
   697          coder=None,
   698          mime_type=mime_type,
   699          # Compression happens at the block level using the supplied codec, and
   700          # not at the file level.
   701          compression_type=CompressionTypes.UNCOMPRESSED)
   702      self._schema = schema
   703      self._codec = codec
   704      if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4":
   705        raise ValueError(
   706            "Due to ARROW-9424, writing with LZ4 compression is not supported in "
   707            "pyarrow 1.x, please use a different pyarrow version or a different "
   708            f"codec. Your pyarrow version: {pa.__version__}")
   709      self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
   710      if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4:
   711        raise ValueError(
   712            "With ARROW-11497, use_compliant_nested_type is only supported in "
   713            "pyarrow version >= 4.x, please use a different pyarrow version. "
   714            f"Your pyarrow version: {pa.__version__}")
   715      self._use_compliant_nested_type = use_compliant_nested_type
   716      self._file_handle = None
   717  
   718    def open(self, temp_path):
   719      self._file_handle = super().open(temp_path)
   720      if ARROW_MAJOR_VERSION < 4:
   721        return pq.ParquetWriter(
   722            self._file_handle,
   723            self._schema,
   724            compression=self._codec,
   725            use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps)
   726      return pq.ParquetWriter(
   727          self._file_handle,
   728          self._schema,
   729          compression=self._codec,
   730          use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps,
   731          use_compliant_nested_type=self._use_compliant_nested_type)
   732  
   733    def write_record(self, writer, table: pa.Table):
   734      writer.write_table(table)
   735  
   736    def close(self, writer):
   737      writer.close()
   738      if self._file_handle:
   739        self._file_handle.close()
   740        self._file_handle = None
   741  
   742    def display_data(self):
   743      res = super().display_data()
   744      res['codec'] = str(self._codec)
   745      res['schema'] = str(self._schema)
   746      return res