github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/visionml.py (about)

     1  # pylint: skip-file
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  """
    20  A connector for sending API requests to the GCP Vision API.
    21  """
    22  
    23  from typing import List
    24  from typing import Optional
    25  from typing import Tuple
    26  from typing import Union
    27  
    28  from apache_beam import typehints
    29  from apache_beam.metrics import Metrics
    30  from apache_beam.transforms import DoFn
    31  from apache_beam.transforms import FlatMap
    32  from apache_beam.transforms import ParDo
    33  from apache_beam.transforms import PTransform
    34  from apache_beam.transforms import util
    35  from cachetools.func import ttl_cache
    36  
    37  try:
    38    from google.cloud import vision
    39  except ImportError:
    40    raise ImportError(
    41        'Google Cloud Vision not supported for this execution environment '
    42        '(could not import google.cloud.vision).')
    43  
    44  __all__ = [
    45      'AnnotateImage',
    46      'AnnotateImageWithContext',
    47  ]
    48  
    49  
    50  @ttl_cache(maxsize=128, ttl=3600)
    51  def get_vision_client(client_options=None):
    52    """Returns a Cloud Vision API client."""
    53    _client = vision.ImageAnnotatorClient(client_options=client_options)
    54    return _client
    55  
    56  
    57  class AnnotateImage(PTransform):
    58    """A ``PTransform`` for annotating images using the GCP Vision API.
    59    ref: https://cloud.google.com/vision/docs/
    60  
    61    Batches elements together using ``util.BatchElements`` PTransform and sends
    62    each batch of elements to the GCP Vision API.
    63    Element is a Union[str, bytes] of either an URI (e.g. a GCS URI)
    64    or bytes base64-encoded image data.
    65    Accepts an `AsDict` side input that maps each image to an image context.
    66    """
    67  
    68    MAX_BATCH_SIZE = 5
    69    MIN_BATCH_SIZE = 1
    70  
    71    def __init__(
    72        self,
    73        features,
    74        retry=None,
    75        timeout=120,
    76        max_batch_size=None,
    77        min_batch_size=None,
    78        client_options=None,
    79        context_side_input=None,
    80        metadata=None):
    81      """
    82      Args:
    83        features: (List[``vision.Feature``]) Required.
    84          The Vision API features to detect
    85        retry: (google.api_core.retry.Retry) Optional.
    86          A retry object used to retry requests.
    87          If None is specified (default), requests will not be retried.
    88        timeout: (float) Optional.
    89          The time in seconds to wait for the response from the Vision API.
    90          Default is 120.
    91        max_batch_size: (int) Optional.
    92          Maximum number of images to batch in the same request to the Vision API.
    93          Default is 5 (which is also the Vision API max).
    94          This parameter is primarily intended for testing.
    95        min_batch_size: (int) Optional.
    96          Minimum number of images to batch in the same request to the Vision API.
    97          Default is None. This parameter is primarily intended for testing.
    98        client_options:
    99          (Union[dict, google.api_core.client_options.ClientOptions]) Optional.
   100          Client options used to set user options on the client.
   101          API Endpoint should be set through client_options.
   102        context_side_input: (beam.pvalue.AsDict) Optional.
   103          An ``AsDict`` of a PCollection to be passed to the
   104          _ImageAnnotateFn as the image context mapping containing additional
   105          image context and/or feature-specific parameters.
   106          Example usage::
   107  
   108            image_contexts =
   109              [(''gs://cloud-samples-data/vision/ocr/sign.jpg'', Union[dict,
   110              ``vision.ImageContext()``]),
   111              (''gs://cloud-samples-data/vision/ocr/sign.jpg'', Union[dict,
   112              ``vision.ImageContext()``]),]
   113  
   114            context_side_input =
   115              (
   116                p
   117                | "Image contexts" >> beam.Create(image_contexts)
   118              )
   119  
   120            visionml.AnnotateImage(features,
   121              context_side_input=beam.pvalue.AsDict(context_side_input)))
   122        metadata: (Optional[Sequence[Tuple[str, str]]]): Optional.
   123          Additional metadata that is provided to the method.
   124      """
   125      super().__init__()
   126      self.features = features
   127      self.retry = retry
   128      self.timeout = timeout
   129      self.max_batch_size = max_batch_size or AnnotateImage.MAX_BATCH_SIZE
   130      if self.max_batch_size > AnnotateImage.MAX_BATCH_SIZE:
   131        raise ValueError(
   132            'Max batch_size exceeded. '
   133            'Batch size needs to be smaller than {}'.format(
   134                AnnotateImage.MAX_BATCH_SIZE))
   135      self.min_batch_size = min_batch_size or AnnotateImage.MIN_BATCH_SIZE
   136      self.client_options = client_options
   137      self.context_side_input = context_side_input
   138      self.metadata = metadata
   139  
   140    def expand(self, pvalue):
   141      return (
   142          pvalue
   143          | FlatMap(self._create_image_annotation_pairs, self.context_side_input)
   144          | util.BatchElements(
   145              min_batch_size=self.min_batch_size,
   146              max_batch_size=self.max_batch_size)
   147          | ParDo(
   148              _ImageAnnotateFn(
   149                  features=self.features,
   150                  retry=self.retry,
   151                  timeout=self.timeout,
   152                  client_options=self.client_options,
   153                  metadata=self.metadata)))
   154  
   155    @typehints.with_input_types(Union[str, bytes], Optional[vision.ImageContext])
   156    @typehints.with_output_types(List[vision.AnnotateImageRequest])
   157    def _create_image_annotation_pairs(self, element, context_side_input):
   158      if context_side_input:  # If we have a side input image context, use that
   159        image_context = context_side_input.get(element)
   160      else:
   161        image_context = None
   162  
   163      if isinstance(element, str):
   164  
   165        image = vision.Image(
   166            {'source': vision.ImageSource({'image_uri': element})})
   167  
   168      else:  # Typehint checks only allows str or bytes
   169        image = vision.Image(content=element)
   170  
   171      request = vision.AnnotateImageRequest({
   172          'image': image,
   173          'features': self.features,
   174          'image_context': image_context
   175      })
   176      yield request
   177  
   178  
   179  class AnnotateImageWithContext(AnnotateImage):
   180    """A ``PTransform`` for annotating images using the GCP Vision API.
   181    ref: https://cloud.google.com/vision/docs/
   182    Batches elements together using ``util.BatchElements`` PTransform and sends
   183    each batch of elements to the GCP Vision API.
   184  
   185    Element is a tuple of::
   186  
   187      (Union[str, bytes],
   188      Optional[``vision.ImageContext``])
   189  
   190    where the former is either an URI (e.g. a GCS URI) or bytes
   191    base64-encoded image data.
   192    """
   193    def __init__(
   194        self,
   195        features,
   196        retry=None,
   197        timeout=120,
   198        max_batch_size=None,
   199        min_batch_size=None,
   200        client_options=None,
   201        metadata=None):
   202      """
   203      Args:
   204        features: (List[``vision.Feature``]) Required.
   205          The Vision API features to detect
   206        retry: (google.api_core.retry.Retry) Optional.
   207          A retry object used to retry requests.
   208          If None is specified (default), requests will not be retried.
   209        timeout: (float) Optional.
   210          The time in seconds to wait for the response from the Vision API.
   211          Default is 120.
   212        max_batch_size: (int) Optional.
   213          Maximum number of images to batch in the same request to the Vision API.
   214          Default is 5 (which is also the Vision API max).
   215          This parameter is primarily intended for testing.
   216        min_batch_size: (int) Optional.
   217          Minimum number of images to batch in the same request to the Vision API.
   218          Default is None. This parameter is primarily intended for testing.
   219        client_options:
   220          (Union[dict, google.api_core.client_options.ClientOptions]) Optional.
   221          Client options used to set user options on the client.
   222          API Endpoint should be set through client_options.
   223        metadata: (Optional[Sequence[Tuple[str, str]]]): Optional.
   224          Additional metadata that is provided to the method.
   225      """
   226      super().__init__(
   227          features=features,
   228          retry=retry,
   229          timeout=timeout,
   230          max_batch_size=max_batch_size,
   231          min_batch_size=min_batch_size,
   232          client_options=client_options,
   233          metadata=metadata)
   234  
   235    def expand(self, pvalue):
   236      return (
   237          pvalue
   238          | FlatMap(self._create_image_annotation_pairs)
   239          | util.BatchElements(
   240              min_batch_size=self.min_batch_size,
   241              max_batch_size=self.max_batch_size)
   242          | ParDo(
   243              _ImageAnnotateFn(
   244                  features=self.features,
   245                  retry=self.retry,
   246                  timeout=self.timeout,
   247                  client_options=self.client_options,
   248                  metadata=self.metadata)))
   249  
   250    @typehints.with_input_types(
   251        Tuple[Union[str, bytes], Optional[vision.ImageContext]])
   252    @typehints.with_output_types(List[vision.AnnotateImageRequest])
   253    def _create_image_annotation_pairs(self, element, **kwargs):
   254      element, image_context = element  # Unpack (image, image_context) tuple
   255      if isinstance(element, str):
   256        image = vision.Image(
   257            {'source': vision.ImageSource({'image_uri': element})})
   258      else:  # Typehint checks only allows str or bytes
   259        image = vision.Image({"content": element})
   260  
   261      request = vision.AnnotateImageRequest({
   262          'image': image,
   263          'features': self.features,
   264          'image_context': image_context
   265      })
   266      yield request
   267  
   268  
   269  @typehints.with_input_types(List[vision.AnnotateImageRequest])
   270  class _ImageAnnotateFn(DoFn):
   271    """A DoFn that sends each input element to the GCP Vision API.
   272    Returns ``google.cloud.vision.BatchAnnotateImagesResponse``.
   273    """
   274    def __init__(self, features, retry, timeout, client_options, metadata):
   275      super().__init__()
   276      self._client = None
   277      self.features = features
   278      self.retry = retry
   279      self.timeout = timeout
   280      self.client_options = client_options
   281      self.metadata = metadata
   282      self.counter = Metrics.counter(self.__class__, "API Calls")
   283  
   284    def setup(self):
   285      self._client = get_vision_client(self.client_options)
   286  
   287    def process(self, element, *args, **kwargs):
   288      response = self._client.batch_annotate_images(
   289          requests=element,
   290          retry=self.retry,
   291          timeout=self.timeout,
   292          metadata=self.metadata)
   293      self.counter.inc()
   294      yield response