github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/expansion_service_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import argparse 20 import logging 21 import pickle 22 import signal 23 import sys 24 import typing 25 26 import grpc 27 28 import apache_beam as beam 29 import apache_beam.transforms.combiners as combine 30 from apache_beam.coders import RowCoder 31 from apache_beam.pipeline import PipelineOptions 32 from apache_beam.portability.api import beam_artifact_api_pb2_grpc 33 from apache_beam.portability.api import beam_expansion_api_pb2_grpc 34 from apache_beam.portability.api import external_transforms_pb2 35 from apache_beam.runners.portability import artifact_service 36 from apache_beam.runners.portability import expansion_service 37 from apache_beam.runners.portability.stager import Stager 38 from apache_beam.transforms import fully_qualified_named_transform 39 from apache_beam.transforms import ptransform 40 from apache_beam.transforms.environments import PyPIArtifactRegistry 41 from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder 42 from apache_beam.utils import thread_pool_executor 43 44 # This script provides an expansion service and example ptransforms for running 45 # external transform test cases. See external_test.py for details. 46 47 _LOGGER = logging.getLogger(__name__) 48 49 TEST_PREFIX_URN = "beam:transforms:xlang:test:prefix" 50 TEST_MULTI_URN = "beam:transforms:xlang:test:multi" 51 TEST_GBK_URN = "beam:transforms:xlang:test:gbk" 52 TEST_CGBK_URN = "beam:transforms:xlang:test:cgbk" 53 TEST_COMGL_URN = "beam:transforms:xlang:test:comgl" 54 TEST_COMPK_URN = "beam:transforms:xlang:test:compk" 55 TEST_FLATTEN_URN = "beam:transforms:xlang:test:flatten" 56 TEST_PARTITION_URN = "beam:transforms:xlang:test:partition" 57 TEST_PYTHON_BS4_URN = "beam:transforms:xlang:test:python_bs4" 58 59 # A transform that does not produce an output. 60 TEST_NO_OUTPUT_URN = "beam:transforms:xlang:test:nooutput" 61 62 63 @ptransform.PTransform.register_urn('beam:transforms:xlang:count', None) 64 class CountPerElementTransform(ptransform.PTransform): 65 def expand(self, pcoll): 66 return pcoll | combine.Count.PerElement() 67 68 def to_runner_api_parameter(self, unused_context): 69 return 'beam:transforms:xlang:count', None 70 71 @staticmethod 72 def from_runner_api_parameter( 73 unused_ptransform, unused_parameter, unused_context): 74 return CountPerElementTransform() 75 76 77 @ptransform.PTransform.register_urn( 78 'beam:transforms:xlang:filter_less_than_eq', bytes) 79 class FilterLessThanTransform(ptransform.PTransform): 80 def __init__(self, payload): 81 self._payload = payload 82 83 def expand(self, pcoll): 84 return ( 85 pcoll | beam.Filter( 86 lambda elem, target: elem <= target, int(ord(self._payload[0])))) 87 88 def to_runner_api_parameter(self, unused_context): 89 return ( 90 'beam:transforms:xlang:filter_less_than', self._payload.encode('utf8')) 91 92 @staticmethod 93 def from_runner_api_parameter(unused_ptransform, payload, unused_context): 94 return FilterLessThanTransform(payload.decode('utf8')) 95 96 97 @ptransform.PTransform.register_urn(TEST_PREFIX_URN, None) 98 @beam.typehints.with_output_types(str) 99 class PrefixTransform(ptransform.PTransform): 100 def __init__(self, payload): 101 self._payload = payload 102 103 def expand(self, pcoll): 104 return pcoll | 'TestLabel' >> beam.Map( 105 lambda x: '{}{}'.format(self._payload, x)) 106 107 def to_runner_api_parameter(self, unused_context): 108 return TEST_PREFIX_URN, ImplicitSchemaPayloadBuilder( 109 {'data': self._payload}).payload() 110 111 @staticmethod 112 def from_runner_api_parameter(unused_ptransform, payload, unused_context): 113 return PrefixTransform(parse_string_payload(payload)['data']) 114 115 116 @ptransform.PTransform.register_urn(TEST_MULTI_URN, None) 117 class MutltiTransform(ptransform.PTransform): 118 def expand(self, pcolls): 119 return { 120 'main': (pcolls['main1'], pcolls['main2']) 121 | beam.Flatten() 122 | beam.Map(lambda x, s: x + s, beam.pvalue.AsSingleton( 123 pcolls['side'])).with_output_types(str), 124 'side': pcolls['side'] 125 | beam.Map(lambda x: x + x).with_output_types(str), 126 } 127 128 def to_runner_api_parameter(self, unused_context): 129 return TEST_MULTI_URN, None 130 131 @staticmethod 132 def from_runner_api_parameter( 133 unused_ptransform, unused_parameter, unused_context): 134 return MutltiTransform() 135 136 137 @ptransform.PTransform.register_urn(TEST_GBK_URN, None) 138 class GBKTransform(ptransform.PTransform): 139 def expand(self, pcoll): 140 return pcoll | 'TestLabel' >> beam.GroupByKey() 141 142 def to_runner_api_parameter(self, unused_context): 143 return TEST_GBK_URN, None 144 145 @staticmethod 146 def from_runner_api_parameter( 147 unused_ptransform, unused_parameter, unused_context): 148 return GBKTransform() 149 150 151 @ptransform.PTransform.register_urn(TEST_CGBK_URN, None) 152 class CoGBKTransform(ptransform.PTransform): 153 class ConcatFn(beam.DoFn): 154 def process(self, element): 155 (k, v) = element 156 return [(k, v['col1'] + v['col2'])] 157 158 def expand(self, pcoll): 159 return pcoll \ 160 | beam.CoGroupByKey() \ 161 | beam.ParDo(self.ConcatFn()).with_output_types( 162 typing.Tuple[int, typing.Iterable[str]]) 163 164 def to_runner_api_parameter(self, unused_context): 165 return TEST_CGBK_URN, None 166 167 @staticmethod 168 def from_runner_api_parameter( 169 unused_ptransform, unused_parameter, unused_context): 170 return CoGBKTransform() 171 172 173 @ptransform.PTransform.register_urn(TEST_COMGL_URN, None) 174 class CombineGloballyTransform(ptransform.PTransform): 175 def expand(self, pcoll): 176 return pcoll \ 177 | beam.CombineGlobally(sum).with_output_types(int) 178 179 def to_runner_api_parameter(self, unused_context): 180 return TEST_COMGL_URN, None 181 182 @staticmethod 183 def from_runner_api_parameter( 184 unused_ptransform, unused_parameter, unused_context): 185 return CombineGloballyTransform() 186 187 188 @ptransform.PTransform.register_urn(TEST_COMPK_URN, None) 189 class CombinePerKeyTransform(ptransform.PTransform): 190 def expand(self, pcoll): 191 output = pcoll \ 192 | beam.CombinePerKey(sum) 193 # TODO: Use `with_output_types` instead of explicitly 194 # assigning to `.element_type` after fixing BEAM-12872 195 output.element_type = beam.typehints.Tuple[str, int] 196 return output 197 198 def to_runner_api_parameter(self, unused_context): 199 return TEST_COMPK_URN, None 200 201 @staticmethod 202 def from_runner_api_parameter( 203 unused_ptransform, unused_parameter, unused_context): 204 return CombinePerKeyTransform() 205 206 207 @ptransform.PTransform.register_urn(TEST_FLATTEN_URN, None) 208 class FlattenTransform(ptransform.PTransform): 209 def expand(self, pcoll): 210 return pcoll.values() | beam.Flatten().with_output_types(int) 211 212 def to_runner_api_parameter(self, unused_context): 213 return TEST_FLATTEN_URN, None 214 215 @staticmethod 216 def from_runner_api_parameter( 217 unused_ptransform, unused_parameter, unused_context): 218 return FlattenTransform() 219 220 221 @ptransform.PTransform.register_urn(TEST_PARTITION_URN, None) 222 class PartitionTransform(ptransform.PTransform): 223 def expand(self, pcoll): 224 col1, col2 = pcoll | beam.Partition( 225 lambda elem, n: 0 if elem % 2 == 0 else 1, 2) 226 typed_col1 = col1 | beam.Map(lambda x: x).with_output_types(int) 227 typed_col2 = col2 | beam.Map(lambda x: x).with_output_types(int) 228 return {'0': typed_col1, '1': typed_col2} 229 230 def to_runner_api_parameter(self, unused_context): 231 return TEST_PARTITION_URN, None 232 233 @staticmethod 234 def from_runner_api_parameter( 235 unused_ptransform, unused_parameter, unused_context): 236 return PartitionTransform() 237 238 239 class ExtractHtmlTitleDoFn(beam.DoFn): 240 def process(self, element): 241 from bs4 import BeautifulSoup 242 soup = BeautifulSoup(element, 'html.parser') 243 return [soup.title.string] 244 245 246 @ptransform.PTransform.register_urn(TEST_PYTHON_BS4_URN, None) 247 class ExtractHtmlTitleTransform(ptransform.PTransform): 248 def expand(self, pcoll): 249 return pcoll | beam.ParDo(ExtractHtmlTitleDoFn()).with_output_types(str) 250 251 def to_runner_api_parameter(self, unused_context): 252 return TEST_PYTHON_BS4_URN, None 253 254 @staticmethod 255 def from_runner_api_parameter( 256 unused_ptransform, unused_parameter, unused_context): 257 return ExtractHtmlTitleTransform() 258 259 260 @ptransform.PTransform.register_urn('payload', bytes) 261 class PayloadTransform(ptransform.PTransform): 262 def __init__(self, payload): 263 self._payload = payload 264 265 def expand(self, pcoll): 266 return pcoll | beam.Map(lambda x, s: x + s, self._payload) 267 268 def to_runner_api_parameter(self, unused_context): 269 return b'payload', self._payload.encode('ascii') 270 271 @staticmethod 272 def from_runner_api_parameter(unused_ptransform, payload, unused_context): 273 return PayloadTransform(payload.decode('ascii')) 274 275 276 @ptransform.PTransform.register_urn('map_to_union_types', None) 277 class MapToUnionTypesTransform(ptransform.PTransform): 278 class CustomDoFn(beam.DoFn): 279 def process(self, element): 280 if element == 1: 281 return ['1'] 282 elif element == 2: 283 return [2] 284 else: 285 return [3.0] 286 287 def expand(self, pcoll): 288 return pcoll | beam.ParDo(self.CustomDoFn()) 289 290 def to_runner_api_parameter(self, unused_context): 291 return b'map_to_union_types', None 292 293 @staticmethod 294 def from_runner_api_parameter( 295 unused_ptransform, unused_payload, unused_context): 296 return MapToUnionTypesTransform() 297 298 299 @ptransform.PTransform.register_urn('fib', bytes) 300 class FibTransform(ptransform.PTransform): 301 def __init__(self, level): 302 self._level = level 303 304 def expand(self, p): 305 if self._level <= 2: 306 return p | beam.Create([1]) 307 else: 308 a = p | 'A' >> beam.ExternalTransform( 309 'fib', 310 str(self._level - 1).encode('ascii'), 311 expansion_service.ExpansionServiceServicer()) 312 b = p | 'B' >> beam.ExternalTransform( 313 'fib', 314 str(self._level - 2).encode('ascii'), 315 expansion_service.ExpansionServiceServicer()) 316 return ((a, b) 317 | beam.Flatten() 318 | beam.CombineGlobally(sum).without_defaults()) 319 320 def to_runner_api_parameter(self, unused_context): 321 return 'fib', str(self._level).encode('ascii') 322 323 @staticmethod 324 def from_runner_api_parameter(unused_ptransform, level, unused_context): 325 return FibTransform(int(level.decode('ascii'))) 326 327 328 @ptransform.PTransform.register_urn(TEST_NO_OUTPUT_URN, None) 329 class NoOutputTransform(ptransform.PTransform): 330 def expand(self, pcoll): 331 def log_val(val): 332 logging.debug('Got value: %r', val) 333 334 # Logging without returning anything 335 _ = (pcoll | 'TestLabel' >> beam.ParDo(log_val)) 336 337 def to_runner_api_parameter(self, unused_context): 338 return TEST_NO_OUTPUT_URN, None 339 340 @staticmethod 341 def from_runner_api_parameter(unused_ptransform, payload, unused_context): 342 return NoOutputTransform(parse_string_payload(payload)['data']) 343 344 345 def parse_string_payload(input_byte): 346 payload = external_transforms_pb2.ExternalConfigurationPayload() 347 payload.ParseFromString(input_byte) 348 349 return RowCoder(payload.schema).decode(payload.payload)._asdict() 350 351 352 def create_test_sklearn_model(file_name): 353 from sklearn import svm 354 x = [[0, 0], [1, 1]] 355 y = [0, 1] 356 model = svm.SVC() 357 model.fit(x, y) 358 with open(file_name, 'wb') as file: 359 pickle.dump(model, file) 360 361 362 def update_sklearn_model_dependency(env): 363 model_file = "/tmp/sklearn_test_model" 364 staged_name = "sklearn_model" 365 create_test_sklearn_model(model_file) 366 env._artifacts.append( 367 Stager._create_file_stage_to_artifact(model_file, staged_name)) 368 369 370 server = None 371 372 373 def cleanup(unused_signum, unused_frame): 374 _LOGGER.info('Shutting down expansion service.') 375 server.stop(None) 376 377 378 def main(unused_argv): 379 # TODO: use the regular expansion service (expansion_service_main) instead of 380 # this custom service for testing. 381 PyPIArtifactRegistry.register_artifact('beautifulsoup4', '>=4.9,<5.0') 382 parser = argparse.ArgumentParser() 383 parser.add_argument( 384 '-p', '--port', type=int, help='port on which to serve the job api') 385 parser.add_argument('--fully_qualified_name_glob', default=None) 386 options = parser.parse_args() 387 388 global server 389 with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter( 390 options.fully_qualified_name_glob): 391 server = grpc.server(thread_pool_executor.shared_unbounded_instance()) 392 expansion_servicer = expansion_service.ExpansionServiceServicer( 393 PipelineOptions([ 394 "--experiments", 395 "beam_fn_api", 396 "--sdk_location", 397 "container", 398 "--pickle_library", 399 "cloudpickle" 400 ])) 401 update_sklearn_model_dependency(expansion_servicer._default_environment) 402 beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server( 403 expansion_servicer, server) 404 beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server( 405 artifact_service.ArtifactRetrievalService( 406 artifact_service.BeamFilesystemHandler(None).file_reader), 407 server) 408 server.add_insecure_port('localhost:{}'.format(options.port)) 409 server.start() 410 _LOGGER.info('Listening for expansion requests at %d', options.port) 411 412 signal.signal(signal.SIGTERM, cleanup) 413 signal.signal(signal.SIGINT, cleanup) 414 # blocking main thread forever. 415 signal.pause() 416 417 418 if __name__ == '__main__': 419 logging.getLogger().setLevel(logging.INFO) 420 main(sys.argv)