github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/visionml_test.py (about) 1 # pylint: skip-file 2 # 3 # Licensed to the Apache Software Foundation (ASF) under one or more 4 # contributor license agreements. See the NOTICE file distributed with 5 # this work for additional information regarding copyright ownership. 6 # The ASF licenses this file to You under the Apache License, Version 2.0 7 # (the "License"); you may not use this file except in compliance with 8 # the License. You may obtain a copy of the License at 9 # 10 # http://www.apache.org/licenses/LICENSE-2.0 11 # 12 # Unless required by applicable law or agreed to in writing, software 13 # distributed under the License is distributed on an "AS IS" BASIS, 14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 # See the License for the specific language governing permissions and 16 # limitations under the License. 17 # 18 19 """Unit tests for visionml.""" 20 21 # pytype: skip-file 22 23 import logging 24 import unittest 25 26 import mock 27 28 import apache_beam as beam 29 from apache_beam.metrics import MetricsFilter 30 from apache_beam.typehints.decorators import TypeCheckError 31 32 # Protect against environments where vision lib is not available. 33 try: 34 from google.cloud.vision import ImageAnnotatorClient 35 from google.cloud import vision 36 from apache_beam.ml.gcp import visionml 37 except ImportError: 38 ImageAnnotatorClient = None 39 40 41 @unittest.skipIf( 42 ImageAnnotatorClient is None, 'Vision dependencies are not installed') 43 class VisionTest(unittest.TestCase): 44 def setUp(self): 45 self._mock_client = mock.Mock() 46 self._mock_client.batch_annotate_images.return_value = None 47 48 feature_type = vision.Feature.Type.TEXT_DETECTION 49 self.features = [ 50 vision.Feature({ 51 'type': feature_type, 'max_results': 3, 'model': "builtin/stable" 52 }) 53 ] 54 self.img_ctx = vision.ImageContext() 55 self.min_batch_size = 1 56 self.max_batch_size = 1 57 58 def test_AnnotateImage_URIs(self): 59 images_to_annotate = [ 60 'gs://cloud-samples-data/vision/ocr/sign.jpg', 61 'gs://cloud-samples-data/vision/ocr/sign.jpg' 62 ] 63 64 expected_counter = len(images_to_annotate) 65 with mock.patch.object(visionml, 66 'get_vision_client', 67 return_value=self._mock_client): 68 p = beam.Pipeline() 69 _ = ( 70 p 71 | "Create data" >> beam.Create(images_to_annotate) 72 | "Annotate image" >> visionml.AnnotateImage( 73 self.features, 74 min_batch_size=self.min_batch_size, 75 max_batch_size=self.max_batch_size)) 76 result = p.run() 77 result.wait_until_finish() 78 79 read_filter = MetricsFilter().with_name('API Calls') 80 query_result = result.metrics().query(read_filter) 81 if query_result['counters']: 82 read_counter = query_result['counters'][0] 83 self.assertTrue(read_counter.result == expected_counter) 84 85 def test_AnnotateImage_URI_with_side_input_context(self): 86 images_to_annotate = [ 87 'gs://cloud-samples-data/vision/ocr/sign.jpg', 88 'gs://cloud-samples-data/vision/ocr/sign.jpg' 89 ] 90 image_contexts = [ 91 ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx), 92 ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx), 93 ] 94 95 expected_counter = len(images_to_annotate) 96 with mock.patch.object(visionml, 97 'get_vision_client', 98 return_value=self._mock_client): 99 p = beam.Pipeline() 100 context_side_input = (p | "Image contexts" >> beam.Create(image_contexts)) 101 102 _ = ( 103 p 104 | "Create data" >> beam.Create(images_to_annotate) 105 | "Annotate image" >> visionml.AnnotateImage( 106 self.features, 107 min_batch_size=self.min_batch_size, 108 max_batch_size=self.max_batch_size, 109 context_side_input=beam.pvalue.AsDict(context_side_input))) 110 result = p.run() 111 result.wait_until_finish() 112 113 read_filter = MetricsFilter().with_name('API Calls') 114 query_result = result.metrics().query(read_filter) 115 if query_result['counters']: 116 read_counter = query_result['counters'][0] 117 self.assertTrue(read_counter.result == expected_counter) 118 119 def test_AnnotateImage_b64_content(self): 120 base_64_encoded_image = \ 121 b'YmVnaW4gNjQ0IGNhdC12aWRlby5tcDRNICAgICgmOVQ+NyFNPCMwUi4uZmFrZV92aWRlb' 122 images_to_annotate = [ 123 base_64_encoded_image, 124 base_64_encoded_image, 125 base_64_encoded_image, 126 ] 127 expected_counter = len(images_to_annotate) 128 with mock.patch.object(visionml, 129 'get_vision_client', 130 return_value=self._mock_client): 131 p = beam.Pipeline() 132 _ = ( 133 p 134 | "Create data" >> beam.Create(images_to_annotate) 135 | "Annotate image" >> visionml.AnnotateImage( 136 self.features, 137 min_batch_size=self.min_batch_size, 138 max_batch_size=self.max_batch_size)) 139 result = p.run() 140 result.wait_until_finish() 141 142 read_filter = MetricsFilter().with_name('API Calls') 143 query_result = result.metrics().query(read_filter) 144 if query_result['counters']: 145 read_counter = query_result['counters'][0] 146 self.assertTrue(read_counter.result == expected_counter) 147 148 def test_AnnotateImageWithContext_URIs(self): 149 images_to_annotate = [ 150 ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx), 151 ('gs://cloud-samples-data/vision/ocr/sign.jpg', None), 152 ('gs://cloud-samples-data/vision/ocr/sign.jpg', self.img_ctx), 153 ] 154 batch_size = 5 155 expected_counter = 1 # All images should fit in the same batch 156 with mock.patch.object(visionml, 157 'get_vision_client', 158 return_value=self._mock_client): 159 p = beam.Pipeline() 160 _ = ( 161 p 162 | "Create data" >> beam.Create(images_to_annotate) 163 | "Annotate image" >> visionml.AnnotateImageWithContext( 164 self.features, 165 min_batch_size=batch_size, 166 max_batch_size=batch_size)) 167 result = p.run() 168 result.wait_until_finish() 169 170 read_filter = MetricsFilter().with_name('API Calls') 171 query_result = result.metrics().query(read_filter) 172 if query_result['counters']: 173 read_counter = query_result['counters'][0] 174 self.assertTrue(read_counter.result == expected_counter) 175 176 def test_AnnotateImageWithContext_bad_input(self): 177 """AnnotateImageWithContext should not accept images without context""" 178 images_to_annotate = [ 179 'gs://cloud-samples-data/vision/ocr/sign.jpg', 180 'gs://cloud-samples-data/vision/ocr/sign.jpg' 181 ] 182 with mock.patch.object(visionml, 183 'get_vision_client', 184 return_value=self._mock_client): 185 with self.assertRaises(TypeCheckError): 186 p = beam.Pipeline() 187 _ = ( 188 p 189 | "Create data" >> beam.Create(images_to_annotate) 190 | "Annotate image" >> visionml.AnnotateImageWithContext( 191 self.features)) 192 result = p.run() 193 result.wait_until_finish() 194 195 def test_AnnotateImage_bad_input(self): 196 images_to_annotate = [123456789, 123456789, 123456789] 197 with mock.patch.object(visionml, 198 'get_vision_client', 199 return_value=self._mock_client): 200 with self.assertRaises(TypeCheckError): 201 p = beam.Pipeline() 202 _ = ( 203 p 204 | "Create data" >> beam.Create(images_to_annotate) 205 | "Annotate image" >> visionml.AnnotateImage(self.features)) 206 result = p.run() 207 result.wait_until_finish() 208 209 def test_AnnotateImage_URIs_large_batch(self): 210 images_to_annotate = [ 211 'gs://cloud-samples-data/vision/ocr/sign.jpg', 212 'gs://cloud-samples-data/vision/ocr/sign.jpg', 213 'gs://cloud-samples-data/vision/ocr/sign.jpg', 214 'gs://cloud-samples-data/vision/ocr/sign.jpg', 215 'gs://cloud-samples-data/vision/ocr/sign.jpg', 216 'gs://cloud-samples-data/vision/ocr/sign.jpg', 217 'gs://cloud-samples-data/vision/ocr/sign.jpg', 218 'gs://cloud-samples-data/vision/ocr/sign.jpg', 219 'gs://cloud-samples-data/vision/ocr/sign.jpg', 220 'gs://cloud-samples-data/vision/ocr/sign.jpg', 221 'gs://cloud-samples-data/vision/ocr/sign.jpg', 222 ] 223 224 batch_size = 5 225 expected_counter = 3 # All 11 images should fit in 3 batches 226 with mock.patch.object(visionml, 227 'get_vision_client', 228 return_value=self._mock_client): 229 p = beam.Pipeline() 230 _ = ( 231 p 232 | "Create data" >> beam.Create(images_to_annotate) 233 | "Annotate image" >> visionml.AnnotateImage( 234 self.features, 235 max_batch_size=batch_size, 236 min_batch_size=batch_size)) 237 result = p.run() 238 result.wait_until_finish() 239 240 read_filter = MetricsFilter().with_name('API Calls') 241 query_result = result.metrics().query(read_filter) 242 if query_result['counters']: 243 read_counter = query_result['counters'][0] 244 self.assertTrue(read_counter.result == expected_counter) 245 246 247 if __name__ == '__main__': 248 logging.getLogger().setLevel(logging.INFO) 249 unittest.main()