github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/pytorch_inference_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import os
    21  import shutil
    22  import tempfile
    23  import unittest
    24  from collections import OrderedDict
    25  
    26  import numpy as np
    27  import pytest
    28  
    29  import apache_beam as beam
    30  from apache_beam.testing.test_pipeline import TestPipeline
    31  from apache_beam.testing.util import assert_that
    32  from apache_beam.testing.util import equal_to
    33  
    34  # Protect against environments where pytorch library is not available.
    35  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    36  try:
    37    import torch
    38    from apache_beam.ml.inference.base import PredictionResult
    39    from apache_beam.ml.inference.base import RunInference
    40    from apache_beam.ml.inference import pytorch_inference
    41    from apache_beam.ml.inference.pytorch_inference import default_keyed_tensor_inference_fn
    42    from apache_beam.ml.inference.pytorch_inference import default_tensor_inference_fn
    43    from apache_beam.ml.inference.pytorch_inference import make_keyed_tensor_model_fn
    44    from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn
    45    from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
    46    from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
    47  except ImportError:
    48    raise unittest.SkipTest('PyTorch dependencies are not installed')
    49  
    50  try:
    51    from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
    52  except ImportError:
    53    GCSFileSystem = None  # type: ignore
    54  
    55  TWO_FEATURES_EXAMPLES = [
    56      torch.from_numpy(np.array([1, 5], dtype="float32")),
    57      torch.from_numpy(np.array([3, 10], dtype="float32")),
    58      torch.from_numpy(np.array([-14, 0], dtype="float32")),
    59      torch.from_numpy(np.array([0.5, 0.5], dtype="float32")),
    60  ]
    61  
    62  TWO_FEATURES_PREDICTIONS = [
    63      PredictionResult(ex, pred) for ex,
    64      pred in zip(
    65          TWO_FEATURES_EXAMPLES,
    66          torch.Tensor(
    67              [f1 * 2.0 + f2 * 3 + 0.5
    68               for f1, f2 in TWO_FEATURES_EXAMPLES]).reshape(-1, 1))
    69  ]
    70  
    71  TWO_FEATURES_DICT_OUT_PREDICTIONS = [
    72      PredictionResult(
    73          p.example, {
    74              "output1": p.inference, "output2": p.inference
    75          }) for p in TWO_FEATURES_PREDICTIONS
    76  ]
    77  
    78  KEYED_TORCH_EXAMPLES = [
    79      {
    80          'k1': torch.from_numpy(np.array([1], dtype="float32")),
    81          'k2': torch.from_numpy(np.array([1.5], dtype="float32"))
    82      },
    83      {
    84          'k1': torch.from_numpy(np.array([5], dtype="float32")),
    85          'k2': torch.from_numpy(np.array([5.5], dtype="float32"))
    86      },
    87      {
    88          'k1': torch.from_numpy(np.array([-3], dtype="float32")),
    89          'k2': torch.from_numpy(np.array([-3.5], dtype="float32"))
    90      },
    91      {
    92          'k1': torch.from_numpy(np.array([10.0], dtype="float32")),
    93          'k2': torch.from_numpy(np.array([10.5], dtype="float32"))
    94      },
    95  ]
    96  
    97  KEYED_TORCH_PREDICTIONS = [
    98      PredictionResult(ex, pred) for ex,
    99      pred in zip(
   100          KEYED_TORCH_EXAMPLES,
   101          torch.Tensor([(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5)
   102                        for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1))
   103  ]
   104  
   105  KEYED_TORCH_HELPER_PREDICTIONS = [
   106      PredictionResult(ex, pred) for ex,
   107      pred in zip(
   108          KEYED_TORCH_EXAMPLES,
   109          torch.Tensor([(example['k1'] * 2.0 + 0.5) +
   110                        (example['k2'] * 2.0 + 0.5) + 0.5
   111                        for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1))
   112  ]
   113  
   114  KEYED_TORCH_DICT_OUT_PREDICTIONS = [
   115      PredictionResult(
   116          p.example, {
   117              "output1": p.inference, "output2": p.inference
   118          }) for p in KEYED_TORCH_PREDICTIONS
   119  ]
   120  
   121  
   122  class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor):
   123    def __init__(self, device, *, inference_fn=default_tensor_inference_fn):
   124      self._device = device
   125      self._inference_fn = inference_fn
   126      self._state_dict_path = None
   127      self._torch_script_model_path = None
   128  
   129  
   130  class TestPytorchModelHandlerKeyedTensorForInferenceOnly(
   131      PytorchModelHandlerKeyedTensor):
   132    def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn):
   133      self._device = device
   134      self._inference_fn = inference_fn
   135      self._state_dict_path = None
   136      self._torch_script_model_path = None
   137  
   138  
   139  def _compare_prediction_result(x, y):
   140    if isinstance(x.example, dict):
   141      example_equals = all(
   142          torch.equal(x, y) for x,
   143          y in zip(x.example.values(), y.example.values()))
   144    else:
   145      example_equals = torch.equal(x.example, y.example)
   146    if not example_equals:
   147      return False
   148  
   149    if isinstance(x.inference, dict):
   150      return all(
   151          torch.equal(x, y) for x,
   152          y in zip(x.inference.values(), y.inference.values()))
   153  
   154    return torch.equal(x.inference, y.inference)
   155  
   156  
   157  def custom_tensor_inference_fn(
   158      batch, model, device, inference_args, model_id=None):
   159    predictions = [
   160        PredictionResult(ex, pred) for ex,
   161        pred in zip(
   162            batch,
   163            torch.Tensor([item * 2.0 + 1.5 for item in batch]).reshape(-1, 1))
   164    ]
   165    return predictions
   166  
   167  
   168  class PytorchLinearRegression(torch.nn.Module):
   169    def __init__(self, input_dim, output_dim):
   170      super().__init__()
   171      self.linear = torch.nn.Linear(input_dim, output_dim)
   172  
   173    def forward(self, x):
   174      out = self.linear(x)
   175      return out
   176  
   177    def generate(self, x):
   178      out = self.linear(x) + 0.5
   179      return out
   180  
   181  
   182  class PytorchLinearRegressionDict(torch.nn.Module):
   183    def __init__(self, input_dim, output_dim):
   184      super().__init__()
   185      self.linear = torch.nn.Linear(input_dim, output_dim)
   186  
   187    def forward(self, x):
   188      out = self.linear(x)
   189      return {'output1': out, 'output2': out}
   190  
   191  
   192  class PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(torch.nn.Module):
   193    """
   194    A linear model with batched keyed inputs and non-batchable extra args.
   195  
   196    Note: k1 and k2 are batchable examples passed in as a dict from str to tensor.
   197    prediction_param_array, prediction_param_bool are non-batchable extra args
   198    (typically model-related info) used to configure the model before its predict
   199    call is invoked
   200    """
   201    def __init__(self, input_dim, output_dim):
   202      super().__init__()
   203      self.linear = torch.nn.Linear(input_dim, output_dim)
   204  
   205    def forward(self, k1, k2, prediction_param_array, prediction_param_bool):
   206      if not prediction_param_bool:
   207        raise ValueError("Expected prediction_param_bool to be True")
   208      if not torch.all(prediction_param_array):
   209        raise ValueError("Expected prediction_param_array to be all True")
   210      out = self.linear(k1) + self.linear(k2)
   211      return out
   212  
   213  
   214  @pytest.mark.uses_pytorch
   215  class PytorchRunInferenceTest(unittest.TestCase):
   216    def test_run_inference_single_tensor_feature(self):
   217      examples = [
   218          torch.from_numpy(np.array([1], dtype="float32")),
   219          torch.from_numpy(np.array([5], dtype="float32")),
   220          torch.from_numpy(np.array([-3], dtype="float32")),
   221          torch.from_numpy(np.array([10.0], dtype="float32")),
   222      ]
   223      expected_predictions = [
   224          PredictionResult(ex, pred) for ex,
   225          pred in zip(
   226              examples,
   227              torch.Tensor([example * 2.0 + 0.5
   228                            for example in examples]).reshape(-1, 1))
   229      ]
   230  
   231      model = PytorchLinearRegression(input_dim=1, output_dim=1)
   232      model.load_state_dict(
   233          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   234                       ('linear.bias', torch.Tensor([0.5]))]))
   235      model.eval()
   236  
   237      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   238          torch.device('cpu'))
   239      predictions = inference_runner.run_inference(examples, model)
   240      for actual, expected in zip(predictions, expected_predictions):
   241        self.assertEqual(actual, expected)
   242  
   243    def test_run_inference_multiple_tensor_features(self):
   244      model = PytorchLinearRegression(input_dim=2, output_dim=1)
   245      model.load_state_dict(
   246          OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   247                       ('linear.bias', torch.Tensor([0.5]))]))
   248      model.eval()
   249  
   250      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   251          torch.device('cpu'))
   252      predictions = inference_runner.run_inference(TWO_FEATURES_EXAMPLES, model)
   253      for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS):
   254        self.assertEqual(actual, expected)
   255  
   256    def test_run_inference_multiple_tensor_features_dict_output(self):
   257      model = PytorchLinearRegressionDict(input_dim=2, output_dim=1)
   258      model.load_state_dict(
   259          OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   260                       ('linear.bias', torch.Tensor([0.5]))]))
   261      model.eval()
   262  
   263      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   264          torch.device('cpu'))
   265      predictions = inference_runner.run_inference(TWO_FEATURES_EXAMPLES, model)
   266      for actual, expected in zip(predictions, TWO_FEATURES_DICT_OUT_PREDICTIONS):
   267        self.assertEqual(actual, expected)
   268  
   269    def test_run_inference_custom(self):
   270      examples = [
   271          torch.from_numpy(np.array([1], dtype="float32")),
   272          torch.from_numpy(np.array([5], dtype="float32")),
   273          torch.from_numpy(np.array([-3], dtype="float32")),
   274          torch.from_numpy(np.array([10.0], dtype="float32")),
   275      ]
   276      expected_predictions = [
   277          PredictionResult(ex, pred) for ex,
   278          pred in zip(
   279              examples,
   280              torch.Tensor([example * 2.0 + 1.5
   281                            for example in examples]).reshape(-1, 1))
   282      ]
   283  
   284      model = PytorchLinearRegression(input_dim=1, output_dim=1)
   285      model.load_state_dict(
   286          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   287                       ('linear.bias', torch.Tensor([0.5]))]))
   288      model.eval()
   289  
   290      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   291          torch.device('cpu'), inference_fn=custom_tensor_inference_fn)
   292      predictions = inference_runner.run_inference(examples, model)
   293      for actual, expected in zip(predictions, expected_predictions):
   294        self.assertEqual(actual, expected)
   295  
   296    def test_run_inference_keyed(self):
   297      """
   298      This tests for inputs that are passed as a dictionary from key to tensor
   299      instead of a standard non-keyed tensor example.
   300  
   301      Example:
   302      Typical input format is
   303      input = torch.tensor([1, 2, 3])
   304  
   305      But Pytorch syntax allows inputs to have the form
   306      input = {
   307        'k1' : torch.tensor([1, 2, 3]),
   308        'k2' : torch.tensor([4, 5, 6])
   309      }
   310      """
   311      class PytorchLinearRegressionMultipleArgs(torch.nn.Module):
   312        def __init__(self, input_dim, output_dim):
   313          super().__init__()
   314          self.linear = torch.nn.Linear(input_dim, output_dim)
   315  
   316        def forward(self, k1, k2):
   317          out = self.linear(k1) + self.linear(k2)
   318          return out
   319  
   320      model = PytorchLinearRegressionMultipleArgs(input_dim=1, output_dim=1)
   321      model.load_state_dict(
   322          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   323                       ('linear.bias', torch.Tensor([0.5]))]))
   324      model.eval()
   325  
   326      inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
   327          torch.device('cpu'))
   328      predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model)
   329      for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS):
   330        self.assertTrue(_compare_prediction_result(actual, expected))
   331  
   332    def test_run_inference_keyed_dict_output(self):
   333      class PytorchLinearRegressionMultipleArgsDict(torch.nn.Module):
   334        def __init__(self, input_dim, output_dim):
   335          super().__init__()
   336          self.linear = torch.nn.Linear(input_dim, output_dim)
   337  
   338        def forward(self, k1, k2):
   339          out = self.linear(k1) + self.linear(k2)
   340          return {'output1': out, 'output2': out}
   341  
   342      model = PytorchLinearRegressionMultipleArgsDict(input_dim=1, output_dim=1)
   343      model.load_state_dict(
   344          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   345                       ('linear.bias', torch.Tensor([0.5]))]))
   346      model.eval()
   347  
   348      inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
   349          torch.device('cpu'))
   350      predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model)
   351      for actual, expected in zip(predictions, KEYED_TORCH_DICT_OUT_PREDICTIONS):
   352        self.assertTrue(_compare_prediction_result(actual, expected))
   353  
   354    def test_inference_runner_inference_args(self):
   355      """
   356      This tests for non-batchable input arguments. Since we do the batching
   357      for the user, we have to distinguish between the inputs that should be
   358      batched and the ones that should not be batched.
   359      """
   360      inference_args = {
   361          'prediction_param_array': torch.from_numpy(
   362              np.array([1, 2], dtype="float32")),
   363          'prediction_param_bool': True
   364      }
   365  
   366      model = PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(
   367          input_dim=1, output_dim=1)
   368      model.load_state_dict(
   369          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   370                       ('linear.bias', torch.Tensor([0.5]))]))
   371      model.eval()
   372  
   373      inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
   374          torch.device('cpu'))
   375      predictions = inference_runner.run_inference(
   376          batch=KEYED_TORCH_EXAMPLES, model=model, inference_args=inference_args)
   377      for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS):
   378        self.assertEqual(actual, expected)
   379  
   380    def test_run_inference_helper(self):
   381      examples = [
   382          torch.from_numpy(np.array([1], dtype="float32")),
   383          torch.from_numpy(np.array([5], dtype="float32")),
   384          torch.from_numpy(np.array([-3], dtype="float32")),
   385          torch.from_numpy(np.array([10.0], dtype="float32")),
   386      ]
   387      expected_predictions = [
   388          PredictionResult(ex, pred) for ex,
   389          pred in zip(
   390              examples,
   391              torch.Tensor([example * 2.0 + 1.0
   392                            for example in examples]).reshape(-1, 1))
   393      ]
   394  
   395      gen_fn = make_tensor_model_fn('generate')
   396  
   397      model = PytorchLinearRegression(input_dim=1, output_dim=1)
   398      model.load_state_dict(
   399          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   400                       ('linear.bias', torch.Tensor([0.5]))]))
   401      model.eval()
   402  
   403      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   404          torch.device('cpu'), inference_fn=gen_fn)
   405      predictions = inference_runner.run_inference(examples, model)
   406      for actual, expected in zip(predictions, expected_predictions):
   407        self.assertEqual(actual, expected)
   408  
   409    def test_run_inference_keyed_helper(self):
   410      """
   411      This tests for inputs that are passed as a dictionary from key to tensor
   412      instead of a standard non-keyed tensor example.
   413  
   414      Example:
   415      Typical input format is
   416      input = torch.tensor([1, 2, 3])
   417  
   418      But Pytorch syntax allows inputs to have the form
   419      input = {
   420        'k1' : torch.tensor([1, 2, 3]),
   421        'k2' : torch.tensor([4, 5, 6])
   422      }
   423      """
   424      class PytorchLinearRegressionMultipleArgs(torch.nn.Module):
   425        def __init__(self, input_dim, output_dim):
   426          super().__init__()
   427          self.linear = torch.nn.Linear(input_dim, output_dim)
   428  
   429        def forward(self, k1, k2):
   430          out = self.linear(k1) + self.linear(k2)
   431          return out
   432  
   433        def generate(self, k1, k2):
   434          out = self.linear(k1) + self.linear(k2) + 0.5
   435          return out
   436  
   437      model = PytorchLinearRegressionMultipleArgs(input_dim=1, output_dim=1)
   438      model.load_state_dict(
   439          OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   440                       ('linear.bias', torch.Tensor([0.5]))]))
   441      model.eval()
   442  
   443      gen_fn = make_keyed_tensor_model_fn('generate')
   444  
   445      inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
   446          torch.device('cpu'), inference_fn=gen_fn)
   447      predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model)
   448      for actual, expected in zip(predictions, KEYED_TORCH_HELPER_PREDICTIONS):
   449        self.assertTrue(_compare_prediction_result(actual, expected))
   450  
   451    def test_num_bytes(self):
   452      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   453          torch.device('cpu'))
   454      examples = torch.from_numpy(
   455          np.array([1, 5, 3, 10, -14, 0, 0.5, 0.5],
   456                   dtype="float32")).reshape(-1, 2)
   457      self.assertEqual((examples[0].element_size()) * 8,
   458                       inference_runner.get_num_bytes(examples))
   459  
   460    def test_namespace(self):
   461      inference_runner = TestPytorchModelHandlerForInferenceOnly(
   462          torch.device('cpu'))
   463      self.assertEqual('BeamML_PyTorch', inference_runner.get_metrics_namespace())
   464  
   465  
   466  @pytest.mark.uses_pytorch
   467  class PytorchRunInferencePipelineTest(unittest.TestCase):
   468    def setUp(self):
   469      self.tmpdir = tempfile.mkdtemp()
   470  
   471    def tearDown(self):
   472      shutil.rmtree(self.tmpdir)
   473  
   474    def test_pipeline_local_model_simple(self):
   475      with TestPipeline() as pipeline:
   476        state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   477                                  ('linear.bias', torch.Tensor([0.5]))])
   478        path = os.path.join(self.tmpdir, 'my_state_dict_path')
   479        torch.save(state_dict, path)
   480  
   481        model_handler = PytorchModelHandlerTensor(
   482            state_dict_path=path,
   483            model_class=PytorchLinearRegression,
   484            model_params={
   485                'input_dim': 2, 'output_dim': 1
   486            })
   487  
   488        pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
   489        predictions = pcoll | RunInference(model_handler)
   490        assert_that(
   491            predictions,
   492            equal_to(
   493                TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result))
   494  
   495    def test_pipeline_local_model_extra_inference_args(self):
   496      with TestPipeline() as pipeline:
   497        inference_args = {
   498            'prediction_param_array': torch.from_numpy(
   499                np.array([1, 2], dtype="float32")),
   500            'prediction_param_bool': True
   501        }
   502  
   503        state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   504                                  ('linear.bias', torch.Tensor([0.5]))])
   505        path = os.path.join(self.tmpdir, 'my_state_dict_path')
   506        torch.save(state_dict, path)
   507  
   508        model_handler = PytorchModelHandlerKeyedTensor(
   509            state_dict_path=path,
   510            model_class=PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs,
   511            model_params={
   512                'input_dim': 1, 'output_dim': 1
   513            })
   514  
   515        pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES)
   516        inference_args_side_input = (
   517            pipeline | 'create side' >> beam.Create(inference_args))
   518        predictions = pcoll | RunInference(
   519            model_handler=model_handler,
   520            inference_args=beam.pvalue.AsDict(inference_args_side_input))
   521        assert_that(
   522            predictions,
   523            equal_to(
   524                KEYED_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result))
   525  
   526    def test_pipeline_local_model_extra_inference_args_batching_args(self):
   527      with TestPipeline() as pipeline:
   528        inference_args = {
   529            'prediction_param_array': torch.from_numpy(
   530                np.array([1, 2], dtype="float32")),
   531            'prediction_param_bool': True
   532        }
   533  
   534        state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   535                                  ('linear.bias', torch.Tensor([0.5]))])
   536        path = os.path.join(self.tmpdir, 'my_state_dict_path')
   537        torch.save(state_dict, path)
   538  
   539        def batch_validator_keyed_tensor_inference_fn(
   540            batch,
   541            model,
   542            device,
   543            inference_args,
   544            model_id,
   545        ):
   546          if len(batch) != 2:
   547            raise Exception(
   548                f'Expected batch of size 2, received batch of size {len(batch)}')
   549          return default_keyed_tensor_inference_fn(
   550              batch, model, device, inference_args, model_id)
   551  
   552        model_handler = PytorchModelHandlerKeyedTensor(
   553            state_dict_path=path,
   554            model_class=PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs,
   555            model_params={
   556                'input_dim': 1, 'output_dim': 1
   557            },
   558            inference_fn=batch_validator_keyed_tensor_inference_fn,
   559            min_batch_size=2,
   560            max_batch_size=2)
   561  
   562        pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES)
   563        inference_args_side_input = (
   564            pipeline | 'create side' >> beam.Create(inference_args))
   565        predictions = pcoll | RunInference(
   566            model_handler=model_handler,
   567            inference_args=beam.pvalue.AsDict(inference_args_side_input))
   568        assert_that(
   569            predictions,
   570            equal_to(
   571                KEYED_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result))
   572  
   573    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   574    def test_pipeline_gcs_model(self):
   575      with TestPipeline() as pipeline:
   576        examples = torch.from_numpy(
   577            np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1))
   578        expected_predictions = [
   579            PredictionResult(ex, pred) for ex,
   580            pred in zip(
   581                examples,
   582                torch.Tensor([example * 2.0 + 0.5
   583                              for example in examples]).reshape(-1, 1))
   584        ]
   585  
   586        gs_pth = 'gs://apache-beam-ml/models/' \
   587            'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
   588        model_handler = PytorchModelHandlerTensor(
   589            state_dict_path=gs_pth,
   590            model_class=PytorchLinearRegression,
   591            model_params={
   592                'input_dim': 1, 'output_dim': 1
   593            })
   594  
   595        pcoll = pipeline | 'start' >> beam.Create(examples)
   596        predictions = pcoll | RunInference(model_handler)
   597        assert_that(
   598            predictions,
   599            equal_to(expected_predictions, equals_fn=_compare_prediction_result))
   600  
   601    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   602    def test_pipeline_gcs_model_control_batching(self):
   603      with TestPipeline() as pipeline:
   604        examples = torch.from_numpy(
   605            np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1))
   606        expected_predictions = [
   607            PredictionResult(ex, pred) for ex,
   608            pred in zip(
   609                examples,
   610                torch.Tensor([example * 2.0 + 0.5
   611                              for example in examples]).reshape(-1, 1))
   612        ]
   613  
   614        def batch_validator_tensor_inference_fn(
   615            batch,
   616            model,
   617            device,
   618            inference_args,
   619            model_id,
   620        ):
   621          if len(batch) != 2:
   622            raise Exception(
   623                f'Expected batch of size 2, received batch of size {len(batch)}')
   624          return default_tensor_inference_fn(
   625              batch, model, device, inference_args, model_id)
   626  
   627  
   628        gs_pth = 'gs://apache-beam-ml/models/' \
   629            'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
   630        model_handler = PytorchModelHandlerTensor(
   631            state_dict_path=gs_pth,
   632            model_class=PytorchLinearRegression,
   633            model_params={
   634                'input_dim': 1, 'output_dim': 1
   635            },
   636            inference_fn=batch_validator_tensor_inference_fn,
   637            min_batch_size=2,
   638            max_batch_size=2)
   639  
   640        pcoll = pipeline | 'start' >> beam.Create(examples)
   641        predictions = pcoll | RunInference(model_handler)
   642        assert_that(
   643            predictions,
   644            equal_to(expected_predictions, equals_fn=_compare_prediction_result))
   645  
   646    def test_invalid_input_type(self):
   647      with self.assertRaisesRegex(TypeError, "expected Tensor as element"):
   648        with TestPipeline() as pipeline:
   649          examples = np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1)
   650  
   651          state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   652                                    ('linear.bias', torch.Tensor([0.5]))])
   653          path = os.path.join(self.tmpdir, 'my_state_dict_path')
   654          torch.save(state_dict, path)
   655  
   656          model_handler = PytorchModelHandlerTensor(
   657              state_dict_path=path,
   658              model_class=PytorchLinearRegression,
   659              model_params={
   660                  'input_dim': 1, 'output_dim': 1
   661              })
   662  
   663          pcoll = pipeline | 'start' >> beam.Create(examples)
   664          # pylint: disable=expression-not-assigned
   665          pcoll | RunInference(model_handler)
   666  
   667    def test_gpu_auto_convert_to_cpu(self):
   668      """
   669      This tests the scenario in which the user defines `device='GPU'` for the
   670      PytorchModelHandlerX, but runs the pipeline on a machine without GPU, we
   671      automatically detect this discrepancy and do automatic conversion to CPU.
   672      A warning is also logged to inform the user.
   673      """
   674      with self.assertLogs() as log:
   675        with TestPipeline() as pipeline:
   676          examples = torch.from_numpy(
   677              np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1))
   678  
   679          state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   680                                    ('linear.bias', torch.Tensor([0.5]))])
   681          path = os.path.join(self.tmpdir, 'my_state_dict_path')
   682          torch.save(state_dict, path)
   683  
   684          model_handler = PytorchModelHandlerTensor(
   685              state_dict_path=path,
   686              model_class=PytorchLinearRegression,
   687              model_params={
   688                  'input_dim': 1, 'output_dim': 1
   689              },
   690              device='GPU')
   691          # Upon initialization, device is cuda
   692          self.assertEqual(model_handler._device, torch.device('cuda'))
   693  
   694          pcoll = pipeline | 'start' >> beam.Create(examples)
   695          # pylint: disable=expression-not-assigned
   696          pcoll | RunInference(model_handler)
   697  
   698          # During model loading, device converted to cuda
   699          self.assertEqual(model_handler._device, torch.device('cuda'))
   700  
   701        self.assertIn("INFO:root:Device is set to CUDA", log.output)
   702        self.assertIn(
   703            "WARNING:root:Model handler specified a 'GPU' device, but GPUs " \
   704            "are not available. Switching to CPU.",
   705            log.output)
   706  
   707    def test_load_torch_script_model(self):
   708      torch_model = PytorchLinearRegression(2, 1)
   709      torch_script_model = torch.jit.script(torch_model)
   710  
   711      torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
   712  
   713      torch.jit.save(torch_script_model, torch_script_path)
   714  
   715      model_handler = PytorchModelHandlerTensor(
   716          torch_script_model_path=torch_script_path)
   717  
   718      torch_script_model = model_handler.load_model()
   719  
   720      self.assertTrue(isinstance(torch_script_model, torch.jit.ScriptModule))
   721  
   722    def test_inference_torch_script_model(self):
   723      torch_model = PytorchLinearRegression(2, 1)
   724      torch_model.load_state_dict(
   725          OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   726                       ('linear.bias', torch.Tensor([0.5]))]))
   727  
   728      torch_script_model = torch.jit.script(torch_model)
   729  
   730      torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
   731  
   732      torch.jit.save(torch_script_model, torch_script_path)
   733  
   734      model_handler = PytorchModelHandlerTensor(
   735          torch_script_model_path=torch_script_path)
   736  
   737      with TestPipeline() as pipeline:
   738        pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
   739        predictions = pcoll | RunInference(model_handler)
   740        assert_that(
   741            predictions,
   742            equal_to(
   743                TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result))
   744  
   745    def test_torch_model_class_none(self):
   746      torch_model = PytorchLinearRegression(2, 1)
   747      torch_path = os.path.join(self.tmpdir, 'torch_model.pt')
   748  
   749      torch.save(torch_model, torch_path)
   750  
   751      with self.assertRaisesRegex(
   752          RuntimeError,
   753          "A state_dict_path has been supplied to the model "
   754          "handler, but the required model_class is missing. "
   755          "Please provide the model_class in order to"):
   756        _ = PytorchModelHandlerTensor(state_dict_path=torch_path)
   757  
   758      with self.assertRaisesRegex(
   759          RuntimeError,
   760          "A state_dict_path has been supplied to the model "
   761          "handler, but the required model_class is missing. "
   762          "Please provide the model_class in order to"):
   763        _ = (PytorchModelHandlerKeyedTensor(state_dict_path=torch_path))
   764  
   765    def test_torch_model_state_dict_none(self):
   766      with self.assertRaisesRegex(
   767          RuntimeError,
   768          "A model_class has been supplied to the model "
   769          "handler, but the required state_dict_path is missing. "
   770          "Please provide the state_dict_path in order to"):
   771        _ = PytorchModelHandlerTensor(model_class=PytorchLinearRegression)
   772  
   773      with self.assertRaisesRegex(
   774          RuntimeError,
   775          "A model_class has been supplied to the model "
   776          "handler, but the required state_dict_path is missing. "
   777          "Please provide the state_dict_path in order to"):
   778        _ = PytorchModelHandlerKeyedTensor(model_class=PytorchLinearRegression)
   779  
   780    def test_specify_torch_script_path_and_state_dict_path(self):
   781      torch_model = PytorchLinearRegression(2, 1)
   782      torch_path = os.path.join(self.tmpdir, 'torch_model.pt')
   783  
   784      torch.save(torch_model, torch_path)
   785      torch_script_model = torch.jit.script(torch_model)
   786  
   787      torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
   788  
   789      torch.jit.save(torch_script_model, torch_script_path)
   790      with self.assertRaisesRegex(
   791          RuntimeError, "Please specify either torch_script_model_path or "):
   792        _ = PytorchModelHandlerTensor(
   793            state_dict_path=torch_path,
   794            model_class=PytorchLinearRegression,
   795            torch_script_model_path=torch_script_path)
   796  
   797    def test_prediction_result_model_id_with_torch_script_model(self):
   798      torch_model = PytorchLinearRegression(2, 1)
   799      torch_script_model = torch.jit.script(torch_model)
   800      torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
   801      torch.jit.save(torch_script_model, torch_script_path)
   802  
   803      model_handler = PytorchModelHandlerTensor(
   804          torch_script_model_path=torch_script_path)
   805  
   806      def check_torch_script_model_id(element):
   807        assert ('torch_script_model.pt' in element.model_id) is True
   808  
   809      with TestPipeline() as pipeline:
   810        pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
   811        predictions = pcoll | RunInference(model_handler)
   812        _ = predictions | beam.Map(check_torch_script_model_id)
   813  
   814    def test_prediction_result_model_id_with_torch_model(self):
   815      # weights associated with PytorchLinearRegression class
   816      state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   817                                ('linear.bias', torch.Tensor([0.5]))])
   818      torch_path = os.path.join(self.tmpdir, 'torch_model.pt')
   819      torch.save(state_dict, torch_path)
   820  
   821      model_handler = PytorchModelHandlerTensor(
   822          state_dict_path=torch_path,
   823          model_class=PytorchLinearRegression,
   824          model_params={
   825              'input_dim': 2, 'output_dim': 1
   826          })
   827  
   828      def check_torch_script_model_id(element):
   829        assert ('torch_model.pt' in element.model_id) is True
   830  
   831      with TestPipeline() as pipeline:
   832        pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
   833        predictions = pcoll | RunInference(model_handler)
   834        _ = predictions | beam.Map(check_torch_script_model_id)
   835  
   836    def test_env_vars_set_correctly_tensor_handler(self):
   837      torch_model = PytorchLinearRegression(2, 1)
   838      torch_model.load_state_dict(
   839          OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   840                       ('linear.bias', torch.Tensor([0.5]))]))
   841  
   842      torch_script_model = torch.jit.script(torch_model)
   843  
   844      torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
   845  
   846      torch.jit.save(torch_script_model, torch_script_path)
   847  
   848      handler_with_vars = PytorchModelHandlerTensor(
   849          torch_script_model_path=torch_script_path, env_vars={'FOO': 'bar'})
   850      os.environ.pop('FOO', None)
   851      self.assertFalse('FOO' in os.environ)
   852      with TestPipeline() as pipeline:
   853        _ = (
   854            pipeline
   855            | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
   856            | RunInference(handler_with_vars))
   857        pipeline.run()
   858        self.assertTrue('FOO' in os.environ)
   859        self.assertTrue((os.environ['FOO']) == 'bar')
   860  
   861    def test_env_vars_set_correctly_keyed_tensor_handler(self):
   862      os.environ.pop('FOO', None)
   863      self.assertFalse('FOO' in os.environ)
   864      with TestPipeline() as pipeline:
   865        inference_args = {
   866            'prediction_param_array': torch.from_numpy(
   867                np.array([1, 2], dtype="float32")),
   868            'prediction_param_bool': True
   869        }
   870  
   871        state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
   872                                  ('linear.bias', torch.Tensor([0.5]))])
   873        path = os.path.join(self.tmpdir, 'my_state_dict_path')
   874        torch.save(state_dict, path)
   875  
   876        handler_with_vars = PytorchModelHandlerKeyedTensor(
   877            env_vars={'FOO': 'bar'},
   878            state_dict_path=path,
   879            model_class=PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs,
   880            model_params={
   881                'input_dim': 1, 'output_dim': 1
   882            })
   883        inference_args_side_input = (
   884            pipeline | 'create side' >> beam.Create(inference_args))
   885  
   886        _ = (
   887            pipeline
   888            | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES)
   889            | RunInference(
   890                model_handler=handler_with_vars,
   891                inference_args=beam.pvalue.AsDict(inference_args_side_input)))
   892        pipeline.run()
   893        self.assertTrue('FOO' in os.environ)
   894        self.assertTrue((os.environ['FOO']) == 'bar')
   895  
   896  
   897  @pytest.mark.uses_pytorch
   898  class PytorchInferenceTestWithMocks(unittest.TestCase):
   899    def setUp(self):
   900      self._load_model = pytorch_inference._load_model
   901      pytorch_inference._load_model = unittest.mock.MagicMock(
   902          return_value=("model", "device"))
   903      self.tmpdir = tempfile.mkdtemp()
   904      self.state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
   905                                     ('linear.bias', torch.Tensor([0.5]))])
   906      self.torch_path = os.path.join(self.tmpdir, 'torch_model.pt')
   907      torch.save(self.state_dict, self.torch_path)
   908      self.model_params = {'input_dim': 2, 'output_dim': 1}
   909  
   910    def tearDown(self):
   911      pytorch_inference._load_model = self._load_model
   912      shutil.rmtree(self.tmpdir)
   913  
   914    def test_load_model_args_tensor(self):
   915      load_model_args = {'weights_only': True}
   916      model_handler = PytorchModelHandlerTensor(
   917          state_dict_path=self.torch_path,
   918          model_class=PytorchLinearRegression,
   919          model_params=self.model_params,
   920          load_model_args=load_model_args)
   921      model_handler.load_model()
   922      pytorch_inference._load_model.assert_called_with(
   923          model_class=PytorchLinearRegression,
   924          state_dict_path=self.torch_path,
   925          device=torch.device('cpu'),
   926          model_params=self.model_params,
   927          torch_script_model_path=None,
   928          load_model_args=load_model_args)
   929  
   930    def test_load_model_args_keyed_tensor(self):
   931      load_model_args = {'weights_only': True}
   932      model_handler = PytorchModelHandlerKeyedTensor(
   933          state_dict_path=self.torch_path,
   934          model_class=PytorchLinearRegression,
   935          model_params=self.model_params,
   936          load_model_args=load_model_args)
   937      model_handler.load_model()
   938      pytorch_inference._load_model.assert_called_with(
   939          model_class=PytorchLinearRegression,
   940          state_dict_path=self.torch_path,
   941          device=torch.device('cpu'),
   942          model_params=self.model_params,
   943          torch_script_model_path=None,
   944          load_model_args=load_model_args)
   945  
   946  
   947  if __name__ == '__main__':
   948    unittest.main()