github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/onnx_inference_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """End-to-End test for Onnx Inference"""
    19  
    20  import logging
    21  import os
    22  import unittest
    23  import uuid
    24  
    25  import pytest
    26  
    27  from apache_beam.io.filesystems import FileSystems
    28  from apache_beam.testing.test_pipeline import TestPipeline
    29  
    30  # pylint: disable=ungrouped-imports
    31  try:
    32    import onnx
    33    from apache_beam.examples.inference import onnx_sentiment_classification
    34  except ImportError as e:
    35    onnx = None
    36  
    37  
    38  def process_outputs(filepath):
    39    with FileSystems().open(filepath) as f:
    40      lines = f.readlines()
    41    lines = [l.decode('utf-8').strip('\n') for l in lines]
    42    return lines
    43  
    44  
    45  @unittest.skipIf(
    46      os.getenv('FORCE_ONNX_IT') is None and onnx is None,
    47      'Missing dependencies. '
    48      'Test depends on onnx and transformers')
    49  class OnnxInference(unittest.TestCase):
    50    @pytest.mark.uses_onnx
    51    @pytest.mark.it_postcommit
    52    def test_onnx_run_inference_roberta_sentiment_classification(self):
    53      test_pipeline = TestPipeline(is_integration_test=True)
    54      # Path to text file containing some sentences
    55      file_of_sentences = (
    56          'gs://apache-beam-ml/testing/inputs/onnx/'
    57          'sentiment_classification_input.txt')
    58      output_file_dir = 'local/sentiment_classification/output'
    59      output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
    60  
    61      model_uri = (
    62          'gs://apache-beam-ml/models/'
    63          'roberta_sentiment_classification.onnx')
    64      extra_opts = {
    65          'input': file_of_sentences,
    66          'output': output_file,
    67          'model_uri': model_uri,
    68      }
    69      onnx_sentiment_classification.run(
    70          test_pipeline.get_full_options_as_args(**extra_opts),
    71          save_main_session=False)
    72  
    73      self.assertEqual(FileSystems().exists(output_file), True)
    74      predictions = process_outputs(filepath=output_file)
    75      actuals_file = (
    76          'gs://apache-beam-ml/testing/expected_outputs/'
    77          'test_onnx_run_inference_roberta_sentiment'
    78          '_classification_actuals.txt')
    79      actuals = process_outputs(filepath=actuals_file)
    80  
    81      predictions_dict = {}
    82      for prediction in predictions:
    83        text, predicted_text = prediction.split(';')
    84        predictions_dict[text] = predicted_text
    85  
    86      for actual in actuals:
    87        text, actual_predicted_text = actual.split(';')
    88        predicted_predicted_text = predictions_dict[text]
    89        self.assertEqual(actual_predicted_text, predicted_predicted_text)
    90  
    91  
    92  if __name__ == '__main__':
    93    logging.getLogger().setLevel(logging.DEBUG)
    94    unittest.main()