github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/visionml.py (about) 1 # pylint: skip-file 2 # 3 # Licensed to the Apache Software Foundation (ASF) under one or more 4 # contributor license agreements. See the NOTICE file distributed with 5 # this work for additional information regarding copyright ownership. 6 # The ASF licenses this file to You under the Apache License, Version 2.0 7 # (the "License"); you may not use this file except in compliance with 8 # the License. You may obtain a copy of the License at 9 # 10 # http://www.apache.org/licenses/LICENSE-2.0 11 # 12 # Unless required by applicable law or agreed to in writing, software 13 # distributed under the License is distributed on an "AS IS" BASIS, 14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 # See the License for the specific language governing permissions and 16 # limitations under the License. 17 # 18 19 """ 20 A connector for sending API requests to the GCP Vision API. 21 """ 22 23 from typing import List 24 from typing import Optional 25 from typing import Tuple 26 from typing import Union 27 28 from apache_beam import typehints 29 from apache_beam.metrics import Metrics 30 from apache_beam.transforms import DoFn 31 from apache_beam.transforms import FlatMap 32 from apache_beam.transforms import ParDo 33 from apache_beam.transforms import PTransform 34 from apache_beam.transforms import util 35 from cachetools.func import ttl_cache 36 37 try: 38 from google.cloud import vision 39 except ImportError: 40 raise ImportError( 41 'Google Cloud Vision not supported for this execution environment ' 42 '(could not import google.cloud.vision).') 43 44 __all__ = [ 45 'AnnotateImage', 46 'AnnotateImageWithContext', 47 ] 48 49 50 @ttl_cache(maxsize=128, ttl=3600) 51 def get_vision_client(client_options=None): 52 """Returns a Cloud Vision API client.""" 53 _client = vision.ImageAnnotatorClient(client_options=client_options) 54 return _client 55 56 57 class AnnotateImage(PTransform): 58 """A ``PTransform`` for annotating images using the GCP Vision API. 59 ref: https://cloud.google.com/vision/docs/ 60 61 Batches elements together using ``util.BatchElements`` PTransform and sends 62 each batch of elements to the GCP Vision API. 63 Element is a Union[str, bytes] of either an URI (e.g. a GCS URI) 64 or bytes base64-encoded image data. 65 Accepts an `AsDict` side input that maps each image to an image context. 66 """ 67 68 MAX_BATCH_SIZE = 5 69 MIN_BATCH_SIZE = 1 70 71 def __init__( 72 self, 73 features, 74 retry=None, 75 timeout=120, 76 max_batch_size=None, 77 min_batch_size=None, 78 client_options=None, 79 context_side_input=None, 80 metadata=None): 81 """ 82 Args: 83 features: (List[``vision.Feature``]) Required. 84 The Vision API features to detect 85 retry: (google.api_core.retry.Retry) Optional. 86 A retry object used to retry requests. 87 If None is specified (default), requests will not be retried. 88 timeout: (float) Optional. 89 The time in seconds to wait for the response from the Vision API. 90 Default is 120. 91 max_batch_size: (int) Optional. 92 Maximum number of images to batch in the same request to the Vision API. 93 Default is 5 (which is also the Vision API max). 94 This parameter is primarily intended for testing. 95 min_batch_size: (int) Optional. 96 Minimum number of images to batch in the same request to the Vision API. 97 Default is None. This parameter is primarily intended for testing. 98 client_options: 99 (Union[dict, google.api_core.client_options.ClientOptions]) Optional. 100 Client options used to set user options on the client. 101 API Endpoint should be set through client_options. 102 context_side_input: (beam.pvalue.AsDict) Optional. 103 An ``AsDict`` of a PCollection to be passed to the 104 _ImageAnnotateFn as the image context mapping containing additional 105 image context and/or feature-specific parameters. 106 Example usage:: 107 108 image_contexts = 109 [(''gs://cloud-samples-data/vision/ocr/sign.jpg'', Union[dict, 110 ``vision.ImageContext()``]), 111 (''gs://cloud-samples-data/vision/ocr/sign.jpg'', Union[dict, 112 ``vision.ImageContext()``]),] 113 114 context_side_input = 115 ( 116 p 117 | "Image contexts" >> beam.Create(image_contexts) 118 ) 119 120 visionml.AnnotateImage(features, 121 context_side_input=beam.pvalue.AsDict(context_side_input))) 122 metadata: (Optional[Sequence[Tuple[str, str]]]): Optional. 123 Additional metadata that is provided to the method. 124 """ 125 super().__init__() 126 self.features = features 127 self.retry = retry 128 self.timeout = timeout 129 self.max_batch_size = max_batch_size or AnnotateImage.MAX_BATCH_SIZE 130 if self.max_batch_size > AnnotateImage.MAX_BATCH_SIZE: 131 raise ValueError( 132 'Max batch_size exceeded. ' 133 'Batch size needs to be smaller than {}'.format( 134 AnnotateImage.MAX_BATCH_SIZE)) 135 self.min_batch_size = min_batch_size or AnnotateImage.MIN_BATCH_SIZE 136 self.client_options = client_options 137 self.context_side_input = context_side_input 138 self.metadata = metadata 139 140 def expand(self, pvalue): 141 return ( 142 pvalue 143 | FlatMap(self._create_image_annotation_pairs, self.context_side_input) 144 | util.BatchElements( 145 min_batch_size=self.min_batch_size, 146 max_batch_size=self.max_batch_size) 147 | ParDo( 148 _ImageAnnotateFn( 149 features=self.features, 150 retry=self.retry, 151 timeout=self.timeout, 152 client_options=self.client_options, 153 metadata=self.metadata))) 154 155 @typehints.with_input_types(Union[str, bytes], Optional[vision.ImageContext]) 156 @typehints.with_output_types(List[vision.AnnotateImageRequest]) 157 def _create_image_annotation_pairs(self, element, context_side_input): 158 if context_side_input: # If we have a side input image context, use that 159 image_context = context_side_input.get(element) 160 else: 161 image_context = None 162 163 if isinstance(element, str): 164 165 image = vision.Image( 166 {'source': vision.ImageSource({'image_uri': element})}) 167 168 else: # Typehint checks only allows str or bytes 169 image = vision.Image(content=element) 170 171 request = vision.AnnotateImageRequest({ 172 'image': image, 173 'features': self.features, 174 'image_context': image_context 175 }) 176 yield request 177 178 179 class AnnotateImageWithContext(AnnotateImage): 180 """A ``PTransform`` for annotating images using the GCP Vision API. 181 ref: https://cloud.google.com/vision/docs/ 182 Batches elements together using ``util.BatchElements`` PTransform and sends 183 each batch of elements to the GCP Vision API. 184 185 Element is a tuple of:: 186 187 (Union[str, bytes], 188 Optional[``vision.ImageContext``]) 189 190 where the former is either an URI (e.g. a GCS URI) or bytes 191 base64-encoded image data. 192 """ 193 def __init__( 194 self, 195 features, 196 retry=None, 197 timeout=120, 198 max_batch_size=None, 199 min_batch_size=None, 200 client_options=None, 201 metadata=None): 202 """ 203 Args: 204 features: (List[``vision.Feature``]) Required. 205 The Vision API features to detect 206 retry: (google.api_core.retry.Retry) Optional. 207 A retry object used to retry requests. 208 If None is specified (default), requests will not be retried. 209 timeout: (float) Optional. 210 The time in seconds to wait for the response from the Vision API. 211 Default is 120. 212 max_batch_size: (int) Optional. 213 Maximum number of images to batch in the same request to the Vision API. 214 Default is 5 (which is also the Vision API max). 215 This parameter is primarily intended for testing. 216 min_batch_size: (int) Optional. 217 Minimum number of images to batch in the same request to the Vision API. 218 Default is None. This parameter is primarily intended for testing. 219 client_options: 220 (Union[dict, google.api_core.client_options.ClientOptions]) Optional. 221 Client options used to set user options on the client. 222 API Endpoint should be set through client_options. 223 metadata: (Optional[Sequence[Tuple[str, str]]]): Optional. 224 Additional metadata that is provided to the method. 225 """ 226 super().__init__( 227 features=features, 228 retry=retry, 229 timeout=timeout, 230 max_batch_size=max_batch_size, 231 min_batch_size=min_batch_size, 232 client_options=client_options, 233 metadata=metadata) 234 235 def expand(self, pvalue): 236 return ( 237 pvalue 238 | FlatMap(self._create_image_annotation_pairs) 239 | util.BatchElements( 240 min_batch_size=self.min_batch_size, 241 max_batch_size=self.max_batch_size) 242 | ParDo( 243 _ImageAnnotateFn( 244 features=self.features, 245 retry=self.retry, 246 timeout=self.timeout, 247 client_options=self.client_options, 248 metadata=self.metadata))) 249 250 @typehints.with_input_types( 251 Tuple[Union[str, bytes], Optional[vision.ImageContext]]) 252 @typehints.with_output_types(List[vision.AnnotateImageRequest]) 253 def _create_image_annotation_pairs(self, element, **kwargs): 254 element, image_context = element # Unpack (image, image_context) tuple 255 if isinstance(element, str): 256 image = vision.Image( 257 {'source': vision.ImageSource({'image_uri': element})}) 258 else: # Typehint checks only allows str or bytes 259 image = vision.Image({"content": element}) 260 261 request = vision.AnnotateImageRequest({ 262 'image': image, 263 'features': self.features, 264 'image_context': image_context 265 }) 266 yield request 267 268 269 @typehints.with_input_types(List[vision.AnnotateImageRequest]) 270 class _ImageAnnotateFn(DoFn): 271 """A DoFn that sends each input element to the GCP Vision API. 272 Returns ``google.cloud.vision.BatchAnnotateImagesResponse``. 273 """ 274 def __init__(self, features, retry, timeout, client_options, metadata): 275 super().__init__() 276 self._client = None 277 self.features = features 278 self.retry = retry 279 self.timeout = timeout 280 self.client_options = client_options 281 self.metadata = metadata 282 self.counter = Metrics.counter(self.__class__, "API Calls") 283 284 def setup(self): 285 self._client = get_vision_client(self.client_options) 286 287 def process(self, element, *args, **kwargs): 288 response = self._client.batch_annotate_images( 289 requests=element, 290 retry=self.retry, 291 timeout=self.timeout, 292 metadata=self.metadata) 293 self.counter.inc() 294 yield response