github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py (about)

     1  # coding=utf-8
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  # pytype: skip-file
    20  # pylint: disable=reimported
    21  
    22  import torch
    23  
    24  
    25  class LinearRegression(torch.nn.Module):
    26    def __init__(self, input_dim=1, output_dim=1):
    27      super().__init__()
    28      self.linear = torch.nn.Linear(input_dim, output_dim)
    29  
    30    def forward(self, x):
    31      out = self.linear(x)
    32      return out
    33  
    34  
    35  def torch_unkeyed_model_handler(test=None):
    36    # [START torch_unkeyed_model_handler]
    37    import apache_beam as beam
    38    import numpy
    39    import torch
    40    from apache_beam.ml.inference.base import RunInference
    41    from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
    42  
    43    model_state_dict_path = 'gs://apache-beam-samples/run_inference/five_times_table_torch.pt'  # pylint: disable=line-too-long
    44    model_class = LinearRegression
    45    model_params = {'input_dim': 1, 'output_dim': 1}
    46    model_handler = PytorchModelHandlerTensor(
    47        model_class=model_class,
    48        model_params=model_params,
    49        state_dict_path=model_state_dict_path)
    50  
    51    unkeyed_data = numpy.array([10, 40, 60, 90],
    52                               dtype=numpy.float32).reshape(-1, 1)
    53  
    54    with beam.Pipeline() as p:
    55      predictions = (
    56          p
    57          | 'InputData' >> beam.Create(unkeyed_data)
    58          | 'ConvertNumpyToTensor' >> beam.Map(torch.Tensor)
    59          | 'PytorchRunInference' >> RunInference(model_handler=model_handler)
    60          | beam.Map(print))
    61      # [END torch_unkeyed_model_handler]
    62      if test:
    63        test(predictions)
    64  
    65  
    66  def torch_keyed_model_handler(test=None):
    67    # [START torch_keyed_model_handler]
    68    import apache_beam as beam
    69    import torch
    70    from apache_beam.ml.inference.base import KeyedModelHandler
    71    from apache_beam.ml.inference.base import RunInference
    72    from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
    73  
    74    model_state_dict_path = 'gs://apache-beam-samples/run_inference/five_times_table_torch.pt'  # pylint: disable=line-too-long
    75    model_class = LinearRegression
    76    model_params = {'input_dim': 1, 'output_dim': 1}
    77    keyed_model_handler = KeyedModelHandler(
    78        PytorchModelHandlerTensor(
    79            model_class=model_class,
    80            model_params=model_params,
    81            state_dict_path=model_state_dict_path))
    82  
    83    keyed_data = [("first_question", 105.00), ("second_question", 108.00),
    84                  ("third_question", 1000.00), ("fourth_question", 1013.00)]
    85  
    86    with beam.Pipeline() as p:
    87      predictions = (
    88          p
    89          | 'KeyedInputData' >> beam.Create(keyed_data)
    90          | "ConvertIntToTensor" >>
    91          beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))
    92          | 'PytorchRunInference' >>
    93          RunInference(model_handler=keyed_model_handler)
    94          | beam.Map(print))
    95      # [END torch_keyed_model_handler]
    96      if test:
    97        test(predictions)
    98  
    99  
   100  def sklearn_unkeyed_model_handler(test=None):
   101    # [START sklearn_unkeyed_model_handler]
   102    import apache_beam as beam
   103    import numpy
   104    from apache_beam.ml.inference.base import RunInference
   105    from apache_beam.ml.inference.sklearn_inference import ModelFileType
   106    from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
   107  
   108    sklearn_model_filename = 'gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'  # pylint: disable=line-too-long
   109    sklearn_model_handler = SklearnModelHandlerNumpy(
   110        model_uri=sklearn_model_filename, model_file_type=ModelFileType.PICKLE)
   111  
   112    unkeyed_data = numpy.array([20, 40, 60, 90],
   113                               dtype=numpy.float32).reshape(-1, 1)
   114    with beam.Pipeline() as p:
   115      predictions = (
   116          p
   117          | "ReadInputs" >> beam.Create(unkeyed_data)
   118          | "RunInferenceSklearn" >>
   119          RunInference(model_handler=sklearn_model_handler)
   120          | beam.Map(print))
   121      # [END sklearn_unkeyed_model_handler]
   122      if test:
   123        test(predictions)
   124  
   125  
   126  def sklearn_keyed_model_handler(test=None):
   127    # [START sklearn_keyed_model_handler]
   128    import apache_beam as beam
   129    from apache_beam.ml.inference.base import KeyedModelHandler
   130    from apache_beam.ml.inference.base import RunInference
   131    from apache_beam.ml.inference.sklearn_inference import ModelFileType
   132    from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
   133  
   134    sklearn_model_filename = 'gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'  # pylint: disable=line-too-long
   135    sklearn_model_handler = KeyedModelHandler(
   136        SklearnModelHandlerNumpy(
   137            model_uri=sklearn_model_filename,
   138            model_file_type=ModelFileType.PICKLE))
   139  
   140    keyed_data = [("first_question", 105.00), ("second_question", 108.00),
   141                  ("third_question", 1000.00), ("fourth_question", 1013.00)]
   142  
   143    with beam.Pipeline() as p:
   144      predictions = (
   145          p
   146          | "ReadInputs" >> beam.Create(keyed_data)
   147          | "ConvertDataToList" >> beam.Map(lambda x: (x[0], [x[1]]))
   148          | "RunInferenceSklearn" >>
   149          RunInference(model_handler=sklearn_model_handler)
   150          | beam.Map(print))
   151      # [END sklearn_keyed_model_handler]
   152      if test:
   153        test(predictions)