github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_kafkaio_it_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Integration test for Python cross-language pipelines for Java KafkaIO.""" 19 20 import contextlib 21 import logging 22 import os 23 import socket 24 import subprocess 25 import sys 26 import time 27 import typing 28 import unittest 29 import uuid 30 31 import apache_beam as beam 32 from apache_beam.coders.coders import VarIntCoder 33 from apache_beam.io.kafka import ReadFromKafka 34 from apache_beam.io.kafka import WriteToKafka 35 from apache_beam.metrics import Metrics 36 from apache_beam.testing.test_pipeline import TestPipeline 37 from apache_beam.testing.util import assert_that 38 from apache_beam.testing.util import equal_to 39 from apache_beam.transforms.userstate import BagStateSpec 40 from apache_beam.transforms.userstate import CombiningValueStateSpec 41 42 NUM_RECORDS = 1000 43 44 45 class CollectingFn(beam.DoFn): 46 BUFFER_STATE = BagStateSpec('buffer', VarIntCoder()) 47 COUNT_STATE = CombiningValueStateSpec('count', sum) 48 49 def process( 50 self, 51 element, 52 buffer_state=beam.DoFn.StateParam(BUFFER_STATE), 53 count_state=beam.DoFn.StateParam(COUNT_STATE)): 54 value = int(element[1].decode()) 55 buffer_state.add(value) 56 57 count_state.add(1) 58 count = count_state.read() 59 60 if count >= NUM_RECORDS: 61 yield sum(buffer_state.read()) 62 count_state.clear() 63 buffer_state.clear() 64 65 66 class CrossLanguageKafkaIO(object): 67 def __init__( 68 self, bootstrap_servers, topic, null_key, expansion_service=None): 69 self.bootstrap_servers = bootstrap_servers 70 self.topic = topic 71 self.null_key = null_key 72 self.expansion_service = expansion_service 73 self.sum_counter = Metrics.counter('source', 'elements_sum') 74 75 def build_write_pipeline(self, pipeline): 76 _ = ( 77 pipeline 78 | 'Generate' >> beam.Create(range(NUM_RECORDS)) # pylint: disable=bad-option-value 79 | 'MakeKV' >> beam.Map( 80 lambda x: (None if self.null_key else b'key', str(x).encode())). 81 with_output_types(typing.Tuple[typing.Optional[bytes], bytes]) 82 | 'WriteToKafka' >> WriteToKafka( 83 producer_config={'bootstrap.servers': self.bootstrap_servers}, 84 topic=self.topic, 85 expansion_service=self.expansion_service)) 86 87 def build_read_pipeline(self, pipeline, max_num_records=None): 88 kafka_records = ( 89 pipeline 90 | 'ReadFromKafka' >> ReadFromKafka( 91 consumer_config={ 92 'bootstrap.servers': self.bootstrap_servers, 93 'auto.offset.reset': 'earliest' 94 }, 95 topics=[self.topic], 96 max_num_records=max_num_records, 97 expansion_service=self.expansion_service)) 98 99 if max_num_records: 100 return kafka_records 101 102 return ( 103 kafka_records 104 | 'CalculateSum' >> beam.ParDo(CollectingFn()) 105 | 'SetSumCounter' >> beam.Map(self.sum_counter.inc)) 106 107 def run_xlang_kafkaio(self, pipeline): 108 self.build_write_pipeline(pipeline) 109 self.build_read_pipeline(pipeline) 110 pipeline.run(False) 111 112 113 @unittest.skipUnless( 114 os.environ.get('LOCAL_KAFKA_JAR'), 115 "LOCAL_KAFKA_JAR environment var is not provided.") 116 class CrossLanguageKafkaIOTest(unittest.TestCase): 117 def test_kafkaio_populated_key(self): 118 kafka_topic = 'xlang_kafkaio_test_populated_key_{}'.format(uuid.uuid4()) 119 local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR') 120 with self.local_kafka_service(local_kafka_jar) as kafka_port: 121 bootstrap_servers = '{}:{}'.format( 122 self.get_platform_localhost(), kafka_port) 123 pipeline_creator = CrossLanguageKafkaIO( 124 bootstrap_servers, kafka_topic, False) 125 126 self.run_kafka_write(pipeline_creator) 127 self.run_kafka_read(pipeline_creator, b'key') 128 129 def test_kafkaio_null_key(self): 130 kafka_topic = 'xlang_kafkaio_test_null_key_{}'.format(uuid.uuid4()) 131 local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR') 132 with self.local_kafka_service(local_kafka_jar) as kafka_port: 133 bootstrap_servers = '{}:{}'.format( 134 self.get_platform_localhost(), kafka_port) 135 pipeline_creator = CrossLanguageKafkaIO( 136 bootstrap_servers, kafka_topic, True) 137 138 self.run_kafka_write(pipeline_creator) 139 self.run_kafka_read(pipeline_creator, None) 140 141 def run_kafka_write(self, pipeline_creator): 142 with TestPipeline() as pipeline: 143 pipeline.not_use_test_runner_api = True 144 pipeline_creator.build_write_pipeline(pipeline) 145 146 def run_kafka_read(self, pipeline_creator, expected_key): 147 with TestPipeline() as pipeline: 148 pipeline.not_use_test_runner_api = True 149 result = pipeline_creator.build_read_pipeline(pipeline, NUM_RECORDS) 150 assert_that( 151 result, 152 equal_to([(expected_key, str(i).encode()) 153 for i in range(NUM_RECORDS)])) 154 155 def get_platform_localhost(self): 156 if sys.platform == 'darwin': 157 return 'host.docker.internal' 158 else: 159 return 'localhost' 160 161 def get_open_port(self): 162 s = None 163 try: 164 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 165 except: # pylint: disable=bare-except 166 # Above call will fail for nodes that only support IPv6. 167 s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) 168 s.bind(('localhost', 0)) 169 s.listen(1) 170 port = s.getsockname()[1] 171 s.close() 172 return port 173 174 @contextlib.contextmanager 175 def local_kafka_service(self, local_kafka_jar_file): 176 kafka_port = str(self.get_open_port()) 177 zookeeper_port = str(self.get_open_port()) 178 kafka_server = None 179 try: 180 kafka_server = subprocess.Popen( 181 ['java', '-jar', local_kafka_jar_file, kafka_port, zookeeper_port]) 182 time.sleep(3) 183 yield kafka_port 184 finally: 185 if kafka_server: 186 kafka_server.kill() 187 188 189 if __name__ == '__main__': 190 logging.getLogger().setLevel(logging.INFO) 191 unittest.main()