github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/base.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # TODO: https://github.com/apache/beam/issues/21822 18 # mypy: ignore-errors 19 20 """An extensible run inference transform. 21 22 Users of this module can extend the ModelHandler class for any machine learning 23 framework. A ModelHandler implementation is a required parameter of 24 RunInference. 25 26 The transform handles standard inference functionality, like metric 27 collection, sharing model between threads, and batching elements. 28 """ 29 30 import logging 31 import os 32 import pickle 33 import sys 34 import threading 35 import time 36 import uuid 37 from typing import Any 38 from typing import Callable 39 from typing import Dict 40 from typing import Generic 41 from typing import Iterable 42 from typing import Mapping 43 from typing import NamedTuple 44 from typing import Optional 45 from typing import Sequence 46 from typing import Tuple 47 from typing import TypeVar 48 from typing import Union 49 50 import apache_beam as beam 51 from apache_beam.utils import multi_process_shared 52 from apache_beam.utils import shared 53 54 try: 55 # pylint: disable=wrong-import-order, wrong-import-position 56 import resource 57 except ImportError: 58 resource = None # type: ignore[assignment] 59 60 _NANOSECOND_TO_MILLISECOND = 1_000_000 61 _NANOSECOND_TO_MICROSECOND = 1_000 62 63 ModelT = TypeVar('ModelT') 64 ExampleT = TypeVar('ExampleT') 65 PreProcessT = TypeVar('PreProcessT') 66 PredictionT = TypeVar('PredictionT') 67 PostProcessT = TypeVar('PostProcessT') 68 _INPUT_TYPE = TypeVar('_INPUT_TYPE') 69 _OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') 70 KeyT = TypeVar('KeyT') 71 72 73 # We use NamedTuple to define the structure of the PredictionResult, 74 # however, as support for generic NamedTuples is not available in Python 75 # versions prior to 3.11, we use the __new__ method to provide default 76 # values for the fields while maintaining backwards compatibility. 77 class PredictionResult(NamedTuple('PredictionResult', 78 [('example', _INPUT_TYPE), 79 ('inference', _OUTPUT_TYPE), 80 ('model_id', Optional[str])])): 81 __slots__ = () 82 83 def __new__(cls, example, inference, model_id=None): 84 return super().__new__(cls, example, inference, model_id) 85 86 87 PredictionResult.__doc__ = """A NamedTuple containing both input and output 88 from the inference.""" 89 PredictionResult.example.__doc__ = """The input example.""" 90 PredictionResult.inference.__doc__ = """Results for the inference on the model 91 for the given example.""" 92 PredictionResult.model_id.__doc__ = """Model ID used to run the prediction.""" 93 94 95 class ModelMetadata(NamedTuple): 96 model_id: str 97 model_name: str 98 99 100 class RunInferenceDLQ(NamedTuple): 101 failed_inferences: beam.PCollection 102 failed_preprocessing: Sequence[beam.PCollection] 103 failed_postprocessing: Sequence[beam.PCollection] 104 105 106 ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be 107 a file path or a URL where the model can be accessed. It is used to load 108 the model for inference.""" 109 ModelMetadata.model_name.__doc__ = """Human-readable name for the model. This 110 can be used to identify the model in the metrics generated by the 111 RunInference transform.""" 112 113 114 def _to_milliseconds(time_ns: int) -> int: 115 return int(time_ns / _NANOSECOND_TO_MILLISECOND) 116 117 118 def _to_microseconds(time_ns: int) -> int: 119 return int(time_ns / _NANOSECOND_TO_MICROSECOND) 120 121 122 class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): 123 """Has the ability to load and apply an ML model.""" 124 def __init__(self): 125 """Environment variables are set using a dict named 'env_vars' before 126 loading the model. Child classes can accept this dict as a kwarg.""" 127 self._env_vars = {} 128 129 def load_model(self) -> ModelT: 130 """Loads and initializes a model for processing.""" 131 raise NotImplementedError(type(self)) 132 133 def run_inference( 134 self, 135 batch: Sequence[ExampleT], 136 model: ModelT, 137 inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: 138 """Runs inferences on a batch of examples. 139 140 Args: 141 batch: A sequence of examples or features. 142 model: The model used to make inferences. 143 inference_args: Extra arguments for models whose inference call requires 144 extra parameters. 145 146 Returns: 147 An Iterable of Predictions. 148 """ 149 raise NotImplementedError(type(self)) 150 151 def get_num_bytes(self, batch: Sequence[ExampleT]) -> int: 152 """ 153 Returns: 154 The number of bytes of data for a batch. 155 """ 156 return len(pickle.dumps(batch)) 157 158 def get_metrics_namespace(self) -> str: 159 """ 160 Returns: 161 A namespace for metrics collected by the RunInference transform. 162 """ 163 return 'RunInference' 164 165 def get_resource_hints(self) -> dict: 166 """ 167 Returns: 168 Resource hints for the transform. 169 """ 170 return {} 171 172 def batch_elements_kwargs(self) -> Mapping[str, Any]: 173 """ 174 Returns: 175 kwargs suitable for beam.BatchElements. 176 """ 177 return {} 178 179 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 180 """Validates inference_args passed in the inference call. 181 182 Because most frameworks do not need extra arguments in their predict() call, 183 the default behavior is to error out if inference_args are present. 184 """ 185 if inference_args: 186 raise ValueError( 187 'inference_args were provided, but should be None because this ' 188 'framework does not expect extra arguments on inferences.') 189 190 def update_model_path(self, model_path: Optional[str] = None): 191 """Update the model paths produced by side inputs.""" 192 pass 193 194 def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 195 """Gets all preprocessing functions to be run before batching/inference. 196 Functions are in order that they should be applied.""" 197 return [] 198 199 def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 200 """Gets all postprocessing functions to be run after inference. 201 Functions are in order that they should be applied.""" 202 return [] 203 204 def set_environment_vars(self): 205 """Sets environment variables using a dictionary provided via kwargs. 206 Keys are the env variable name, and values are the env variable value. 207 Child ModelHandler classes should set _env_vars via kwargs in __init__, 208 or else call super().__init__().""" 209 env_vars = getattr(self, '_env_vars', {}) 210 for env_variable, env_value in env_vars.items(): 211 os.environ[env_variable] = env_value 212 213 def with_preprocess_fn( 214 self, fn: Callable[[PreProcessT], ExampleT] 215 ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]': 216 """Returns a new ModelHandler with a preprocessing function 217 associated with it. The preprocessing function will be run 218 before batching/inference and should map your input PCollection 219 to the base ModelHandler's input type. If you apply multiple 220 preprocessing functions, they will be run on your original 221 PCollection in order from last applied to first applied.""" 222 return _PreProcessingModelHandler(self, fn) 223 224 def with_postprocess_fn( 225 self, fn: Callable[[PredictionT], PostProcessT] 226 ) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]': 227 """Returns a new ModelHandler with a postprocessing function 228 associated with it. The postprocessing function will be run 229 after inference and should map the base ModelHandler's output 230 type to your desired output type. If you apply multiple 231 postprocessing functions, they will be run on your original 232 inference result in order from first applied to last applied.""" 233 return _PostProcessingModelHandler(self, fn) 234 235 def share_model_across_processes(self) -> bool: 236 """Returns a boolean representing whether or not a model should 237 be shared across multiple processes instead of being loaded per process. 238 This is primary useful for large models that can't fit multiple copies in 239 memory. Multi-process support may vary by runner, but this will fallback to 240 loading per process as necessary. See 241 https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""" 242 return False 243 244 245 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], 246 ModelHandler[Tuple[KeyT, ExampleT], 247 Tuple[KeyT, PredictionT], 248 ModelT]): 249 def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): 250 """A ModelHandler that takes keyed examples and returns keyed predictions. 251 252 For example, if the original model is used with RunInference to take a 253 PCollection[E] to a PCollection[P], this ModelHandler would take a 254 PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible 255 to use the key to associate the outputs with the inputs. 256 257 Args: 258 unkeyed: An implementation of ModelHandler that does not require keys. 259 """ 260 if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()): 261 raise Exception( 262 'Cannot make make an unkeyed model handler with pre or ' 263 'postprocessing functions defined into a keyed model handler. All ' 264 'pre/postprocessing functions must be defined on the outer model' 265 'handler.') 266 self._unkeyed = unkeyed 267 self._env_vars = unkeyed._env_vars 268 269 def load_model(self) -> ModelT: 270 return self._unkeyed.load_model() 271 272 def run_inference( 273 self, 274 batch: Sequence[Tuple[KeyT, ExampleT]], 275 model: ModelT, 276 inference_args: Optional[Dict[str, Any]] = None 277 ) -> Iterable[Tuple[KeyT, PredictionT]]: 278 keys, unkeyed_batch = zip(*batch) 279 return zip( 280 keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) 281 282 def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: 283 keys, unkeyed_batch = zip(*batch) 284 return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) 285 286 def get_metrics_namespace(self) -> str: 287 return self._unkeyed.get_metrics_namespace() 288 289 def get_resource_hints(self): 290 return self._unkeyed.get_resource_hints() 291 292 def batch_elements_kwargs(self): 293 return self._unkeyed.batch_elements_kwargs() 294 295 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 296 return self._unkeyed.validate_inference_args(inference_args) 297 298 def update_model_path(self, model_path: Optional[str] = None): 299 return self._unkeyed.update_model_path(model_path=model_path) 300 301 def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 302 return self._unkeyed.get_preprocess_fns() 303 304 def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 305 return self._unkeyed.get_postprocess_fns() 306 307 def share_model_across_processes(self) -> bool: 308 return self._unkeyed.share_model_across_processes() 309 310 311 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], 312 ModelHandler[Union[ExampleT, Tuple[KeyT, 313 ExampleT]], 314 Union[PredictionT, 315 Tuple[KeyT, PredictionT]], 316 ModelT]): 317 def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): 318 """A ModelHandler that takes examples that might have keys and returns 319 predictions that might have keys. 320 321 For example, if the original model is used with RunInference to take a 322 PCollection[E] to a PCollection[P], this ModelHandler would take either 323 PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a 324 PCollection[Tuple[K, P]], depending on the whether the elements are 325 tuples. This pattern makes it possible to associate the outputs with the 326 inputs based on the key. 327 328 Note that you cannot use this ModelHandler if E is a tuple type. 329 In addition, either all examples should be keyed, or none of them. 330 331 Args: 332 unkeyed: An implementation of ModelHandler that does not require keys. 333 """ 334 if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()): 335 raise Exception( 336 'Cannot make make an unkeyed model handler with pre or ' 337 'postprocessing functions defined into a keyed model handler. All ' 338 'pre/postprocessing functions must be defined on the outer model' 339 'handler.') 340 self._unkeyed = unkeyed 341 self._env_vars = unkeyed._env_vars 342 343 def load_model(self) -> ModelT: 344 return self._unkeyed.load_model() 345 346 def run_inference( 347 self, 348 batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], 349 model: ModelT, 350 inference_args: Optional[Dict[str, Any]] = None 351 ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: 352 # Really the input should be 353 # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] 354 # but there's not a good way to express (or check) that. 355 if isinstance(batch[0], tuple): 356 is_keyed = True 357 keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] 358 else: 359 is_keyed = False 360 unkeyed_batch = batch # type: ignore[assignment] 361 unkeyed_results = self._unkeyed.run_inference( 362 unkeyed_batch, model, inference_args) 363 if is_keyed: 364 return zip(keys, unkeyed_results) 365 else: 366 return unkeyed_results 367 368 def get_num_bytes( 369 self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: 370 # MyPy can't follow the branching logic. 371 if isinstance(batch[0], tuple): 372 keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] 373 return len( 374 pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) 375 else: 376 return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type] 377 378 def get_metrics_namespace(self) -> str: 379 return self._unkeyed.get_metrics_namespace() 380 381 def get_resource_hints(self): 382 return self._unkeyed.get_resource_hints() 383 384 def batch_elements_kwargs(self): 385 return self._unkeyed.batch_elements_kwargs() 386 387 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 388 return self._unkeyed.validate_inference_args(inference_args) 389 390 def update_model_path(self, model_path: Optional[str] = None): 391 return self._unkeyed.update_model_path(model_path=model_path) 392 393 def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 394 return self._unkeyed.get_preprocess_fns() 395 396 def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 397 return self._unkeyed.get_postprocess_fns() 398 399 def share_model_across_processes(self) -> bool: 400 return self._unkeyed.share_model_across_processes() 401 402 403 class _PreProcessingModelHandler(Generic[ExampleT, 404 PredictionT, 405 ModelT, 406 PreProcessT], 407 ModelHandler[PreProcessT, PredictionT, 408 ModelT]): 409 def __init__( 410 self, 411 base: ModelHandler[ExampleT, PredictionT, ModelT], 412 preprocess_fn: Callable[[PreProcessT], ExampleT]): 413 """A ModelHandler that has a preprocessing function associated with it. 414 415 Args: 416 base: An implementation of the underlying model handler. 417 preprocess_fn: the preprocessing function to use. 418 """ 419 self._base = base 420 self._env_vars = base._env_vars 421 self._preprocess_fn = preprocess_fn 422 423 def load_model(self) -> ModelT: 424 return self._base.load_model() 425 426 def run_inference( 427 self, 428 batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], 429 model: ModelT, 430 inference_args: Optional[Dict[str, Any]] = None 431 ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: 432 return self._base.run_inference(batch, model, inference_args) 433 434 def get_num_bytes( 435 self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: 436 return self._base.get_num_bytes(batch) 437 438 def get_metrics_namespace(self) -> str: 439 return self._base.get_metrics_namespace() 440 441 def get_resource_hints(self): 442 return self._base.get_resource_hints() 443 444 def batch_elements_kwargs(self): 445 return self._base.batch_elements_kwargs() 446 447 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 448 return self._base.validate_inference_args(inference_args) 449 450 def update_model_path(self, model_path: Optional[str] = None): 451 return self._base.update_model_path(model_path=model_path) 452 453 def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 454 return [self._preprocess_fn] + self._base.get_preprocess_fns() 455 456 def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 457 return self._base.get_postprocess_fns() 458 459 460 class _PostProcessingModelHandler(Generic[ExampleT, 461 PredictionT, 462 ModelT, 463 PostProcessT], 464 ModelHandler[ExampleT, PostProcessT, ModelT]): 465 def __init__( 466 self, 467 base: ModelHandler[ExampleT, PredictionT, ModelT], 468 postprocess_fn: Callable[[PredictionT], PostProcessT]): 469 """A ModelHandler that has a preprocessing function associated with it. 470 471 Args: 472 base: An implementation of the underlying model handler. 473 postprocess_fn: the preprocessing function to use. 474 """ 475 self._base = base 476 self._env_vars = base._env_vars 477 self._postprocess_fn = postprocess_fn 478 479 def load_model(self) -> ModelT: 480 return self._base.load_model() 481 482 def run_inference( 483 self, 484 batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], 485 model: ModelT, 486 inference_args: Optional[Dict[str, Any]] = None 487 ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: 488 return self._base.run_inference(batch, model, inference_args) 489 490 def get_num_bytes( 491 self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: 492 return self._base.get_num_bytes(batch) 493 494 def get_metrics_namespace(self) -> str: 495 return self._base.get_metrics_namespace() 496 497 def get_resource_hints(self): 498 return self._base.get_resource_hints() 499 500 def batch_elements_kwargs(self): 501 return self._base.batch_elements_kwargs() 502 503 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 504 return self._base.validate_inference_args(inference_args) 505 506 def update_model_path(self, model_path: Optional[str] = None): 507 return self._base.update_model_path(model_path=model_path) 508 509 def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 510 return self._base.get_preprocess_fns() 511 512 def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: 513 return self._base.get_postprocess_fns() + [self._postprocess_fn] 514 515 516 class RunInference(beam.PTransform[beam.PCollection[ExampleT], 517 beam.PCollection[PredictionT]]): 518 def __init__( 519 self, 520 model_handler: ModelHandler[ExampleT, PredictionT, Any], 521 clock=time, 522 inference_args: Optional[Dict[str, Any]] = None, 523 metrics_namespace: Optional[str] = None, 524 *, 525 model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, 526 watch_model_pattern: Optional[str] = None, 527 **kwargs): 528 """ 529 A transform that takes a PCollection of examples (or features) for use 530 on an ML model. The transform then outputs inferences (or predictions) for 531 those examples in a PCollection of PredictionResults that contains the input 532 examples and the output inferences. 533 534 Models for supported frameworks can be loaded using a URI. Supported 535 services can also be used. 536 537 This transform attempts to batch examples using the beam.BatchElements 538 transform. Batching can be configured using the ModelHandler. 539 540 Args: 541 model_handler: An implementation of ModelHandler. 542 clock: A clock implementing time_ns. *Used for unit testing.* 543 inference_args: Extra arguments for models whose inference call requires 544 extra parameters. 545 metrics_namespace: Namespace of the transform to collect metrics. 546 model_metadata_pcoll: PCollection that emits Singleton ModelMetadata 547 containing model path and model name, that is used as a side input 548 to the _RunInferenceDoFn. 549 watch_model_pattern: A glob pattern used to watch a directory 550 for automatic model refresh. 551 """ 552 self._model_handler = model_handler 553 self._inference_args = inference_args 554 self._clock = clock 555 self._metrics_namespace = metrics_namespace 556 self._model_metadata_pcoll = model_metadata_pcoll 557 self._enable_side_input_loading = self._model_metadata_pcoll is not None 558 self._with_exception_handling = False 559 self._watch_model_pattern = watch_model_pattern 560 self._kwargs = kwargs 561 # Generate a random tag to use for shared.py and multi_process_shared.py to 562 # allow us to effectively disambiguate in multi-model settings. 563 self._model_tag = uuid.uuid4().hex 564 565 def _get_model_metadata_pcoll(self, pipeline): 566 # avoid circular imports. 567 # pylint: disable=wrong-import-position 568 from apache_beam.ml.inference.utils import WatchFilePattern 569 extra_params = {} 570 if 'interval' in self._kwargs: 571 extra_params['interval'] = self._kwargs['interval'] 572 if 'stop_timestamp' in self._kwargs: 573 extra_params['stop_timestamp'] = self._kwargs['stop_timestamp'] 574 575 return ( 576 pipeline | WatchFilePattern( 577 file_pattern=self._watch_model_pattern, **extra_params)) 578 579 # TODO(BEAM-14046): Add and link to help documentation. 580 @classmethod 581 def from_callable(cls, model_handler_provider, **kwargs): 582 """Multi-language friendly constructor. 583 584 Use this constructor with fully_qualified_named_transform to 585 initialize the RunInference transform from PythonCallableSource provided 586 by foreign SDKs. 587 588 Args: 589 model_handler_provider: A callable object that returns ModelHandler. 590 kwargs: Keyword arguments for model_handler_provider. 591 """ 592 return cls(model_handler_provider(**kwargs)) 593 594 def _apply_fns( 595 self, 596 pcoll: beam.PCollection, 597 fns: Iterable[Callable[[Any], Any]], 598 step_prefix: str) -> Tuple[beam.PCollection, Iterable[beam.PCollection]]: 599 bad_preprocessed = [] 600 for idx in range(len(fns)): 601 fn = fns[idx] 602 if self._with_exception_handling: 603 pcoll, bad = (pcoll 604 | f"{step_prefix}-{idx}" >> beam.Map( 605 fn).with_exception_handling( 606 exc_class=self._exc_class, 607 use_subprocess=self._use_subprocess, 608 threshold=self._threshold)) 609 bad_preprocessed.append(bad) 610 else: 611 pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn) 612 613 return pcoll, bad_preprocessed 614 615 # TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off 616 # in the case there are functional reasons large batch sizes cannot be 617 # handled. 618 def expand( 619 self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]: 620 self._model_handler.validate_inference_args(self._inference_args) 621 # DLQ pcollections 622 bad_preprocessed = [] 623 bad_inference = None 624 bad_postprocessed = [] 625 preprocess_fns = self._model_handler.get_preprocess_fns() 626 postprocess_fns = self._model_handler.get_postprocess_fns() 627 628 pcoll, bad_preprocessed = self._apply_fns( 629 pcoll, preprocess_fns, 'BeamML_RunInference_Preprocess') 630 631 resource_hints = self._model_handler.get_resource_hints() 632 633 # check for the side input 634 if self._watch_model_pattern: 635 self._model_metadata_pcoll = self._get_model_metadata_pcoll( 636 pcoll.pipeline) 637 638 batched_elements_pcoll = ( 639 pcoll 640 # TODO(https://github.com/apache/beam/issues/21440): Hook into the 641 # batching DoFn APIs. 642 | beam.BatchElements(**self._model_handler.batch_elements_kwargs())) 643 644 run_inference_pardo = beam.ParDo( 645 _RunInferenceDoFn( 646 self._model_handler, 647 self._clock, 648 self._metrics_namespace, 649 self._enable_side_input_loading, 650 self._model_tag), 651 self._inference_args, 652 beam.pvalue.AsSingleton( 653 self._model_metadata_pcoll, 654 ) if self._enable_side_input_loading else None).with_resource_hints( 655 **resource_hints) 656 657 if self._with_exception_handling: 658 results, bad_inference = ( 659 batched_elements_pcoll 660 | 'BeamML_RunInference' >> 661 run_inference_pardo.with_exception_handling( 662 exc_class=self._exc_class, 663 use_subprocess=self._use_subprocess, 664 threshold=self._threshold)) 665 else: 666 results = ( 667 batched_elements_pcoll 668 | 'BeamML_RunInference' >> run_inference_pardo) 669 670 results, bad_postprocessed = self._apply_fns( 671 results, postprocess_fns, 'BeamML_RunInference_Postprocess') 672 673 if self._with_exception_handling: 674 dlq = RunInferenceDLQ(bad_inference, bad_preprocessed, bad_postprocessed) 675 return results, dlq 676 677 return results 678 679 def with_exception_handling( 680 self, *, exc_class=Exception, use_subprocess=False, threshold=1): 681 """Automatically provides a dead letter output for skipping bad records. 682 This can allow a pipeline to continue successfully rather than fail or 683 continuously throw errors on retry when bad elements are encountered. 684 685 This returns a tagged output with two PCollections, the first being the 686 results of successfully processing the input PCollection, and the second 687 being the set of bad batches of records (those which threw exceptions 688 during processing) along with information about the errors raised. 689 690 For example, one would write:: 691 692 main, other = RunInference( 693 maybe_error_raising_model_handler 694 ).with_exception_handling() 695 696 and `main` will be a PCollection of PredictionResults and `other` will 697 contain a `RunInferenceDLQ` object with PCollections containing failed 698 records for each failed inference, preprocess operation, or postprocess 699 operation. To access each collection of failed records, one would write: 700 701 failed_inferences = other.failed_inferences 702 failed_preprocessing = other.failed_preprocessing 703 failed_postprocessing = other.failed_postprocessing 704 705 failed_inferences is in the form 706 PCollection[Tuple[failed batch, exception]]. 707 708 failed_preprocessing is in the form 709 list[PCollection[Tuple[failed record, exception]]]], where each element of 710 the list corresponds to a preprocess function. These PCollections are 711 in the same order that the preprocess functions are applied. 712 713 failed_postprocessing is in the form 714 List[PCollection[Tuple[failed record, exception]]]], where each element of 715 the list corresponds to a postprocess function. These PCollections are 716 in the same order that the postprocess functions are applied. 717 718 719 Args: 720 exc_class: An exception class, or tuple of exception classes, to catch. 721 Optional, defaults to 'Exception'. 722 use_subprocess: Whether to execute the DoFn logic in a subprocess. This 723 allows one to recover from errors that can crash the calling process 724 (e.g. from an underlying library causing a segfault), but is 725 slower as elements and results must cross a process boundary. Note 726 that this starts up a long-running process that is used to handle 727 all the elements (until hard failure, which should be rare) rather 728 than a new process per element, so the overhead should be minimal 729 (and can be amortized if there's any per-process or per-bundle 730 initialization that needs to be done). Optional, defaults to False. 731 threshold: An upper bound on the ratio of records that can be bad before 732 aborting the entire pipeline. Optional, defaults to 1.0 (meaning 733 up to 100% of records can be bad and the pipeline will still succeed). 734 """ 735 self._with_exception_handling = True 736 self._exc_class = exc_class 737 self._use_subprocess = use_subprocess 738 self._threshold = threshold 739 return self 740 741 742 class _MetricsCollector: 743 """A metrics collector that tracks ML related performance and memory usage.""" 744 def __init__(self, namespace: str, prefix: str = ''): 745 """ 746 Args: 747 namespace: Namespace for the metrics. 748 prefix: Unique identifier for metrics, used when models 749 are updated using side input. 750 """ 751 # Metrics 752 if prefix: 753 prefix = f'{prefix}_' 754 self._inference_counter = beam.metrics.Metrics.counter( 755 namespace, prefix + 'num_inferences') 756 self.failed_batches_counter = beam.metrics.Metrics.counter( 757 namespace, prefix + 'failed_batches_counter') 758 self._inference_request_batch_size = beam.metrics.Metrics.distribution( 759 namespace, prefix + 'inference_request_batch_size') 760 self._inference_request_batch_byte_size = ( 761 beam.metrics.Metrics.distribution( 762 namespace, prefix + 'inference_request_batch_byte_size')) 763 # Batch inference latency in microseconds. 764 self._inference_batch_latency_micro_secs = ( 765 beam.metrics.Metrics.distribution( 766 namespace, prefix + 'inference_batch_latency_micro_secs')) 767 self._model_byte_size = beam.metrics.Metrics.distribution( 768 namespace, prefix + 'model_byte_size') 769 # Model load latency in milliseconds. 770 self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution( 771 namespace, prefix + 'load_model_latency_milli_secs') 772 773 # Metrics cache 774 self._load_model_latency_milli_secs_cache = None 775 self._model_byte_size_cache = None 776 777 def update_metrics_with_cache(self): 778 if self._load_model_latency_milli_secs_cache is not None: 779 self._load_model_latency_milli_secs.update( 780 self._load_model_latency_milli_secs_cache) 781 self._load_model_latency_milli_secs_cache = None 782 if self._model_byte_size_cache is not None: 783 self._model_byte_size.update(self._model_byte_size_cache) 784 self._model_byte_size_cache = None 785 786 def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size): 787 self._load_model_latency_milli_secs_cache = load_model_latency_ms 788 self._model_byte_size_cache = model_byte_size 789 790 def update( 791 self, 792 examples_count: int, 793 examples_byte_size: int, 794 latency_micro_secs: int): 795 self._inference_batch_latency_micro_secs.update(latency_micro_secs) 796 self._inference_counter.inc(examples_count) 797 self._inference_request_batch_size.update(examples_count) 798 self._inference_request_batch_byte_size.update(examples_byte_size) 799 800 801 class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): 802 def __init__( 803 self, 804 model_handler: ModelHandler[ExampleT, PredictionT, Any], 805 clock, 806 metrics_namespace, 807 enable_side_input_loading: bool = False, 808 model_tag: str = "RunInference"): 809 """A DoFn implementation generic to frameworks. 810 811 Args: 812 model_handler: An implementation of ModelHandler. 813 clock: A clock implementing time_ns. *Used for unit testing.* 814 metrics_namespace: Namespace of the transform to collect metrics. 815 enable_side_input_loading: Bool to indicate if model updates 816 with side inputs. 817 model_tag: Tag to use to disambiguate models in multi-model settings. 818 """ 819 self._model_handler = model_handler 820 self._shared_model_handle = shared.Shared() 821 self._clock = clock 822 self._model = None 823 self._metrics_namespace = metrics_namespace 824 self._enable_side_input_loading = enable_side_input_loading 825 self._side_input_path = None 826 self._model_tag = model_tag 827 828 def _load_model(self, side_input_model_path: Optional[str] = None): 829 def load(): 830 """Function for constructing shared LoadedModel.""" 831 memory_before = _get_current_process_memory_in_bytes() 832 start_time = _to_milliseconds(self._clock.time_ns()) 833 self._model_handler.update_model_path(side_input_model_path) 834 model = self._model_handler.load_model() 835 end_time = _to_milliseconds(self._clock.time_ns()) 836 memory_after = _get_current_process_memory_in_bytes() 837 load_model_latency_ms = end_time - start_time 838 model_byte_size = memory_after - memory_before 839 self._metrics_collector.cache_load_model_metrics( 840 load_model_latency_ms, model_byte_size) 841 return model 842 843 # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing 844 # model. 845 if self._model_handler.share_model_across_processes(): 846 model = multi_process_shared.MultiProcessShared( 847 load, tag=side_input_model_path or self._model_tag).acquire() 848 else: 849 model = self._shared_model_handle.acquire( 850 load, tag=side_input_model_path or self._model_tag) 851 # since shared_model_handle is shared across threads, the model path 852 # might not get updated in the model handler 853 # because we directly get cached weak ref model from shared cache, instead 854 # of calling load(). For sanity check, call update_model_path again. 855 self._model_handler.update_model_path(side_input_model_path) 856 return model 857 858 def get_metrics_collector(self, prefix: str = ''): 859 """ 860 Args: 861 prefix: Unique identifier for metrics, used when models 862 are updated using side input. 863 """ 864 metrics_namespace = ( 865 self._metrics_namespace) if self._metrics_namespace else ( 866 self._model_handler.get_metrics_namespace()) 867 return _MetricsCollector(metrics_namespace, prefix=prefix) 868 869 def setup(self): 870 self._metrics_collector = self.get_metrics_collector() 871 self._model_handler.set_environment_vars() 872 if not self._enable_side_input_loading: 873 self._model = self._load_model() 874 875 def update_model(self, side_input_model_path: Optional[str] = None): 876 self._model = self._load_model(side_input_model_path=side_input_model_path) 877 878 def _run_inference(self, batch, inference_args): 879 start_time = _to_microseconds(self._clock.time_ns()) 880 try: 881 result_generator = self._model_handler.run_inference( 882 batch, self._model, inference_args) 883 except BaseException as e: 884 self._metrics_collector.failed_batches_counter.inc() 885 raise e 886 predictions = list(result_generator) 887 888 end_time = _to_microseconds(self._clock.time_ns()) 889 inference_latency = end_time - start_time 890 num_bytes = self._model_handler.get_num_bytes(batch) 891 num_elements = len(batch) 892 self._metrics_collector.update(num_elements, num_bytes, inference_latency) 893 894 return predictions 895 896 def process( 897 self, batch, inference_args, si_model_metadata: Optional[ModelMetadata]): 898 """ 899 When side input is enabled: 900 The method checks if the side input model has been updated, and if so, 901 updates the model and runs inference on the batch of data. If the 902 side input is empty or the model has not been updated, the method 903 simply runs inference on the batch of data. 904 """ 905 if si_model_metadata: 906 if isinstance(si_model_metadata, beam.pvalue.EmptySideInput): 907 self.update_model(side_input_model_path=None) 908 return self._run_inference(batch, inference_args) 909 elif self._side_input_path != si_model_metadata.model_id: 910 self._side_input_path = si_model_metadata.model_id 911 self._metrics_collector = self.get_metrics_collector( 912 prefix=si_model_metadata.model_name) 913 with threading.Lock(): 914 self.update_model(si_model_metadata.model_id) 915 return self._run_inference(batch, inference_args) 916 return self._run_inference(batch, inference_args) 917 918 def finish_bundle(self): 919 # TODO(https://github.com/apache/beam/issues/21435): Figure out why there 920 # is a cache. 921 self._metrics_collector.update_metrics_with_cache() 922 923 924 def _is_darwin() -> bool: 925 return sys.platform == 'darwin' 926 927 928 def _get_current_process_memory_in_bytes(): 929 """ 930 Returns: 931 memory usage in bytes. 932 """ 933 934 if resource is not None: 935 usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 936 if _is_darwin(): 937 return usage 938 return usage * 1024 939 else: 940 logging.warning( 941 'Resource module is not available for current platform, ' 942 'memory usage cannot be fetched.') 943 return 0