github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/naturallanguageml_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  """Unit tests for Google Cloud Natural Language API transform."""
    20  
    21  import unittest
    22  
    23  import mock
    24  
    25  import apache_beam as beam
    26  from apache_beam.metrics import MetricsFilter
    27  from apache_beam.testing.test_pipeline import TestPipeline
    28  
    29  # Protect against environments where Google Cloud Natural Language client
    30  # is not available.
    31  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    32  try:
    33    from google.cloud import language
    34  except ImportError:
    35    language = None
    36  else:
    37    from apache_beam.ml.gcp import naturallanguageml
    38  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    39  
    40  
    41  @unittest.skipIf(language is None, 'GCP dependencies are not installed')
    42  class NaturalLanguageMlTest(unittest.TestCase):
    43    def assertCounterEqual(self, pipeline_result, counter_name, expected):
    44      metrics = pipeline_result.metrics().query(
    45          MetricsFilter().with_name(counter_name))
    46      try:
    47        counter = metrics['counters'][0]
    48        self.assertEqual(expected, counter.result)
    49      except IndexError:
    50        raise AssertionError('Counter "{}" was not found'.format(counter_name))
    51  
    52    def test_document_source(self):
    53      document = naturallanguageml.Document('Hello, world!')
    54      dict_ = naturallanguageml.Document.to_dict(document)
    55      self.assertTrue('content' in dict_)
    56      self.assertFalse('gcs_content_uri' in dict_)
    57  
    58      document = naturallanguageml.Document('gs://sample/location', from_gcs=True)
    59      dict_ = naturallanguageml.Document.to_dict(document)
    60      self.assertFalse('content' in dict_)
    61      self.assertTrue('gcs_content_uri' in dict_)
    62  
    63    def test_annotate_test_called(self):
    64      with mock.patch('apache_beam.ml.gcp.naturallanguageml._AnnotateTextFn'
    65                      '._get_api_client'):
    66        p = TestPipeline()
    67        features = [
    68            naturallanguageml.language_v1.AnnotateTextRequest.Features(
    69                extract_syntax=True)
    70        ]
    71        _ = (
    72            p | beam.Create([naturallanguageml.Document('Hello, world!')])
    73            | naturallanguageml.AnnotateText(features))
    74        result = p.run()
    75        result.wait_until_finish()
    76        self.assertCounterEqual(result, 'api_calls', 1)
    77  
    78  
    79  if __name__ == '__main__':
    80    unittest.main()