github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/base_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for apache_beam.ml.base."""
    19  import math
    20  import os
    21  import pickle
    22  import time
    23  import unittest
    24  from typing import Any
    25  from typing import Dict
    26  from typing import Iterable
    27  from typing import Mapping
    28  from typing import Optional
    29  from typing import Sequence
    30  
    31  import pytest
    32  
    33  import apache_beam as beam
    34  from apache_beam.examples.inference import run_inference_side_inputs
    35  from apache_beam.metrics.metric import MetricsFilter
    36  from apache_beam.ml.inference import base
    37  from apache_beam.options.pipeline_options import StandardOptions
    38  from apache_beam.testing.test_pipeline import TestPipeline
    39  from apache_beam.testing.util import assert_that
    40  from apache_beam.testing.util import equal_to
    41  from apache_beam.transforms import trigger
    42  from apache_beam.transforms import window
    43  from apache_beam.transforms.periodicsequence import TimestampedValue
    44  
    45  
    46  class FakeModel:
    47    def predict(self, example: int) -> int:
    48      return example + 1
    49  
    50  
    51  class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
    52    def __init__(
    53        self,
    54        clock=None,
    55        min_batch_size=1,
    56        max_batch_size=9999,
    57        multi_process_shared=False,
    58        **kwargs):
    59      self._fake_clock = clock
    60      self._min_batch_size = min_batch_size
    61      self._max_batch_size = max_batch_size
    62      self._env_vars = kwargs.get('env_vars', {})
    63      self._multi_process_shared = multi_process_shared
    64  
    65    def load_model(self):
    66      if self._fake_clock:
    67        self._fake_clock.current_time_ns += 500_000_000  # 500ms
    68      return FakeModel()
    69  
    70    def run_inference(
    71        self,
    72        batch: Sequence[int],
    73        model: FakeModel,
    74        inference_args=None) -> Iterable[int]:
    75      multi_process_shared_loaded = "multi_process_shared" in str(type(model))
    76      if self._multi_process_shared != multi_process_shared_loaded:
    77        raise Exception(
    78            f'Loaded model of type {type(model)}, was' +
    79            f'{"" if self._multi_process_shared else " not"} ' +
    80            'expecting multi_process_shared_model')
    81      if self._fake_clock:
    82        self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
    83      for example in batch:
    84        yield model.predict(example)
    85  
    86    def update_model_path(self, model_path: Optional[str] = None):
    87      pass
    88  
    89    def batch_elements_kwargs(self):
    90      return {
    91          'min_batch_size': self._min_batch_size,
    92          'max_batch_size': self._max_batch_size
    93      }
    94  
    95    def share_model_across_processes(self):
    96      return self._multi_process_shared
    97  
    98  
    99  class FakeModelHandlerReturnsPredictionResult(
   100      base.ModelHandler[int, base.PredictionResult, FakeModel]):
   101    def __init__(
   102        self,
   103        clock=None,
   104        model_id='fake_model_id_default',
   105        multi_process_shared=False):
   106      self.model_id = model_id
   107      self._fake_clock = clock
   108      self._env_vars = {}
   109      self._multi_process_shared = multi_process_shared
   110  
   111    def load_model(self):
   112      return FakeModel()
   113  
   114    def run_inference(
   115        self,
   116        batch: Sequence[int],
   117        model: FakeModel,
   118        inference_args=None) -> Iterable[base.PredictionResult]:
   119      multi_process_shared_loaded = "multi_process_shared" in str(type(model))
   120      if self._multi_process_shared != multi_process_shared_loaded:
   121        raise Exception(
   122            f'Loaded model of type {type(model)}, was' +
   123            f'{"" if self._multi_process_shared else " not"} ' +
   124            'expecting multi_process_shared_model')
   125      for example in batch:
   126        yield base.PredictionResult(
   127            model_id=self.model_id,
   128            example=example,
   129            inference=model.predict(example))
   130  
   131    def update_model_path(self, model_path: Optional[str] = None):
   132      self.model_id = model_path if model_path else self.model_id
   133  
   134    def share_model_across_processes(self):
   135      return self._multi_process_shared
   136  
   137  
   138  class FakeModelHandlerNoEnvVars(base.ModelHandler[int, int, FakeModel]):
   139    def __init__(
   140        self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs):
   141      self._fake_clock = clock
   142      self._min_batch_size = min_batch_size
   143      self._max_batch_size = max_batch_size
   144  
   145    def load_model(self):
   146      if self._fake_clock:
   147        self._fake_clock.current_time_ns += 500_000_000  # 500ms
   148      return FakeModel()
   149  
   150    def run_inference(
   151        self,
   152        batch: Sequence[int],
   153        model: FakeModel,
   154        inference_args=None) -> Iterable[int]:
   155      if self._fake_clock:
   156        self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
   157      for example in batch:
   158        yield model.predict(example)
   159  
   160    def update_model_path(self, model_path: Optional[str] = None):
   161      pass
   162  
   163    def batch_elements_kwargs(self):
   164      return {
   165          'min_batch_size': self._min_batch_size,
   166          'max_batch_size': self._max_batch_size
   167      }
   168  
   169  
   170  class FakeClock:
   171    def __init__(self):
   172      # Start at 10 seconds.
   173      self.current_time_ns = 10_000_000_000
   174  
   175    def time_ns(self) -> int:
   176      return self.current_time_ns
   177  
   178  
   179  class ExtractInferences(beam.DoFn):
   180    def process(self, prediction_result):
   181      yield prediction_result.inference
   182  
   183  
   184  class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
   185    def run_inference(self, batch, unused_model, inference_args=None):
   186      if len(batch) < 100:
   187        raise ValueError('Unexpectedly small batch')
   188      return batch
   189  
   190    def batch_elements_kwargs(self):
   191      return {'min_batch_size': 9999}
   192  
   193  
   194  class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
   195    def run_inference(self, batch, unused_model, inference_args=None):
   196      raise ValueError(
   197          'run_inference should not be called because error should already be '
   198          'thrown from the validate_inference_args check.')
   199  
   200  
   201  class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
   202    def run_inference(self, batch, unused_model, inference_args=None):
   203      if not inference_args:
   204        raise ValueError('inference_args should exist')
   205      return batch
   206  
   207    def validate_inference_args(self, inference_args):
   208      pass
   209  
   210  
   211  class RunInferenceBaseTest(unittest.TestCase):
   212    def test_run_inference_impl_simple_examples(self):
   213      with TestPipeline() as pipeline:
   214        examples = [1, 5, 3, 10]
   215        expected = [example + 1 for example in examples]
   216        pcoll = pipeline | 'start' >> beam.Create(examples)
   217        actual = pcoll | base.RunInference(FakeModelHandler())
   218        assert_that(actual, equal_to(expected), label='assert:inferences')
   219  
   220    def test_run_inference_impl_simple_examples_multi_process_shared(self):
   221      with TestPipeline() as pipeline:
   222        examples = [1, 5, 3, 10]
   223        expected = [example + 1 for example in examples]
   224        pcoll = pipeline | 'start' >> beam.Create(examples)
   225        actual = pcoll | base.RunInference(
   226            FakeModelHandler(multi_process_shared=True))
   227        assert_that(actual, equal_to(expected), label='assert:inferences')
   228  
   229    def test_run_inference_impl_with_keyed_examples(self):
   230      with TestPipeline() as pipeline:
   231        examples = [1, 5, 3, 10]
   232        keyed_examples = [(i, example) for i, example in enumerate(examples)]
   233        expected = [(i, example + 1) for i, example in enumerate(examples)]
   234        pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
   235        actual = pcoll | base.RunInference(
   236            base.KeyedModelHandler(FakeModelHandler()))
   237        assert_that(actual, equal_to(expected), label='assert:inferences')
   238  
   239    def test_run_inference_impl_with_maybe_keyed_examples(self):
   240      with TestPipeline() as pipeline:
   241        examples = [1, 5, 3, 10]
   242        keyed_examples = [(i, example) for i, example in enumerate(examples)]
   243        expected = [example + 1 for example in examples]
   244        keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
   245        model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
   246  
   247        pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
   248        actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
   249        assert_that(actual, equal_to(expected), label='CheckUnkeyed')
   250  
   251        keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
   252        keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
   253            model_handler)
   254        assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
   255  
   256    def test_run_inference_impl_with_keyed_examples_multi_process_shared(self):
   257      with TestPipeline() as pipeline:
   258        examples = [1, 5, 3, 10]
   259        keyed_examples = [(i, example) for i, example in enumerate(examples)]
   260        expected = [(i, example + 1) for i, example in enumerate(examples)]
   261        pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
   262        actual = pcoll | base.RunInference(
   263            base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True)))
   264        assert_that(actual, equal_to(expected), label='assert:inferences')
   265  
   266    def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared(
   267        self):
   268      with TestPipeline() as pipeline:
   269        examples = [1, 5, 3, 10]
   270        keyed_examples = [(i, example) for i, example in enumerate(examples)]
   271        expected = [example + 1 for example in examples]
   272        keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
   273        model_handler = base.MaybeKeyedModelHandler(
   274            FakeModelHandler(multi_process_shared=True))
   275  
   276        pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
   277        actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
   278        assert_that(actual, equal_to(expected), label='CheckUnkeyed')
   279  
   280        keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
   281        keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
   282            model_handler)
   283        assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
   284  
   285    def test_run_inference_preprocessing(self):
   286      def mult_two(example: str) -> int:
   287        return int(example) * 2
   288  
   289      with TestPipeline() as pipeline:
   290        examples = ["1", "5", "3", "10"]
   291        expected = [int(example) * 2 + 1 for example in examples]
   292        pcoll = pipeline | 'start' >> beam.Create(examples)
   293        actual = pcoll | base.RunInference(
   294            FakeModelHandler().with_preprocess_fn(mult_two))
   295        assert_that(actual, equal_to(expected), label='assert:inferences')
   296  
   297    def test_run_inference_preprocessing_multiple_fns(self):
   298      def add_one(example: str) -> int:
   299        return int(example) + 1
   300  
   301      def mult_two(example: int) -> int:
   302        return example * 2
   303  
   304      with TestPipeline() as pipeline:
   305        examples = ["1", "5", "3", "10"]
   306        expected = [(int(example) + 1) * 2 + 1 for example in examples]
   307        pcoll = pipeline | 'start' >> beam.Create(examples)
   308        actual = pcoll | base.RunInference(
   309            FakeModelHandler().with_preprocess_fn(mult_two).with_preprocess_fn(
   310                add_one))
   311        assert_that(actual, equal_to(expected), label='assert:inferences')
   312  
   313    def test_run_inference_postprocessing(self):
   314      def mult_two(example: int) -> str:
   315        return str(example * 2)
   316  
   317      with TestPipeline() as pipeline:
   318        examples = [1, 5, 3, 10]
   319        expected = [str((example + 1) * 2) for example in examples]
   320        pcoll = pipeline | 'start' >> beam.Create(examples)
   321        actual = pcoll | base.RunInference(
   322            FakeModelHandler().with_postprocess_fn(mult_two))
   323        assert_that(actual, equal_to(expected), label='assert:inferences')
   324  
   325    def test_run_inference_postprocessing_multiple_fns(self):
   326      def add_one(example: int) -> str:
   327        return str(int(example) + 1)
   328  
   329      def mult_two(example: int) -> int:
   330        return example * 2
   331  
   332      with TestPipeline() as pipeline:
   333        examples = [1, 5, 3, 10]
   334        expected = [str(((example + 1) * 2) + 1) for example in examples]
   335        pcoll = pipeline | 'start' >> beam.Create(examples)
   336        actual = pcoll | base.RunInference(
   337            FakeModelHandler().with_postprocess_fn(mult_two).with_postprocess_fn(
   338                add_one))
   339        assert_that(actual, equal_to(expected), label='assert:inferences')
   340  
   341    def test_run_inference_preprocessing_dlq(self):
   342      def mult_two(example: str) -> int:
   343        if example == "5":
   344          raise Exception("TEST")
   345        return int(example) * 2
   346  
   347      with TestPipeline() as pipeline:
   348        examples = ["1", "5", "3", "10"]
   349        expected = [3, 7, 21]
   350        expected_bad = ["5"]
   351        pcoll = pipeline | 'start' >> beam.Create(examples)
   352        main, other = pcoll | base.RunInference(
   353            FakeModelHandler().with_preprocess_fn(mult_two)
   354            ).with_exception_handling()
   355        assert_that(main, equal_to(expected), label='assert:inferences')
   356        assert_that(
   357            other.failed_inferences, equal_to([]), label='assert:bad_infer')
   358  
   359        # bad will be in form [element, error]. Just pull out bad element.
   360        bad_without_error = other.failed_preprocessing[0] | beam.Map(
   361            lambda x: x[0])
   362        assert_that(
   363            bad_without_error, equal_to(expected_bad), label='assert:failures')
   364  
   365    def test_run_inference_postprocessing_dlq(self):
   366      def mult_two(example: int) -> str:
   367        if example == 6:
   368          raise Exception("TEST")
   369        return str(example * 2)
   370  
   371      with TestPipeline() as pipeline:
   372        examples = [1, 5, 3, 10]
   373        expected = ["4", "8", "22"]
   374        expected_bad = [6]
   375        pcoll = pipeline | 'start' >> beam.Create(examples)
   376        main, other = pcoll | base.RunInference(
   377            FakeModelHandler().with_postprocess_fn(mult_two)
   378            ).with_exception_handling()
   379        assert_that(main, equal_to(expected), label='assert:inferences')
   380        assert_that(
   381            other.failed_inferences, equal_to([]), label='assert:bad_infer')
   382  
   383        # bad will be in form [element, error]. Just pull out bad element.
   384        bad_without_error = other.failed_postprocessing[0] | beam.Map(
   385            lambda x: x[0])
   386        assert_that(
   387            bad_without_error, equal_to(expected_bad), label='assert:failures')
   388  
   389    def test_run_inference_pre_and_post_processing_dlq(self):
   390      def mult_two_pre(example: str) -> int:
   391        if example == "5":
   392          raise Exception("TEST")
   393        return int(example) * 2
   394  
   395      def mult_two_post(example: int) -> str:
   396        if example == 7:
   397          raise Exception("TEST")
   398        return str(example * 2)
   399  
   400      with TestPipeline() as pipeline:
   401        examples = ["1", "5", "3", "10"]
   402        expected = ["6", "42"]
   403        expected_bad_pre = ["5"]
   404        expected_bad_post = [7]
   405        pcoll = pipeline | 'start' >> beam.Create(examples)
   406        main, other = pcoll | base.RunInference(
   407            FakeModelHandler().with_preprocess_fn(
   408              mult_two_pre
   409              ).with_postprocess_fn(
   410                mult_two_post
   411                )).with_exception_handling()
   412        assert_that(main, equal_to(expected), label='assert:inferences')
   413        assert_that(
   414            other.failed_inferences, equal_to([]), label='assert:bad_infer')
   415  
   416        # bad will be in form [elements, error]. Just pull out bad element.
   417        bad_without_error_pre = other.failed_preprocessing[0] | beam.Map(
   418            lambda x: x[0])
   419        assert_that(
   420            bad_without_error_pre,
   421            equal_to(expected_bad_pre),
   422            label='assert:failures_pre')
   423  
   424        # bad will be in form [elements, error]. Just pull out bad element.
   425        bad_without_error_post = other.failed_postprocessing[0] | beam.Map(
   426            lambda x: x[0])
   427        assert_that(
   428            bad_without_error_post,
   429            equal_to(expected_bad_post),
   430            label='assert:failures_post')
   431  
   432    def test_run_inference_keyed_pre_and_post_processing(self):
   433      def mult_two(element):
   434        return (element[0], element[1] * 2)
   435  
   436      with TestPipeline() as pipeline:
   437        examples = [1, 5, 3, 10]
   438        keyed_examples = [(i, example) for i, example in enumerate(examples)]
   439        expected = [
   440            (i, ((example * 2) + 1) * 2) for i, example in enumerate(examples)
   441        ]
   442        pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
   443        actual = pcoll | base.RunInference(
   444            base.KeyedModelHandler(FakeModelHandler()).with_preprocess_fn(
   445                mult_two).with_postprocess_fn(mult_two))
   446        assert_that(actual, equal_to(expected), label='assert:inferences')
   447  
   448    def test_run_inference_maybe_keyed_pre_and_post_processing(self):
   449      def mult_two(element):
   450        return element * 2
   451  
   452      def mult_two_keyed(element):
   453        return (element[0], element[1] * 2)
   454  
   455      with TestPipeline() as pipeline:
   456        examples = [1, 5, 3, 10]
   457        keyed_examples = [(i, example) for i, example in enumerate(examples)]
   458        expected = [((2 * example) + 1) * 2 for example in examples]
   459        keyed_expected = [
   460            (i, ((2 * example) + 1) * 2) for i, example in enumerate(examples)
   461        ]
   462        model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
   463  
   464        pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
   465        actual = pcoll | 'RunUnkeyed' >> base.RunInference(
   466            model_handler.with_preprocess_fn(mult_two).with_postprocess_fn(
   467                mult_two))
   468        assert_that(actual, equal_to(expected), label='CheckUnkeyed')
   469  
   470        keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
   471        keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
   472            model_handler.with_preprocess_fn(mult_two_keyed).with_postprocess_fn(
   473                mult_two_keyed))
   474        assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
   475  
   476    def test_run_inference_impl_dlq(self):
   477      with TestPipeline() as pipeline:
   478        examples = [1, 'TEST', 3, 10, 'TEST2']
   479        expected_good = [2, 4, 11]
   480        expected_bad = ['TEST', 'TEST2']
   481        pcoll = pipeline | 'start' >> beam.Create(examples)
   482        main, other = pcoll | base.RunInference(
   483            FakeModelHandler(
   484              min_batch_size=1,
   485              max_batch_size=1
   486            )).with_exception_handling()
   487        assert_that(main, equal_to(expected_good), label='assert:inferences')
   488  
   489        # bad.failed_inferences will be in form [batch[elements], error].
   490        # Just pull out bad element.
   491        bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0])
   492        assert_that(
   493            bad_without_error, equal_to(expected_bad), label='assert:failures')
   494  
   495    def test_run_inference_impl_inference_args(self):
   496      with TestPipeline() as pipeline:
   497        examples = [1, 5, 3, 10]
   498        pcoll = pipeline | 'start' >> beam.Create(examples)
   499        inference_args = {'key': True}
   500        actual = pcoll | base.RunInference(
   501            FakeModelHandlerExpectedInferenceArgs(),
   502            inference_args=inference_args)
   503        assert_that(actual, equal_to(examples), label='assert:inferences')
   504  
   505    def test_run_inference_metrics_with_custom_namespace(self):
   506      metrics_namespace = 'my_custom_namespace'
   507      pipeline = TestPipeline()
   508      examples = [1, 5, 3, 10]
   509      pcoll = pipeline | 'start' >> beam.Create(examples)
   510      _ = pcoll | base.RunInference(
   511          FakeModelHandler(), metrics_namespace=metrics_namespace)
   512      result = pipeline.run()
   513      result.wait_until_finish()
   514  
   515      metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace)
   516      metrics = result.metrics().query(metrics_filter)
   517      assert len(metrics['counters']) != 0
   518      assert len(metrics['distributions']) != 0
   519  
   520      metrics_filter = MetricsFilter().with_namespace(namespace='fake_namespace')
   521      metrics = result.metrics().query(metrics_filter)
   522      assert len(metrics['counters']) == len(metrics['distributions']) == 0
   523  
   524    def test_unexpected_inference_args_passed(self):
   525      with self.assertRaisesRegex(ValueError, r'inference_args were provided'):
   526        with TestPipeline() as pipeline:
   527          examples = [1, 5, 3, 10]
   528          pcoll = pipeline | 'start' >> beam.Create(examples)
   529          inference_args = {'key': True}
   530          _ = pcoll | base.RunInference(
   531              FakeModelHandlerFailsOnInferenceArgs(),
   532              inference_args=inference_args)
   533  
   534    def test_increment_failed_batches_counter(self):
   535      with self.assertRaises(ValueError):
   536        with TestPipeline() as pipeline:
   537          examples = [7]
   538          pcoll = pipeline | 'start' >> beam.Create(examples)
   539          _ = pcoll | base.RunInference(FakeModelHandlerExpectedInferenceArgs())
   540          run_result = pipeline.run()
   541          run_result.wait_until_finish()
   542  
   543          metric_results = (
   544              run_result.metrics().query(
   545                  MetricsFilter().with_name('failed_batches_counter')))
   546          num_failed_batches_counter = metric_results['counters'][0]
   547          self.assertEqual(num_failed_batches_counter.committed, 3)
   548          # !!!: The above will need to be updated if retry behavior changes
   549  
   550    def test_failed_batches_counter_no_failures(self):
   551      pipeline = TestPipeline()
   552      examples = [7]
   553      pcoll = pipeline | 'start' >> beam.Create(examples)
   554      inference_args = {'key': True}
   555      _ = pcoll | base.RunInference(
   556          FakeModelHandlerExpectedInferenceArgs(), inference_args=inference_args)
   557      run_result = pipeline.run()
   558      run_result.wait_until_finish()
   559  
   560      metric_results = (
   561          run_result.metrics().query(
   562              MetricsFilter().with_name('failed_batches_counter')))
   563      self.assertEqual(len(metric_results['counters']), 0)
   564  
   565    def test_counted_metrics(self):
   566      pipeline = TestPipeline()
   567      examples = [1, 5, 3, 10]
   568      pcoll = pipeline | 'start' >> beam.Create(examples)
   569      _ = pcoll | base.RunInference(FakeModelHandler())
   570      run_result = pipeline.run()
   571      run_result.wait_until_finish()
   572  
   573      metric_results = (
   574          run_result.metrics().query(MetricsFilter().with_name('num_inferences')))
   575      num_inferences_counter = metric_results['counters'][0]
   576      self.assertEqual(num_inferences_counter.committed, 4)
   577  
   578      inference_request_batch_size = run_result.metrics().query(
   579          MetricsFilter().with_name('inference_request_batch_size'))
   580      self.assertTrue(inference_request_batch_size['distributions'])
   581      self.assertEqual(
   582          inference_request_batch_size['distributions'][0].result.sum, 4)
   583      inference_request_batch_byte_size = run_result.metrics().query(
   584          MetricsFilter().with_name('inference_request_batch_byte_size'))
   585      self.assertTrue(inference_request_batch_byte_size['distributions'])
   586      self.assertGreaterEqual(
   587          inference_request_batch_byte_size['distributions'][0].result.sum,
   588          len(pickle.dumps(examples)))
   589      inference_request_batch_byte_size = run_result.metrics().query(
   590          MetricsFilter().with_name('model_byte_size'))
   591      self.assertTrue(inference_request_batch_byte_size['distributions'])
   592  
   593    def test_timing_metrics(self):
   594      pipeline = TestPipeline()
   595      examples = [1, 5, 3, 10]
   596      pcoll = pipeline | 'start' >> beam.Create(examples)
   597      fake_clock = FakeClock()
   598      _ = pcoll | base.RunInference(
   599          FakeModelHandler(clock=fake_clock), clock=fake_clock)
   600      res = pipeline.run()
   601      res.wait_until_finish()
   602  
   603      metric_results = (
   604          res.metrics().query(
   605              MetricsFilter().with_name('inference_batch_latency_micro_secs')))
   606      batch_latency = metric_results['distributions'][0]
   607      self.assertEqual(batch_latency.result.count, 3)
   608      self.assertEqual(batch_latency.result.mean, 3000)
   609  
   610      metric_results = (
   611          res.metrics().query(
   612              MetricsFilter().with_name('load_model_latency_milli_secs')))
   613      load_model_latency = metric_results['distributions'][0]
   614      self.assertEqual(load_model_latency.result.count, 1)
   615      self.assertEqual(load_model_latency.result.mean, 500)
   616  
   617    def test_forwards_batch_args(self):
   618      examples = list(range(100))
   619      with TestPipeline() as pipeline:
   620        pcoll = pipeline | 'start' >> beam.Create(examples)
   621        actual = pcoll | base.RunInference(FakeModelHandlerNeedsBigBatch())
   622        assert_that(actual, equal_to(examples), label='assert:inferences')
   623  
   624    def test_run_inference_unkeyed_examples_with_keyed_model_handler(self):
   625      pipeline = TestPipeline()
   626      with self.assertRaises(TypeError):
   627        examples = [1, 3, 5]
   628        model_handler = base.KeyedModelHandler(FakeModelHandler())
   629        _ = (
   630            pipeline | 'Unkeyed' >> beam.Create(examples)
   631            | 'RunUnkeyed' >> base.RunInference(model_handler))
   632        pipeline.run()
   633  
   634    def test_run_inference_keyed_examples_with_unkeyed_model_handler(self):
   635      pipeline = TestPipeline()
   636      examples = [1, 3, 5]
   637      keyed_examples = [(i, example) for i, example in enumerate(examples)]
   638      model_handler = FakeModelHandler()
   639      with self.assertRaises(TypeError):
   640        _ = (
   641            pipeline | 'keyed' >> beam.Create(keyed_examples)
   642            | 'RunKeyed' >> base.RunInference(model_handler))
   643        pipeline.run()
   644  
   645    def test_model_handler_compatibility(self):
   646      # ** IMPORTANT ** Do not change this test to make your PR pass without
   647      # first reading below.
   648      # Be certain that the modification will not break third party
   649      # implementations of ModelHandler.
   650      # See issue https://github.com/apache/beam/issues/23484
   651      # If this test fails, likely third party implementations of
   652      # ModelHandler will break.
   653      class ThirdPartyHandler(base.ModelHandler[int, int, FakeModel]):
   654        def __init__(self, custom_parameter=None):
   655          pass
   656  
   657        def load_model(self) -> FakeModel:
   658          return FakeModel()
   659  
   660        def run_inference(
   661            self,
   662            batch: Sequence[int],
   663            model: FakeModel,
   664            inference_args: Optional[Dict[str, Any]] = None) -> Iterable[int]:
   665          yield 0
   666  
   667        def get_num_bytes(self, batch: Sequence[int]) -> int:
   668          return 1
   669  
   670        def get_metrics_namespace(self) -> str:
   671          return 'ThirdParty'
   672  
   673        def get_resource_hints(self) -> dict:
   674          return {}
   675  
   676        def batch_elements_kwargs(self) -> Mapping[str, Any]:
   677          return {}
   678  
   679        def validate_inference_args(
   680            self, inference_args: Optional[Dict[str, Any]]):
   681          pass
   682  
   683      # This test passes if calling these methods does not cause
   684      # any runtime exceptions.
   685      third_party_model_handler = ThirdPartyHandler(custom_parameter=0)
   686      fake_model = third_party_model_handler.load_model()
   687      third_party_model_handler.run_inference([], fake_model)
   688      fake_inference_args = {'some_arg': 1}
   689      third_party_model_handler.run_inference([],
   690                                              fake_model,
   691                                              inference_args=fake_inference_args)
   692      third_party_model_handler.get_num_bytes([1, 2, 3])
   693      third_party_model_handler.get_metrics_namespace()
   694      third_party_model_handler.get_resource_hints()
   695      third_party_model_handler.batch_elements_kwargs()
   696      third_party_model_handler.validate_inference_args({})
   697  
   698    def test_run_inference_prediction_result_with_model_id(self):
   699      examples = [1, 5, 3, 10]
   700      expected = [
   701          base.PredictionResult(
   702              example=example,
   703              inference=example + 1,
   704              model_id='fake_model_id_default') for example in examples
   705      ]
   706      with TestPipeline() as pipeline:
   707        pcoll = pipeline | 'start' >> beam.Create(examples)
   708        actual = pcoll | base.RunInference(
   709            FakeModelHandlerReturnsPredictionResult())
   710        assert_that(actual, equal_to(expected), label='assert:inferences')
   711  
   712    def test_run_inference_with_iterable_side_input(self):
   713      test_pipeline = TestPipeline()
   714      side_input = (
   715          test_pipeline | "CreateDummySideInput" >> beam.Create(
   716              [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
   717          | "ApplySideInputWindow" >> beam.WindowInto(
   718              window.GlobalWindows(),
   719              trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
   720              accumulation_mode=trigger.AccumulationMode.DISCARDING))
   721  
   722      test_pipeline.options.view_as(StandardOptions).streaming = True
   723      with self.assertRaises(ValueError) as e:
   724        _ = (
   725            test_pipeline
   726            | beam.Create([1, 2, 3, 4])
   727            | base.RunInference(
   728                FakeModelHandler(), model_metadata_pcoll=side_input))
   729        test_pipeline.run()
   730  
   731      self.assertTrue(
   732          'PCollection of size 2 with more than one element accessed as a '
   733          'singleton view. First two elements encountered are' in str(
   734              e.exception))
   735  
   736    def test_run_inference_with_iterable_side_input_multi_process_shared(self):
   737      test_pipeline = TestPipeline()
   738      side_input = (
   739          test_pipeline | "CreateDummySideInput" >> beam.Create(
   740              [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
   741          | "ApplySideInputWindow" >> beam.WindowInto(
   742              window.GlobalWindows(),
   743              trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
   744              accumulation_mode=trigger.AccumulationMode.DISCARDING))
   745  
   746      test_pipeline.options.view_as(StandardOptions).streaming = True
   747      with self.assertRaises(ValueError) as e:
   748        _ = (
   749            test_pipeline
   750            | beam.Create([1, 2, 3, 4])
   751            | base.RunInference(
   752                FakeModelHandler(multi_process_shared=True),
   753                model_metadata_pcoll=side_input))
   754        test_pipeline.run()
   755  
   756      self.assertTrue(
   757          'PCollection of size 2 with more than one element accessed as a '
   758          'singleton view. First two elements encountered are' in str(
   759              e.exception))
   760  
   761    def test_run_inference_empty_side_input(self):
   762      model_handler = FakeModelHandlerReturnsPredictionResult()
   763      main_input_elements = [1, 2]
   764      with TestPipeline(is_integration_test=False) as pipeline:
   765        side_pcoll = pipeline | "side" >> beam.Create([])
   766        result_pcoll = (
   767            pipeline
   768            | beam.Create(main_input_elements)
   769            | base.RunInference(model_handler, model_metadata_pcoll=side_pcoll))
   770        expected = [
   771            base.PredictionResult(ele, ele + 1, 'fake_model_id_default')
   772            for ele in main_input_elements
   773        ]
   774  
   775        assert_that(result_pcoll, equal_to(expected))
   776  
   777    def test_run_inference_side_input_in_batch(self):
   778      first_ts = math.floor(time.time()) - 30
   779      interval = 7
   780  
   781      sample_main_input_elements = ([
   782          first_ts - 2,
   783          first_ts + 1,
   784          first_ts + 8,
   785          first_ts + 15,
   786          first_ts + 22,
   787      ])
   788  
   789      sample_side_input_elements = [
   790          (first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
   791          # if model_id is empty string, we use the default model
   792          # handler model URI.
   793          (
   794              first_ts + 8,
   795              base.ModelMetadata(
   796                  model_id='fake_model_id_1', model_name='fake_model_id_1')),
   797          (
   798              first_ts + 15,
   799              base.ModelMetadata(
   800                  model_id='fake_model_id_2', model_name='fake_model_id_2'))
   801      ]
   802  
   803      model_handler = FakeModelHandlerReturnsPredictionResult()
   804  
   805      # applying GroupByKey to utilize windowing according to
   806      # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
   807      class _EmitElement(beam.DoFn):
   808        def process(self, element):
   809          for e in element:
   810            yield e
   811  
   812      with TestPipeline() as pipeline:
   813        side_input = (
   814            pipeline
   815            |
   816            "CreateSideInputElements" >> beam.Create(sample_side_input_elements)
   817            | beam.Map(lambda x: TimestampedValue(x[1], x[0]))
   818            | beam.WindowInto(
   819                window.FixedWindows(interval),
   820                accumulation_mode=trigger.AccumulationMode.DISCARDING)
   821            | beam.Map(lambda x: ('key', x))
   822            | beam.GroupByKey()
   823            | beam.Map(lambda x: x[1])
   824            | "EmitSideInput" >> beam.ParDo(_EmitElement()))
   825  
   826        result_pcoll = (
   827            pipeline
   828            | beam.Create(sample_main_input_elements)
   829            | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
   830            | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
   831            | beam.Map(lambda x: ('key', x))
   832            | "MainInputGBK" >> beam.GroupByKey()
   833            | beam.Map(lambda x: x[1])
   834            | beam.ParDo(_EmitElement())
   835            | "RunInference" >> base.RunInference(
   836                model_handler, model_metadata_pcoll=side_input))
   837  
   838        expected_model_id_order = [
   839            'fake_model_id_default',
   840            'fake_model_id_default',
   841            'fake_model_id_1',
   842            'fake_model_id_2',
   843            'fake_model_id_2'
   844        ]
   845        expected_result = [
   846            base.PredictionResult(
   847                example=sample_main_input_elements[i],
   848                inference=sample_main_input_elements[i] + 1,
   849                model_id=expected_model_id_order[i]) for i in range(5)
   850        ]
   851  
   852        assert_that(result_pcoll, equal_to(expected_result))
   853  
   854    def test_run_inference_side_input_in_batch_multi_process_shared(self):
   855      first_ts = math.floor(time.time()) - 30
   856      interval = 7
   857  
   858      sample_main_input_elements = ([
   859          first_ts - 2,
   860          first_ts + 1,
   861          first_ts + 8,
   862          first_ts + 15,
   863          first_ts + 22,
   864      ])
   865  
   866      sample_side_input_elements = [
   867          (first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
   868          # if model_id is empty string, we use the default model
   869          # handler model URI.
   870          (
   871              first_ts + 8,
   872              base.ModelMetadata(
   873                  model_id='fake_model_id_1', model_name='fake_model_id_1')),
   874          (
   875              first_ts + 15,
   876              base.ModelMetadata(
   877                  model_id='fake_model_id_2', model_name='fake_model_id_2'))
   878      ]
   879  
   880      model_handler = FakeModelHandlerReturnsPredictionResult(
   881          multi_process_shared=True)
   882  
   883      # applying GroupByKey to utilize windowing according to
   884      # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
   885      class _EmitElement(beam.DoFn):
   886        def process(self, element):
   887          for e in element:
   888            yield e
   889  
   890      with TestPipeline() as pipeline:
   891        side_input = (
   892            pipeline
   893            |
   894            "CreateSideInputElements" >> beam.Create(sample_side_input_elements)
   895            | beam.Map(lambda x: TimestampedValue(x[1], x[0]))
   896            | beam.WindowInto(
   897                window.FixedWindows(interval),
   898                accumulation_mode=trigger.AccumulationMode.DISCARDING)
   899            | beam.Map(lambda x: ('key', x))
   900            | beam.GroupByKey()
   901            | beam.Map(lambda x: x[1])
   902            | "EmitSideInput" >> beam.ParDo(_EmitElement()))
   903  
   904        result_pcoll = (
   905            pipeline
   906            | beam.Create(sample_main_input_elements)
   907            | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
   908            | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
   909            | beam.Map(lambda x: ('key', x))
   910            | "MainInputGBK" >> beam.GroupByKey()
   911            | beam.Map(lambda x: x[1])
   912            | beam.ParDo(_EmitElement())
   913            | "RunInference" >> base.RunInference(
   914                model_handler, model_metadata_pcoll=side_input))
   915  
   916        expected_model_id_order = [
   917            'fake_model_id_default',
   918            'fake_model_id_default',
   919            'fake_model_id_1',
   920            'fake_model_id_2',
   921            'fake_model_id_2'
   922        ]
   923        expected_result = [
   924            base.PredictionResult(
   925                example=sample_main_input_elements[i],
   926                inference=sample_main_input_elements[i] + 1,
   927                model_id=expected_model_id_order[i]) for i in range(5)
   928        ]
   929  
   930        assert_that(result_pcoll, equal_to(expected_result))
   931  
   932    @unittest.skipIf(
   933        not TestPipeline().get_pipeline_options().view_as(
   934            StandardOptions).streaming,
   935        "SideInputs to RunInference are only supported in streaming mode.")
   936    @pytest.mark.it_postcommit
   937    @pytest.mark.sickbay_direct
   938    @pytest.mark.it_validatesrunner
   939    def test_run_inference_with_side_inputin_streaming(self):
   940      test_pipeline = TestPipeline(is_integration_test=True)
   941      test_pipeline.options.view_as(StandardOptions).streaming = True
   942      run_inference_side_inputs.run(
   943          test_pipeline.get_full_options_as_args(), save_main_session=False)
   944  
   945    def test_env_vars_set_correctly(self):
   946      handler_with_vars = FakeModelHandler(env_vars={'FOO': 'bar'})
   947      os.environ.pop('FOO', None)
   948      self.assertFalse('FOO' in os.environ)
   949      with TestPipeline() as pipeline:
   950        examples = [1, 2, 3]
   951        _ = (
   952            pipeline
   953            | 'start' >> beam.Create(examples)
   954            | base.RunInference(handler_with_vars))
   955        pipeline.run()
   956        self.assertTrue('FOO' in os.environ)
   957        self.assertTrue((os.environ['FOO']) == 'bar')
   958  
   959    def test_child_class_without_env_vars(self):
   960      with TestPipeline() as pipeline:
   961        examples = [1, 5, 3, 10]
   962        expected = [example + 1 for example in examples]
   963        pcoll = pipeline | 'start' >> beam.Create(examples)
   964        actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
   965        assert_that(actual, equal_to(expected), label='assert:inferences')
   966  
   967  
   968  if __name__ == '__main__':
   969    unittest.main()