github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/visionml_test.py (about)

     1  # pylint: skip-file
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  """Unit tests for visionml."""
    20  
    21  # pytype: skip-file
    22  
    23  import logging
    24  import unittest
    25  
    26  import mock
    27  
    28  import apache_beam as beam
    29  from apache_beam.metrics import MetricsFilter
    30  from apache_beam.typehints.decorators import TypeCheckError
    31  
    32  # Protect against environments where vision lib is not available.
    33  try:
    34    from google.cloud.vision import ImageAnnotatorClient
    35    from google.cloud import vision
    36    from apache_beam.ml.gcp import visionml
    37  except ImportError:
    38    ImageAnnotatorClient = None
    39  
    40  
    41  @unittest.skipIf(
    42      ImageAnnotatorClient is None, 'Vision dependencies are not installed')
    43  class VisionTest(unittest.TestCase):
    44    def setUp(self):
    45      self._mock_client = mock.Mock()
    46      self._mock_client.batch_annotate_images.return_value = None
    47  
    48      feature_type = vision.Feature.Type.TEXT_DETECTION
    49      self.features = [
    50          vision.Feature({
    51              'type': feature_type, 'max_results': 3, 'model': "builtin/stable"
    52          })
    53      ]
    54      self.img_ctx = vision.ImageContext()
    55      self.min_batch_size = 1
    56      self.max_batch_size = 1
    57  
    58    def test_AnnotateImage_URIs(self):
    59      images_to_annotate = [
    60          'gs://cloud-samples-data/vision/ocr/sign.jpg',
    61          'gs://cloud-samples-data/vision/ocr/sign.jpg'
    62      ]
    63  
    64      expected_counter = len(images_to_annotate)
    65      with mock.patch.object(visionml,
    66                             'get_vision_client',
    67                             return_value=self._mock_client):
    68        p = beam.Pipeline()
    69        _ = (
    70            p
    71            | "Create data" >> beam.Create(images_to_annotate)
    72            | "Annotate image" >> visionml.AnnotateImage(
    73                self.features,
    74                min_batch_size=self.min_batch_size,
    75                max_batch_size=self.max_batch_size))
    76        result = p.run()
    77        result.wait_until_finish()
    78  
    79        read_filter = MetricsFilter().with_name('API Calls')
    80        query_result = result.metrics().query(read_filter)
    81        if query_result['counters']:
    82          read_counter = query_result['counters'][0]
    83          self.assertTrue(read_counter.result == expected_counter)
    84  
    85    def test_AnnotateImage_URI_with_side_input_context(self):
    86      images_to_annotate = [
    87          'gs://cloud-samples-data/vision/ocr/sign.jpg',
    88          'gs://cloud-samples-data/vision/ocr/sign.jpg'
    89      ]
    90      image_contexts = [
    91          ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx),
    92          ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx),
    93      ]
    94  
    95      expected_counter = len(images_to_annotate)
    96      with mock.patch.object(visionml,
    97                             'get_vision_client',
    98                             return_value=self._mock_client):
    99        p = beam.Pipeline()
   100        context_side_input = (p | "Image contexts" >> beam.Create(image_contexts))
   101  
   102        _ = (
   103            p
   104            | "Create data" >> beam.Create(images_to_annotate)
   105            | "Annotate image" >> visionml.AnnotateImage(
   106                self.features,
   107                min_batch_size=self.min_batch_size,
   108                max_batch_size=self.max_batch_size,
   109                context_side_input=beam.pvalue.AsDict(context_side_input)))
   110        result = p.run()
   111        result.wait_until_finish()
   112  
   113        read_filter = MetricsFilter().with_name('API Calls')
   114        query_result = result.metrics().query(read_filter)
   115        if query_result['counters']:
   116          read_counter = query_result['counters'][0]
   117          self.assertTrue(read_counter.result == expected_counter)
   118  
   119    def test_AnnotateImage_b64_content(self):
   120      base_64_encoded_image = \
   121        b'YmVnaW4gNjQ0IGNhdC12aWRlby5tcDRNICAgICgmOVQ+NyFNPCMwUi4uZmFrZV92aWRlb'
   122      images_to_annotate = [
   123          base_64_encoded_image,
   124          base_64_encoded_image,
   125          base_64_encoded_image,
   126      ]
   127      expected_counter = len(images_to_annotate)
   128      with mock.patch.object(visionml,
   129                             'get_vision_client',
   130                             return_value=self._mock_client):
   131        p = beam.Pipeline()
   132        _ = (
   133            p
   134            | "Create data" >> beam.Create(images_to_annotate)
   135            | "Annotate image" >> visionml.AnnotateImage(
   136                self.features,
   137                min_batch_size=self.min_batch_size,
   138                max_batch_size=self.max_batch_size))
   139        result = p.run()
   140        result.wait_until_finish()
   141  
   142        read_filter = MetricsFilter().with_name('API Calls')
   143        query_result = result.metrics().query(read_filter)
   144        if query_result['counters']:
   145          read_counter = query_result['counters'][0]
   146          self.assertTrue(read_counter.result == expected_counter)
   147  
   148    def test_AnnotateImageWithContext_URIs(self):
   149      images_to_annotate = [
   150          ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx),
   151          ('gs://cloud-samples-data/vision/ocr/sign.jpg', None),
   152          ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx),
   153      ]
   154      batch_size = 5
   155      expected_counter = 1  # All images should fit in the same batch
   156      with mock.patch.object(visionml,
   157                             'get_vision_client',
   158                             return_value=self._mock_client):
   159        p = beam.Pipeline()
   160        _ = (
   161            p
   162            | "Create data" >> beam.Create(images_to_annotate)
   163            | "Annotate image" >> visionml.AnnotateImageWithContext(
   164                self.features,
   165                min_batch_size=batch_size,
   166                max_batch_size=batch_size))
   167        result = p.run()
   168        result.wait_until_finish()
   169  
   170        read_filter = MetricsFilter().with_name('API Calls')
   171        query_result = result.metrics().query(read_filter)
   172        if query_result['counters']:
   173          read_counter = query_result['counters'][0]
   174          self.assertTrue(read_counter.result == expected_counter)
   175  
   176    def test_AnnotateImageWithContext_bad_input(self):
   177      """AnnotateImageWithContext should not accept images without context"""
   178      images_to_annotate = [
   179          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   180          'gs://cloud-samples-data/vision/ocr/sign.jpg'
   181      ]
   182      with mock.patch.object(visionml,
   183                             'get_vision_client',
   184                             return_value=self._mock_client):
   185        with self.assertRaises(TypeCheckError):
   186          p = beam.Pipeline()
   187          _ = (
   188              p
   189              | "Create data" >> beam.Create(images_to_annotate)
   190              | "Annotate image" >> visionml.AnnotateImageWithContext(
   191                  self.features))
   192          result = p.run()
   193          result.wait_until_finish()
   194  
   195    def test_AnnotateImage_bad_input(self):
   196      images_to_annotate = [123456789, 123456789, 123456789]
   197      with mock.patch.object(visionml,
   198                             'get_vision_client',
   199                             return_value=self._mock_client):
   200        with self.assertRaises(TypeCheckError):
   201          p = beam.Pipeline()
   202          _ = (
   203              p
   204              | "Create data" >> beam.Create(images_to_annotate)
   205              | "Annotate image" >> visionml.AnnotateImage(self.features))
   206          result = p.run()
   207          result.wait_until_finish()
   208  
   209    def test_AnnotateImage_URIs_large_batch(self):
   210      images_to_annotate = [
   211          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   212          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   213          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   214          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   215          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   216          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   217          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   218          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   219          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   220          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   221          'gs://cloud-samples-data/vision/ocr/sign.jpg',
   222      ]
   223  
   224      batch_size = 5
   225      expected_counter = 3  # All 11 images should fit in 3 batches
   226      with mock.patch.object(visionml,
   227                             'get_vision_client',
   228                             return_value=self._mock_client):
   229        p = beam.Pipeline()
   230        _ = (
   231            p
   232            | "Create data" >> beam.Create(images_to_annotate)
   233            | "Annotate image" >> visionml.AnnotateImage(
   234                self.features,
   235                max_batch_size=batch_size,
   236                min_batch_size=batch_size))
   237        result = p.run()
   238        result.wait_until_finish()
   239  
   240        read_filter = MetricsFilter().with_name('API Calls')
   241        query_result = result.metrics().query(read_filter)
   242        if query_result['counters']:
   243          read_counter = query_result['counters'][0]
   244          self.assertTrue(read_counter.result == expected_counter)
   245  
   246  
   247  if __name__ == '__main__':
   248    logging.getLogger().setLevel(logging.INFO)
   249    unittest.main()