github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/base_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for apache_beam.ml.base.""" 19 import math 20 import os 21 import pickle 22 import time 23 import unittest 24 from typing import Any 25 from typing import Dict 26 from typing import Iterable 27 from typing import Mapping 28 from typing import Optional 29 from typing import Sequence 30 31 import pytest 32 33 import apache_beam as beam 34 from apache_beam.examples.inference import run_inference_side_inputs 35 from apache_beam.metrics.metric import MetricsFilter 36 from apache_beam.ml.inference import base 37 from apache_beam.options.pipeline_options import StandardOptions 38 from apache_beam.testing.test_pipeline import TestPipeline 39 from apache_beam.testing.util import assert_that 40 from apache_beam.testing.util import equal_to 41 from apache_beam.transforms import trigger 42 from apache_beam.transforms import window 43 from apache_beam.transforms.periodicsequence import TimestampedValue 44 45 46 class FakeModel: 47 def predict(self, example: int) -> int: 48 return example + 1 49 50 51 class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): 52 def __init__( 53 self, 54 clock=None, 55 min_batch_size=1, 56 max_batch_size=9999, 57 multi_process_shared=False, 58 **kwargs): 59 self._fake_clock = clock 60 self._min_batch_size = min_batch_size 61 self._max_batch_size = max_batch_size 62 self._env_vars = kwargs.get('env_vars', {}) 63 self._multi_process_shared = multi_process_shared 64 65 def load_model(self): 66 if self._fake_clock: 67 self._fake_clock.current_time_ns += 500_000_000 # 500ms 68 return FakeModel() 69 70 def run_inference( 71 self, 72 batch: Sequence[int], 73 model: FakeModel, 74 inference_args=None) -> Iterable[int]: 75 multi_process_shared_loaded = "multi_process_shared" in str(type(model)) 76 if self._multi_process_shared != multi_process_shared_loaded: 77 raise Exception( 78 f'Loaded model of type {type(model)}, was' + 79 f'{"" if self._multi_process_shared else " not"} ' + 80 'expecting multi_process_shared_model') 81 if self._fake_clock: 82 self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds 83 for example in batch: 84 yield model.predict(example) 85 86 def update_model_path(self, model_path: Optional[str] = None): 87 pass 88 89 def batch_elements_kwargs(self): 90 return { 91 'min_batch_size': self._min_batch_size, 92 'max_batch_size': self._max_batch_size 93 } 94 95 def share_model_across_processes(self): 96 return self._multi_process_shared 97 98 99 class FakeModelHandlerReturnsPredictionResult( 100 base.ModelHandler[int, base.PredictionResult, FakeModel]): 101 def __init__( 102 self, 103 clock=None, 104 model_id='fake_model_id_default', 105 multi_process_shared=False): 106 self.model_id = model_id 107 self._fake_clock = clock 108 self._env_vars = {} 109 self._multi_process_shared = multi_process_shared 110 111 def load_model(self): 112 return FakeModel() 113 114 def run_inference( 115 self, 116 batch: Sequence[int], 117 model: FakeModel, 118 inference_args=None) -> Iterable[base.PredictionResult]: 119 multi_process_shared_loaded = "multi_process_shared" in str(type(model)) 120 if self._multi_process_shared != multi_process_shared_loaded: 121 raise Exception( 122 f'Loaded model of type {type(model)}, was' + 123 f'{"" if self._multi_process_shared else " not"} ' + 124 'expecting multi_process_shared_model') 125 for example in batch: 126 yield base.PredictionResult( 127 model_id=self.model_id, 128 example=example, 129 inference=model.predict(example)) 130 131 def update_model_path(self, model_path: Optional[str] = None): 132 self.model_id = model_path if model_path else self.model_id 133 134 def share_model_across_processes(self): 135 return self._multi_process_shared 136 137 138 class FakeModelHandlerNoEnvVars(base.ModelHandler[int, int, FakeModel]): 139 def __init__( 140 self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs): 141 self._fake_clock = clock 142 self._min_batch_size = min_batch_size 143 self._max_batch_size = max_batch_size 144 145 def load_model(self): 146 if self._fake_clock: 147 self._fake_clock.current_time_ns += 500_000_000 # 500ms 148 return FakeModel() 149 150 def run_inference( 151 self, 152 batch: Sequence[int], 153 model: FakeModel, 154 inference_args=None) -> Iterable[int]: 155 if self._fake_clock: 156 self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds 157 for example in batch: 158 yield model.predict(example) 159 160 def update_model_path(self, model_path: Optional[str] = None): 161 pass 162 163 def batch_elements_kwargs(self): 164 return { 165 'min_batch_size': self._min_batch_size, 166 'max_batch_size': self._max_batch_size 167 } 168 169 170 class FakeClock: 171 def __init__(self): 172 # Start at 10 seconds. 173 self.current_time_ns = 10_000_000_000 174 175 def time_ns(self) -> int: 176 return self.current_time_ns 177 178 179 class ExtractInferences(beam.DoFn): 180 def process(self, prediction_result): 181 yield prediction_result.inference 182 183 184 class FakeModelHandlerNeedsBigBatch(FakeModelHandler): 185 def run_inference(self, batch, unused_model, inference_args=None): 186 if len(batch) < 100: 187 raise ValueError('Unexpectedly small batch') 188 return batch 189 190 def batch_elements_kwargs(self): 191 return {'min_batch_size': 9999} 192 193 194 class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler): 195 def run_inference(self, batch, unused_model, inference_args=None): 196 raise ValueError( 197 'run_inference should not be called because error should already be ' 198 'thrown from the validate_inference_args check.') 199 200 201 class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler): 202 def run_inference(self, batch, unused_model, inference_args=None): 203 if not inference_args: 204 raise ValueError('inference_args should exist') 205 return batch 206 207 def validate_inference_args(self, inference_args): 208 pass 209 210 211 class RunInferenceBaseTest(unittest.TestCase): 212 def test_run_inference_impl_simple_examples(self): 213 with TestPipeline() as pipeline: 214 examples = [1, 5, 3, 10] 215 expected = [example + 1 for example in examples] 216 pcoll = pipeline | 'start' >> beam.Create(examples) 217 actual = pcoll | base.RunInference(FakeModelHandler()) 218 assert_that(actual, equal_to(expected), label='assert:inferences') 219 220 def test_run_inference_impl_simple_examples_multi_process_shared(self): 221 with TestPipeline() as pipeline: 222 examples = [1, 5, 3, 10] 223 expected = [example + 1 for example in examples] 224 pcoll = pipeline | 'start' >> beam.Create(examples) 225 actual = pcoll | base.RunInference( 226 FakeModelHandler(multi_process_shared=True)) 227 assert_that(actual, equal_to(expected), label='assert:inferences') 228 229 def test_run_inference_impl_with_keyed_examples(self): 230 with TestPipeline() as pipeline: 231 examples = [1, 5, 3, 10] 232 keyed_examples = [(i, example) for i, example in enumerate(examples)] 233 expected = [(i, example + 1) for i, example in enumerate(examples)] 234 pcoll = pipeline | 'start' >> beam.Create(keyed_examples) 235 actual = pcoll | base.RunInference( 236 base.KeyedModelHandler(FakeModelHandler())) 237 assert_that(actual, equal_to(expected), label='assert:inferences') 238 239 def test_run_inference_impl_with_maybe_keyed_examples(self): 240 with TestPipeline() as pipeline: 241 examples = [1, 5, 3, 10] 242 keyed_examples = [(i, example) for i, example in enumerate(examples)] 243 expected = [example + 1 for example in examples] 244 keyed_expected = [(i, example + 1) for i, example in enumerate(examples)] 245 model_handler = base.MaybeKeyedModelHandler(FakeModelHandler()) 246 247 pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) 248 actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler) 249 assert_that(actual, equal_to(expected), label='CheckUnkeyed') 250 251 keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples) 252 keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference( 253 model_handler) 254 assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') 255 256 def test_run_inference_impl_with_keyed_examples_multi_process_shared(self): 257 with TestPipeline() as pipeline: 258 examples = [1, 5, 3, 10] 259 keyed_examples = [(i, example) for i, example in enumerate(examples)] 260 expected = [(i, example + 1) for i, example in enumerate(examples)] 261 pcoll = pipeline | 'start' >> beam.Create(keyed_examples) 262 actual = pcoll | base.RunInference( 263 base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True))) 264 assert_that(actual, equal_to(expected), label='assert:inferences') 265 266 def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared( 267 self): 268 with TestPipeline() as pipeline: 269 examples = [1, 5, 3, 10] 270 keyed_examples = [(i, example) for i, example in enumerate(examples)] 271 expected = [example + 1 for example in examples] 272 keyed_expected = [(i, example + 1) for i, example in enumerate(examples)] 273 model_handler = base.MaybeKeyedModelHandler( 274 FakeModelHandler(multi_process_shared=True)) 275 276 pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) 277 actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler) 278 assert_that(actual, equal_to(expected), label='CheckUnkeyed') 279 280 keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples) 281 keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference( 282 model_handler) 283 assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') 284 285 def test_run_inference_preprocessing(self): 286 def mult_two(example: str) -> int: 287 return int(example) * 2 288 289 with TestPipeline() as pipeline: 290 examples = ["1", "5", "3", "10"] 291 expected = [int(example) * 2 + 1 for example in examples] 292 pcoll = pipeline | 'start' >> beam.Create(examples) 293 actual = pcoll | base.RunInference( 294 FakeModelHandler().with_preprocess_fn(mult_two)) 295 assert_that(actual, equal_to(expected), label='assert:inferences') 296 297 def test_run_inference_preprocessing_multiple_fns(self): 298 def add_one(example: str) -> int: 299 return int(example) + 1 300 301 def mult_two(example: int) -> int: 302 return example * 2 303 304 with TestPipeline() as pipeline: 305 examples = ["1", "5", "3", "10"] 306 expected = [(int(example) + 1) * 2 + 1 for example in examples] 307 pcoll = pipeline | 'start' >> beam.Create(examples) 308 actual = pcoll | base.RunInference( 309 FakeModelHandler().with_preprocess_fn(mult_two).with_preprocess_fn( 310 add_one)) 311 assert_that(actual, equal_to(expected), label='assert:inferences') 312 313 def test_run_inference_postprocessing(self): 314 def mult_two(example: int) -> str: 315 return str(example * 2) 316 317 with TestPipeline() as pipeline: 318 examples = [1, 5, 3, 10] 319 expected = [str((example + 1) * 2) for example in examples] 320 pcoll = pipeline | 'start' >> beam.Create(examples) 321 actual = pcoll | base.RunInference( 322 FakeModelHandler().with_postprocess_fn(mult_two)) 323 assert_that(actual, equal_to(expected), label='assert:inferences') 324 325 def test_run_inference_postprocessing_multiple_fns(self): 326 def add_one(example: int) -> str: 327 return str(int(example) + 1) 328 329 def mult_two(example: int) -> int: 330 return example * 2 331 332 with TestPipeline() as pipeline: 333 examples = [1, 5, 3, 10] 334 expected = [str(((example + 1) * 2) + 1) for example in examples] 335 pcoll = pipeline | 'start' >> beam.Create(examples) 336 actual = pcoll | base.RunInference( 337 FakeModelHandler().with_postprocess_fn(mult_two).with_postprocess_fn( 338 add_one)) 339 assert_that(actual, equal_to(expected), label='assert:inferences') 340 341 def test_run_inference_preprocessing_dlq(self): 342 def mult_two(example: str) -> int: 343 if example == "5": 344 raise Exception("TEST") 345 return int(example) * 2 346 347 with TestPipeline() as pipeline: 348 examples = ["1", "5", "3", "10"] 349 expected = [3, 7, 21] 350 expected_bad = ["5"] 351 pcoll = pipeline | 'start' >> beam.Create(examples) 352 main, other = pcoll | base.RunInference( 353 FakeModelHandler().with_preprocess_fn(mult_two) 354 ).with_exception_handling() 355 assert_that(main, equal_to(expected), label='assert:inferences') 356 assert_that( 357 other.failed_inferences, equal_to([]), label='assert:bad_infer') 358 359 # bad will be in form [element, error]. Just pull out bad element. 360 bad_without_error = other.failed_preprocessing[0] | beam.Map( 361 lambda x: x[0]) 362 assert_that( 363 bad_without_error, equal_to(expected_bad), label='assert:failures') 364 365 def test_run_inference_postprocessing_dlq(self): 366 def mult_two(example: int) -> str: 367 if example == 6: 368 raise Exception("TEST") 369 return str(example * 2) 370 371 with TestPipeline() as pipeline: 372 examples = [1, 5, 3, 10] 373 expected = ["4", "8", "22"] 374 expected_bad = [6] 375 pcoll = pipeline | 'start' >> beam.Create(examples) 376 main, other = pcoll | base.RunInference( 377 FakeModelHandler().with_postprocess_fn(mult_two) 378 ).with_exception_handling() 379 assert_that(main, equal_to(expected), label='assert:inferences') 380 assert_that( 381 other.failed_inferences, equal_to([]), label='assert:bad_infer') 382 383 # bad will be in form [element, error]. Just pull out bad element. 384 bad_without_error = other.failed_postprocessing[0] | beam.Map( 385 lambda x: x[0]) 386 assert_that( 387 bad_without_error, equal_to(expected_bad), label='assert:failures') 388 389 def test_run_inference_pre_and_post_processing_dlq(self): 390 def mult_two_pre(example: str) -> int: 391 if example == "5": 392 raise Exception("TEST") 393 return int(example) * 2 394 395 def mult_two_post(example: int) -> str: 396 if example == 7: 397 raise Exception("TEST") 398 return str(example * 2) 399 400 with TestPipeline() as pipeline: 401 examples = ["1", "5", "3", "10"] 402 expected = ["6", "42"] 403 expected_bad_pre = ["5"] 404 expected_bad_post = [7] 405 pcoll = pipeline | 'start' >> beam.Create(examples) 406 main, other = pcoll | base.RunInference( 407 FakeModelHandler().with_preprocess_fn( 408 mult_two_pre 409 ).with_postprocess_fn( 410 mult_two_post 411 )).with_exception_handling() 412 assert_that(main, equal_to(expected), label='assert:inferences') 413 assert_that( 414 other.failed_inferences, equal_to([]), label='assert:bad_infer') 415 416 # bad will be in form [elements, error]. Just pull out bad element. 417 bad_without_error_pre = other.failed_preprocessing[0] | beam.Map( 418 lambda x: x[0]) 419 assert_that( 420 bad_without_error_pre, 421 equal_to(expected_bad_pre), 422 label='assert:failures_pre') 423 424 # bad will be in form [elements, error]. Just pull out bad element. 425 bad_without_error_post = other.failed_postprocessing[0] | beam.Map( 426 lambda x: x[0]) 427 assert_that( 428 bad_without_error_post, 429 equal_to(expected_bad_post), 430 label='assert:failures_post') 431 432 def test_run_inference_keyed_pre_and_post_processing(self): 433 def mult_two(element): 434 return (element[0], element[1] * 2) 435 436 with TestPipeline() as pipeline: 437 examples = [1, 5, 3, 10] 438 keyed_examples = [(i, example) for i, example in enumerate(examples)] 439 expected = [ 440 (i, ((example * 2) + 1) * 2) for i, example in enumerate(examples) 441 ] 442 pcoll = pipeline | 'start' >> beam.Create(keyed_examples) 443 actual = pcoll | base.RunInference( 444 base.KeyedModelHandler(FakeModelHandler()).with_preprocess_fn( 445 mult_two).with_postprocess_fn(mult_two)) 446 assert_that(actual, equal_to(expected), label='assert:inferences') 447 448 def test_run_inference_maybe_keyed_pre_and_post_processing(self): 449 def mult_two(element): 450 return element * 2 451 452 def mult_two_keyed(element): 453 return (element[0], element[1] * 2) 454 455 with TestPipeline() as pipeline: 456 examples = [1, 5, 3, 10] 457 keyed_examples = [(i, example) for i, example in enumerate(examples)] 458 expected = [((2 * example) + 1) * 2 for example in examples] 459 keyed_expected = [ 460 (i, ((2 * example) + 1) * 2) for i, example in enumerate(examples) 461 ] 462 model_handler = base.MaybeKeyedModelHandler(FakeModelHandler()) 463 464 pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) 465 actual = pcoll | 'RunUnkeyed' >> base.RunInference( 466 model_handler.with_preprocess_fn(mult_two).with_postprocess_fn( 467 mult_two)) 468 assert_that(actual, equal_to(expected), label='CheckUnkeyed') 469 470 keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples) 471 keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference( 472 model_handler.with_preprocess_fn(mult_two_keyed).with_postprocess_fn( 473 mult_two_keyed)) 474 assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') 475 476 def test_run_inference_impl_dlq(self): 477 with TestPipeline() as pipeline: 478 examples = [1, 'TEST', 3, 10, 'TEST2'] 479 expected_good = [2, 4, 11] 480 expected_bad = ['TEST', 'TEST2'] 481 pcoll = pipeline | 'start' >> beam.Create(examples) 482 main, other = pcoll | base.RunInference( 483 FakeModelHandler( 484 min_batch_size=1, 485 max_batch_size=1 486 )).with_exception_handling() 487 assert_that(main, equal_to(expected_good), label='assert:inferences') 488 489 # bad.failed_inferences will be in form [batch[elements], error]. 490 # Just pull out bad element. 491 bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0]) 492 assert_that( 493 bad_without_error, equal_to(expected_bad), label='assert:failures') 494 495 def test_run_inference_impl_inference_args(self): 496 with TestPipeline() as pipeline: 497 examples = [1, 5, 3, 10] 498 pcoll = pipeline | 'start' >> beam.Create(examples) 499 inference_args = {'key': True} 500 actual = pcoll | base.RunInference( 501 FakeModelHandlerExpectedInferenceArgs(), 502 inference_args=inference_args) 503 assert_that(actual, equal_to(examples), label='assert:inferences') 504 505 def test_run_inference_metrics_with_custom_namespace(self): 506 metrics_namespace = 'my_custom_namespace' 507 pipeline = TestPipeline() 508 examples = [1, 5, 3, 10] 509 pcoll = pipeline | 'start' >> beam.Create(examples) 510 _ = pcoll | base.RunInference( 511 FakeModelHandler(), metrics_namespace=metrics_namespace) 512 result = pipeline.run() 513 result.wait_until_finish() 514 515 metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace) 516 metrics = result.metrics().query(metrics_filter) 517 assert len(metrics['counters']) != 0 518 assert len(metrics['distributions']) != 0 519 520 metrics_filter = MetricsFilter().with_namespace(namespace='fake_namespace') 521 metrics = result.metrics().query(metrics_filter) 522 assert len(metrics['counters']) == len(metrics['distributions']) == 0 523 524 def test_unexpected_inference_args_passed(self): 525 with self.assertRaisesRegex(ValueError, r'inference_args were provided'): 526 with TestPipeline() as pipeline: 527 examples = [1, 5, 3, 10] 528 pcoll = pipeline | 'start' >> beam.Create(examples) 529 inference_args = {'key': True} 530 _ = pcoll | base.RunInference( 531 FakeModelHandlerFailsOnInferenceArgs(), 532 inference_args=inference_args) 533 534 def test_increment_failed_batches_counter(self): 535 with self.assertRaises(ValueError): 536 with TestPipeline() as pipeline: 537 examples = [7] 538 pcoll = pipeline | 'start' >> beam.Create(examples) 539 _ = pcoll | base.RunInference(FakeModelHandlerExpectedInferenceArgs()) 540 run_result = pipeline.run() 541 run_result.wait_until_finish() 542 543 metric_results = ( 544 run_result.metrics().query( 545 MetricsFilter().with_name('failed_batches_counter'))) 546 num_failed_batches_counter = metric_results['counters'][0] 547 self.assertEqual(num_failed_batches_counter.committed, 3) 548 # !!!: The above will need to be updated if retry behavior changes 549 550 def test_failed_batches_counter_no_failures(self): 551 pipeline = TestPipeline() 552 examples = [7] 553 pcoll = pipeline | 'start' >> beam.Create(examples) 554 inference_args = {'key': True} 555 _ = pcoll | base.RunInference( 556 FakeModelHandlerExpectedInferenceArgs(), inference_args=inference_args) 557 run_result = pipeline.run() 558 run_result.wait_until_finish() 559 560 metric_results = ( 561 run_result.metrics().query( 562 MetricsFilter().with_name('failed_batches_counter'))) 563 self.assertEqual(len(metric_results['counters']), 0) 564 565 def test_counted_metrics(self): 566 pipeline = TestPipeline() 567 examples = [1, 5, 3, 10] 568 pcoll = pipeline | 'start' >> beam.Create(examples) 569 _ = pcoll | base.RunInference(FakeModelHandler()) 570 run_result = pipeline.run() 571 run_result.wait_until_finish() 572 573 metric_results = ( 574 run_result.metrics().query(MetricsFilter().with_name('num_inferences'))) 575 num_inferences_counter = metric_results['counters'][0] 576 self.assertEqual(num_inferences_counter.committed, 4) 577 578 inference_request_batch_size = run_result.metrics().query( 579 MetricsFilter().with_name('inference_request_batch_size')) 580 self.assertTrue(inference_request_batch_size['distributions']) 581 self.assertEqual( 582 inference_request_batch_size['distributions'][0].result.sum, 4) 583 inference_request_batch_byte_size = run_result.metrics().query( 584 MetricsFilter().with_name('inference_request_batch_byte_size')) 585 self.assertTrue(inference_request_batch_byte_size['distributions']) 586 self.assertGreaterEqual( 587 inference_request_batch_byte_size['distributions'][0].result.sum, 588 len(pickle.dumps(examples))) 589 inference_request_batch_byte_size = run_result.metrics().query( 590 MetricsFilter().with_name('model_byte_size')) 591 self.assertTrue(inference_request_batch_byte_size['distributions']) 592 593 def test_timing_metrics(self): 594 pipeline = TestPipeline() 595 examples = [1, 5, 3, 10] 596 pcoll = pipeline | 'start' >> beam.Create(examples) 597 fake_clock = FakeClock() 598 _ = pcoll | base.RunInference( 599 FakeModelHandler(clock=fake_clock), clock=fake_clock) 600 res = pipeline.run() 601 res.wait_until_finish() 602 603 metric_results = ( 604 res.metrics().query( 605 MetricsFilter().with_name('inference_batch_latency_micro_secs'))) 606 batch_latency = metric_results['distributions'][0] 607 self.assertEqual(batch_latency.result.count, 3) 608 self.assertEqual(batch_latency.result.mean, 3000) 609 610 metric_results = ( 611 res.metrics().query( 612 MetricsFilter().with_name('load_model_latency_milli_secs'))) 613 load_model_latency = metric_results['distributions'][0] 614 self.assertEqual(load_model_latency.result.count, 1) 615 self.assertEqual(load_model_latency.result.mean, 500) 616 617 def test_forwards_batch_args(self): 618 examples = list(range(100)) 619 with TestPipeline() as pipeline: 620 pcoll = pipeline | 'start' >> beam.Create(examples) 621 actual = pcoll | base.RunInference(FakeModelHandlerNeedsBigBatch()) 622 assert_that(actual, equal_to(examples), label='assert:inferences') 623 624 def test_run_inference_unkeyed_examples_with_keyed_model_handler(self): 625 pipeline = TestPipeline() 626 with self.assertRaises(TypeError): 627 examples = [1, 3, 5] 628 model_handler = base.KeyedModelHandler(FakeModelHandler()) 629 _ = ( 630 pipeline | 'Unkeyed' >> beam.Create(examples) 631 | 'RunUnkeyed' >> base.RunInference(model_handler)) 632 pipeline.run() 633 634 def test_run_inference_keyed_examples_with_unkeyed_model_handler(self): 635 pipeline = TestPipeline() 636 examples = [1, 3, 5] 637 keyed_examples = [(i, example) for i, example in enumerate(examples)] 638 model_handler = FakeModelHandler() 639 with self.assertRaises(TypeError): 640 _ = ( 641 pipeline | 'keyed' >> beam.Create(keyed_examples) 642 | 'RunKeyed' >> base.RunInference(model_handler)) 643 pipeline.run() 644 645 def test_model_handler_compatibility(self): 646 # ** IMPORTANT ** Do not change this test to make your PR pass without 647 # first reading below. 648 # Be certain that the modification will not break third party 649 # implementations of ModelHandler. 650 # See issue https://github.com/apache/beam/issues/23484 651 # If this test fails, likely third party implementations of 652 # ModelHandler will break. 653 class ThirdPartyHandler(base.ModelHandler[int, int, FakeModel]): 654 def __init__(self, custom_parameter=None): 655 pass 656 657 def load_model(self) -> FakeModel: 658 return FakeModel() 659 660 def run_inference( 661 self, 662 batch: Sequence[int], 663 model: FakeModel, 664 inference_args: Optional[Dict[str, Any]] = None) -> Iterable[int]: 665 yield 0 666 667 def get_num_bytes(self, batch: Sequence[int]) -> int: 668 return 1 669 670 def get_metrics_namespace(self) -> str: 671 return 'ThirdParty' 672 673 def get_resource_hints(self) -> dict: 674 return {} 675 676 def batch_elements_kwargs(self) -> Mapping[str, Any]: 677 return {} 678 679 def validate_inference_args( 680 self, inference_args: Optional[Dict[str, Any]]): 681 pass 682 683 # This test passes if calling these methods does not cause 684 # any runtime exceptions. 685 third_party_model_handler = ThirdPartyHandler(custom_parameter=0) 686 fake_model = third_party_model_handler.load_model() 687 third_party_model_handler.run_inference([], fake_model) 688 fake_inference_args = {'some_arg': 1} 689 third_party_model_handler.run_inference([], 690 fake_model, 691 inference_args=fake_inference_args) 692 third_party_model_handler.get_num_bytes([1, 2, 3]) 693 third_party_model_handler.get_metrics_namespace() 694 third_party_model_handler.get_resource_hints() 695 third_party_model_handler.batch_elements_kwargs() 696 third_party_model_handler.validate_inference_args({}) 697 698 def test_run_inference_prediction_result_with_model_id(self): 699 examples = [1, 5, 3, 10] 700 expected = [ 701 base.PredictionResult( 702 example=example, 703 inference=example + 1, 704 model_id='fake_model_id_default') for example in examples 705 ] 706 with TestPipeline() as pipeline: 707 pcoll = pipeline | 'start' >> beam.Create(examples) 708 actual = pcoll | base.RunInference( 709 FakeModelHandlerReturnsPredictionResult()) 710 assert_that(actual, equal_to(expected), label='assert:inferences') 711 712 def test_run_inference_with_iterable_side_input(self): 713 test_pipeline = TestPipeline() 714 side_input = ( 715 test_pipeline | "CreateDummySideInput" >> beam.Create( 716 [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)]) 717 | "ApplySideInputWindow" >> beam.WindowInto( 718 window.GlobalWindows(), 719 trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), 720 accumulation_mode=trigger.AccumulationMode.DISCARDING)) 721 722 test_pipeline.options.view_as(StandardOptions).streaming = True 723 with self.assertRaises(ValueError) as e: 724 _ = ( 725 test_pipeline 726 | beam.Create([1, 2, 3, 4]) 727 | base.RunInference( 728 FakeModelHandler(), model_metadata_pcoll=side_input)) 729 test_pipeline.run() 730 731 self.assertTrue( 732 'PCollection of size 2 with more than one element accessed as a ' 733 'singleton view. First two elements encountered are' in str( 734 e.exception)) 735 736 def test_run_inference_with_iterable_side_input_multi_process_shared(self): 737 test_pipeline = TestPipeline() 738 side_input = ( 739 test_pipeline | "CreateDummySideInput" >> beam.Create( 740 [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)]) 741 | "ApplySideInputWindow" >> beam.WindowInto( 742 window.GlobalWindows(), 743 trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), 744 accumulation_mode=trigger.AccumulationMode.DISCARDING)) 745 746 test_pipeline.options.view_as(StandardOptions).streaming = True 747 with self.assertRaises(ValueError) as e: 748 _ = ( 749 test_pipeline 750 | beam.Create([1, 2, 3, 4]) 751 | base.RunInference( 752 FakeModelHandler(multi_process_shared=True), 753 model_metadata_pcoll=side_input)) 754 test_pipeline.run() 755 756 self.assertTrue( 757 'PCollection of size 2 with more than one element accessed as a ' 758 'singleton view. First two elements encountered are' in str( 759 e.exception)) 760 761 def test_run_inference_empty_side_input(self): 762 model_handler = FakeModelHandlerReturnsPredictionResult() 763 main_input_elements = [1, 2] 764 with TestPipeline(is_integration_test=False) as pipeline: 765 side_pcoll = pipeline | "side" >> beam.Create([]) 766 result_pcoll = ( 767 pipeline 768 | beam.Create(main_input_elements) 769 | base.RunInference(model_handler, model_metadata_pcoll=side_pcoll)) 770 expected = [ 771 base.PredictionResult(ele, ele + 1, 'fake_model_id_default') 772 for ele in main_input_elements 773 ] 774 775 assert_that(result_pcoll, equal_to(expected)) 776 777 def test_run_inference_side_input_in_batch(self): 778 first_ts = math.floor(time.time()) - 30 779 interval = 7 780 781 sample_main_input_elements = ([ 782 first_ts - 2, 783 first_ts + 1, 784 first_ts + 8, 785 first_ts + 15, 786 first_ts + 22, 787 ]) 788 789 sample_side_input_elements = [ 790 (first_ts + 1, base.ModelMetadata(model_id='', model_name='')), 791 # if model_id is empty string, we use the default model 792 # handler model URI. 793 ( 794 first_ts + 8, 795 base.ModelMetadata( 796 model_id='fake_model_id_1', model_name='fake_model_id_1')), 797 ( 798 first_ts + 15, 799 base.ModelMetadata( 800 model_id='fake_model_id_2', model_name='fake_model_id_2')) 801 ] 802 803 model_handler = FakeModelHandlerReturnsPredictionResult() 804 805 # applying GroupByKey to utilize windowing according to 806 # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections 807 class _EmitElement(beam.DoFn): 808 def process(self, element): 809 for e in element: 810 yield e 811 812 with TestPipeline() as pipeline: 813 side_input = ( 814 pipeline 815 | 816 "CreateSideInputElements" >> beam.Create(sample_side_input_elements) 817 | beam.Map(lambda x: TimestampedValue(x[1], x[0])) 818 | beam.WindowInto( 819 window.FixedWindows(interval), 820 accumulation_mode=trigger.AccumulationMode.DISCARDING) 821 | beam.Map(lambda x: ('key', x)) 822 | beam.GroupByKey() 823 | beam.Map(lambda x: x[1]) 824 | "EmitSideInput" >> beam.ParDo(_EmitElement())) 825 826 result_pcoll = ( 827 pipeline 828 | beam.Create(sample_main_input_elements) 829 | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x)) 830 | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval)) 831 | beam.Map(lambda x: ('key', x)) 832 | "MainInputGBK" >> beam.GroupByKey() 833 | beam.Map(lambda x: x[1]) 834 | beam.ParDo(_EmitElement()) 835 | "RunInference" >> base.RunInference( 836 model_handler, model_metadata_pcoll=side_input)) 837 838 expected_model_id_order = [ 839 'fake_model_id_default', 840 'fake_model_id_default', 841 'fake_model_id_1', 842 'fake_model_id_2', 843 'fake_model_id_2' 844 ] 845 expected_result = [ 846 base.PredictionResult( 847 example=sample_main_input_elements[i], 848 inference=sample_main_input_elements[i] + 1, 849 model_id=expected_model_id_order[i]) for i in range(5) 850 ] 851 852 assert_that(result_pcoll, equal_to(expected_result)) 853 854 def test_run_inference_side_input_in_batch_multi_process_shared(self): 855 first_ts = math.floor(time.time()) - 30 856 interval = 7 857 858 sample_main_input_elements = ([ 859 first_ts - 2, 860 first_ts + 1, 861 first_ts + 8, 862 first_ts + 15, 863 first_ts + 22, 864 ]) 865 866 sample_side_input_elements = [ 867 (first_ts + 1, base.ModelMetadata(model_id='', model_name='')), 868 # if model_id is empty string, we use the default model 869 # handler model URI. 870 ( 871 first_ts + 8, 872 base.ModelMetadata( 873 model_id='fake_model_id_1', model_name='fake_model_id_1')), 874 ( 875 first_ts + 15, 876 base.ModelMetadata( 877 model_id='fake_model_id_2', model_name='fake_model_id_2')) 878 ] 879 880 model_handler = FakeModelHandlerReturnsPredictionResult( 881 multi_process_shared=True) 882 883 # applying GroupByKey to utilize windowing according to 884 # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections 885 class _EmitElement(beam.DoFn): 886 def process(self, element): 887 for e in element: 888 yield e 889 890 with TestPipeline() as pipeline: 891 side_input = ( 892 pipeline 893 | 894 "CreateSideInputElements" >> beam.Create(sample_side_input_elements) 895 | beam.Map(lambda x: TimestampedValue(x[1], x[0])) 896 | beam.WindowInto( 897 window.FixedWindows(interval), 898 accumulation_mode=trigger.AccumulationMode.DISCARDING) 899 | beam.Map(lambda x: ('key', x)) 900 | beam.GroupByKey() 901 | beam.Map(lambda x: x[1]) 902 | "EmitSideInput" >> beam.ParDo(_EmitElement())) 903 904 result_pcoll = ( 905 pipeline 906 | beam.Create(sample_main_input_elements) 907 | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x)) 908 | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval)) 909 | beam.Map(lambda x: ('key', x)) 910 | "MainInputGBK" >> beam.GroupByKey() 911 | beam.Map(lambda x: x[1]) 912 | beam.ParDo(_EmitElement()) 913 | "RunInference" >> base.RunInference( 914 model_handler, model_metadata_pcoll=side_input)) 915 916 expected_model_id_order = [ 917 'fake_model_id_default', 918 'fake_model_id_default', 919 'fake_model_id_1', 920 'fake_model_id_2', 921 'fake_model_id_2' 922 ] 923 expected_result = [ 924 base.PredictionResult( 925 example=sample_main_input_elements[i], 926 inference=sample_main_input_elements[i] + 1, 927 model_id=expected_model_id_order[i]) for i in range(5) 928 ] 929 930 assert_that(result_pcoll, equal_to(expected_result)) 931 932 @unittest.skipIf( 933 not TestPipeline().get_pipeline_options().view_as( 934 StandardOptions).streaming, 935 "SideInputs to RunInference are only supported in streaming mode.") 936 @pytest.mark.it_postcommit 937 @pytest.mark.sickbay_direct 938 @pytest.mark.it_validatesrunner 939 def test_run_inference_with_side_inputin_streaming(self): 940 test_pipeline = TestPipeline(is_integration_test=True) 941 test_pipeline.options.view_as(StandardOptions).streaming = True 942 run_inference_side_inputs.run( 943 test_pipeline.get_full_options_as_args(), save_main_session=False) 944 945 def test_env_vars_set_correctly(self): 946 handler_with_vars = FakeModelHandler(env_vars={'FOO': 'bar'}) 947 os.environ.pop('FOO', None) 948 self.assertFalse('FOO' in os.environ) 949 with TestPipeline() as pipeline: 950 examples = [1, 2, 3] 951 _ = ( 952 pipeline 953 | 'start' >> beam.Create(examples) 954 | base.RunInference(handler_with_vars)) 955 pipeline.run() 956 self.assertTrue('FOO' in os.environ) 957 self.assertTrue((os.environ['FOO']) == 'bar') 958 959 def test_child_class_without_env_vars(self): 960 with TestPipeline() as pipeline: 961 examples = [1, 5, 3, 10] 962 expected = [example + 1 for example in examples] 963 pcoll = pipeline | 'start' >> beam.Create(examples) 964 actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars()) 965 assert_that(actual, equal_to(expected), label='assert:inferences') 966 967 968 if __name__ == '__main__': 969 unittest.main()