github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/fastavro_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """End-to-end test for Avro IO's fastavro support.
    19  
    20  Writes a configurable number of records to a temporary location with fastavro,
    21  then reads them back in from source, joins the generated records and records
    22  that are read from the source, and verifies they have the same elements.
    23  
    24  
    25  
    26  Usage:
    27  
    28    DataFlowRunner:
    29      pytest apache_beam/examples/fastavro_it_test.py \
    30          --test-pipeline-options="
    31            --runner=TestDataflowRunner
    32            --project=...
    33            --region=...
    34            --staging_location=gs://...
    35            --temp_location=gs://...
    36            --output=gs://...
    37            --sdk_location=...
    38          "
    39  
    40    DirectRunner:
    41      pytest apache_beam/examples/fastavro_it_test.py \
    42        --test-pipeline-options="
    43          --output=/tmp
    44          --records=5000
    45        "
    46  """
    47  
    48  # pytype: skip-file
    49  
    50  import json
    51  import logging
    52  import unittest
    53  import uuid
    54  
    55  import pytest
    56  from fastavro import parse_schema
    57  
    58  from apache_beam.io.avroio import ReadAllFromAvro
    59  from apache_beam.io.avroio import WriteToAvro
    60  from apache_beam.runners.runner import PipelineState
    61  from apache_beam.testing.test_pipeline import TestPipeline
    62  from apache_beam.testing.test_utils import delete_files
    63  from apache_beam.testing.util import BeamAssertException
    64  from apache_beam.transforms.core import Create
    65  from apache_beam.transforms.core import FlatMap
    66  from apache_beam.transforms.core import Map
    67  from apache_beam.transforms.util import CoGroupByKey
    68  
    69  LABELS = ['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx']
    70  COLORS = ['RED', 'ORANGE', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', None]
    71  
    72  
    73  def record(i):
    74    return {
    75        'label': LABELS[i % len(LABELS)],
    76        'number': i,
    77        'number_str': str(i),
    78        'color': COLORS[i % len(COLORS)]
    79    }
    80  
    81  
    82  def assertEqual(l, r):
    83    if l != r:
    84      raise BeamAssertException('Assertion failed: %s == %s' % (l, r))
    85  
    86  
    87  def check(element):
    88    assert element['color'] in COLORS
    89    assert element['label'] in LABELS
    90    assertEqual(
    91        sorted(element.keys()), ['color', 'label', 'number', 'number_str'])
    92  
    93  
    94  class FastavroIT(unittest.TestCase):
    95  
    96    SCHEMA_STRING = '''
    97      {"namespace": "example.avro",
    98       "type": "record",
    99       "name": "User",
   100       "fields": [
   101           {"name": "label", "type": "string"},
   102           {"name": "number",  "type": ["int", "null"]},
   103           {"name": "number_str", "type": ["string", "null"]},
   104           {"name": "color", "type": ["string", "null"]}
   105       ]
   106      }
   107      '''
   108  
   109    def setUp(self):
   110      self.test_pipeline = TestPipeline(is_integration_test=True)
   111      self.uuid = str(uuid.uuid4())
   112      self.output = '/'.join([self.test_pipeline.get_option('output'), self.uuid])
   113  
   114    @pytest.mark.it_postcommit
   115    def test_avro_it(self):
   116      num_records = self.test_pipeline.get_option('records')
   117      num_records = int(num_records) if num_records else 1000000
   118      fastavro_output = '/'.join([self.output, 'fastavro'])
   119  
   120      # Seed a `PCollection` with indices that will each be FlatMap'd into
   121      # `batch_size` records, to avoid having a too-large list in memory at
   122      # the outset
   123      batch_size = self.test_pipeline.get_option('batch-size')
   124      batch_size = int(batch_size) if batch_size else 10000
   125  
   126      # pylint: disable=bad-option-value
   127      batches = range(int(num_records / batch_size))
   128  
   129      def batch_indices(start):
   130        # pylint: disable=bad-option-value
   131        return range(start * batch_size, (start + 1) * batch_size)
   132  
   133      # A `PCollection` with `num_records` avro records
   134      records_pcoll = \
   135          self.test_pipeline \
   136          | 'create-batches' >> Create(batches) \
   137          | 'expand-batches' >> FlatMap(batch_indices) \
   138          | 'create-records' >> Map(record)
   139  
   140      # pylint: disable=expression-not-assigned
   141      records_pcoll \
   142      | 'write_fastavro' >> WriteToAvro(
   143          fastavro_output,
   144          parse_schema(json.loads(self.SCHEMA_STRING)),
   145      )
   146      result = self.test_pipeline.run()
   147      result.wait_until_finish()
   148      fastavro_pcoll = self.test_pipeline \
   149                       | 'create-fastavro' >> Create(['%s*' % fastavro_output]) \
   150                       | 'read-fastavro' >> ReadAllFromAvro()
   151  
   152      mapped_fastavro_pcoll = fastavro_pcoll | "map_fastavro" >> Map(
   153          lambda x: (x['number'], x))
   154      mapped_record_pcoll = records_pcoll | "map_record" >> Map(
   155          lambda x: (x['number'], x))
   156  
   157      def validate_record(elem):
   158        v = elem[1]
   159  
   160        def assertEqual(l, r):
   161          if l != r:
   162            raise BeamAssertException('Assertion failed: %s == %s' % (l, r))
   163  
   164        assertEqual(sorted(v.keys()), ['fastavro', 'record_pcoll'])
   165        record_pcoll_values = v['record_pcoll']
   166        fastavro_values = v['fastavro']
   167        assertEqual(record_pcoll_values, fastavro_values)
   168        assertEqual(len(record_pcoll_values), 1)
   169  
   170      {
   171          "record_pcoll": mapped_record_pcoll, "fastavro": mapped_fastavro_pcoll
   172      } | CoGroupByKey() | Map(validate_record)
   173  
   174      result = self.test_pipeline.run()
   175      result.wait_until_finish()
   176  
   177      self.addCleanup(delete_files, [self.output])
   178      assert result.state == PipelineState.DONE
   179  
   180  
   181  if __name__ == '__main__':
   182    logging.getLogger().setLevel(logging.DEBUG)
   183    unittest.main()