github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_kafkaio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Integration test for Python cross-language pipelines for Java KafkaIO."""
    19  
    20  import contextlib
    21  import logging
    22  import os
    23  import socket
    24  import subprocess
    25  import sys
    26  import time
    27  import typing
    28  import unittest
    29  import uuid
    30  
    31  import apache_beam as beam
    32  from apache_beam.coders.coders import VarIntCoder
    33  from apache_beam.io.kafka import ReadFromKafka
    34  from apache_beam.io.kafka import WriteToKafka
    35  from apache_beam.metrics import Metrics
    36  from apache_beam.testing.test_pipeline import TestPipeline
    37  from apache_beam.testing.util import assert_that
    38  from apache_beam.testing.util import equal_to
    39  from apache_beam.transforms.userstate import BagStateSpec
    40  from apache_beam.transforms.userstate import CombiningValueStateSpec
    41  
    42  NUM_RECORDS = 1000
    43  
    44  
    45  class CollectingFn(beam.DoFn):
    46    BUFFER_STATE = BagStateSpec('buffer', VarIntCoder())
    47    COUNT_STATE = CombiningValueStateSpec('count', sum)
    48  
    49    def process(
    50        self,
    51        element,
    52        buffer_state=beam.DoFn.StateParam(BUFFER_STATE),
    53        count_state=beam.DoFn.StateParam(COUNT_STATE)):
    54      value = int(element[1].decode())
    55      buffer_state.add(value)
    56  
    57      count_state.add(1)
    58      count = count_state.read()
    59  
    60      if count >= NUM_RECORDS:
    61        yield sum(buffer_state.read())
    62        count_state.clear()
    63        buffer_state.clear()
    64  
    65  
    66  class CrossLanguageKafkaIO(object):
    67    def __init__(
    68        self, bootstrap_servers, topic, null_key, expansion_service=None):
    69      self.bootstrap_servers = bootstrap_servers
    70      self.topic = topic
    71      self.null_key = null_key
    72      self.expansion_service = expansion_service
    73      self.sum_counter = Metrics.counter('source', 'elements_sum')
    74  
    75    def build_write_pipeline(self, pipeline):
    76      _ = (
    77          pipeline
    78          | 'Generate' >> beam.Create(range(NUM_RECORDS))  # pylint: disable=bad-option-value
    79          | 'MakeKV' >> beam.Map(
    80              lambda x: (None if self.null_key else b'key', str(x).encode())).
    81          with_output_types(typing.Tuple[typing.Optional[bytes], bytes])
    82          | 'WriteToKafka' >> WriteToKafka(
    83              producer_config={'bootstrap.servers': self.bootstrap_servers},
    84              topic=self.topic,
    85              expansion_service=self.expansion_service))
    86  
    87    def build_read_pipeline(self, pipeline, max_num_records=None):
    88      kafka_records = (
    89          pipeline
    90          | 'ReadFromKafka' >> ReadFromKafka(
    91              consumer_config={
    92                  'bootstrap.servers': self.bootstrap_servers,
    93                  'auto.offset.reset': 'earliest'
    94              },
    95              topics=[self.topic],
    96              max_num_records=max_num_records,
    97              expansion_service=self.expansion_service))
    98  
    99      if max_num_records:
   100        return kafka_records
   101  
   102      return (
   103          kafka_records
   104          | 'CalculateSum' >> beam.ParDo(CollectingFn())
   105          | 'SetSumCounter' >> beam.Map(self.sum_counter.inc))
   106  
   107    def run_xlang_kafkaio(self, pipeline):
   108      self.build_write_pipeline(pipeline)
   109      self.build_read_pipeline(pipeline)
   110      pipeline.run(False)
   111  
   112  
   113  @unittest.skipUnless(
   114      os.environ.get('LOCAL_KAFKA_JAR'),
   115      "LOCAL_KAFKA_JAR environment var is not provided.")
   116  class CrossLanguageKafkaIOTest(unittest.TestCase):
   117    def test_kafkaio_populated_key(self):
   118      kafka_topic = 'xlang_kafkaio_test_populated_key_{}'.format(uuid.uuid4())
   119      local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR')
   120      with self.local_kafka_service(local_kafka_jar) as kafka_port:
   121        bootstrap_servers = '{}:{}'.format(
   122            self.get_platform_localhost(), kafka_port)
   123        pipeline_creator = CrossLanguageKafkaIO(
   124            bootstrap_servers, kafka_topic, False)
   125  
   126        self.run_kafka_write(pipeline_creator)
   127        self.run_kafka_read(pipeline_creator, b'key')
   128  
   129    def test_kafkaio_null_key(self):
   130      kafka_topic = 'xlang_kafkaio_test_null_key_{}'.format(uuid.uuid4())
   131      local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR')
   132      with self.local_kafka_service(local_kafka_jar) as kafka_port:
   133        bootstrap_servers = '{}:{}'.format(
   134            self.get_platform_localhost(), kafka_port)
   135        pipeline_creator = CrossLanguageKafkaIO(
   136            bootstrap_servers, kafka_topic, True)
   137  
   138        self.run_kafka_write(pipeline_creator)
   139        self.run_kafka_read(pipeline_creator, None)
   140  
   141    def run_kafka_write(self, pipeline_creator):
   142      with TestPipeline() as pipeline:
   143        pipeline.not_use_test_runner_api = True
   144        pipeline_creator.build_write_pipeline(pipeline)
   145  
   146    def run_kafka_read(self, pipeline_creator, expected_key):
   147      with TestPipeline() as pipeline:
   148        pipeline.not_use_test_runner_api = True
   149        result = pipeline_creator.build_read_pipeline(pipeline, NUM_RECORDS)
   150        assert_that(
   151            result,
   152            equal_to([(expected_key, str(i).encode())
   153                      for i in range(NUM_RECORDS)]))
   154  
   155    def get_platform_localhost(self):
   156      if sys.platform == 'darwin':
   157        return 'host.docker.internal'
   158      else:
   159        return 'localhost'
   160  
   161    def get_open_port(self):
   162      s = None
   163      try:
   164        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
   165      except:  # pylint: disable=bare-except
   166        # Above call will fail for nodes that only support IPv6.
   167        s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
   168      s.bind(('localhost', 0))
   169      s.listen(1)
   170      port = s.getsockname()[1]
   171      s.close()
   172      return port
   173  
   174    @contextlib.contextmanager
   175    def local_kafka_service(self, local_kafka_jar_file):
   176      kafka_port = str(self.get_open_port())
   177      zookeeper_port = str(self.get_open_port())
   178      kafka_server = None
   179      try:
   180        kafka_server = subprocess.Popen(
   181            ['java', '-jar', local_kafka_jar_file, kafka_port, zookeeper_port])
   182        time.sleep(3)
   183        yield kafka_port
   184      finally:
   185        if kafka_server:
   186          kafka_server.kill()
   187  
   188  
   189  if __name__ == '__main__':
   190    logging.getLogger().setLevel(logging.INFO)
   191    unittest.main()