github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/videointelligenceml_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for videointelligenceml."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import unittest
    24  
    25  import mock
    26  
    27  import apache_beam as beam
    28  from apache_beam.metrics import MetricsFilter
    29  from apache_beam.typehints.decorators import TypeCheckError
    30  
    31  # Protect against environments where video intelligence lib is not available.
    32  # pylint: disable=ungrouped-imports
    33  try:
    34    from google.cloud.videointelligence import VideoIntelligenceServiceClient
    35    from google.cloud import videointelligence
    36    from apache_beam.ml.gcp import videointelligenceml
    37  except ImportError:
    38    VideoIntelligenceServiceClient = None
    39  
    40  
    41  @unittest.skipIf(
    42      VideoIntelligenceServiceClient is None,
    43      'Video intelligence dependencies are not installed')
    44  class VideoIntelligenceTest(unittest.TestCase):
    45    def setUp(self):
    46      self._mock_client = mock.Mock()
    47      self.m2 = mock.Mock()
    48      self.m2.result.return_value = None
    49      self._mock_client.annotate_video.return_value = self.m2
    50      self.features = [videointelligence.Feature.LABEL_DETECTION]
    51      self.location_id = 'us-west1'
    52      config = videointelligence.SpeechTranscriptionConfig(
    53          language_code='en-US', enable_automatic_punctuation=True)
    54      self.video_ctx = videointelligence.VideoContext(
    55          speech_transcription_config=config)
    56  
    57    def test_AnnotateVideo_with_side_input_context(self):
    58      videos_to_annotate = [
    59          'gs://cloud-samples-data/video/cat.mp4',
    60          'gs://some-other-video/sample.mp4',
    61          'gs://some-other-video/sample_2.mp4'
    62      ]
    63      video_contexts = [
    64          ('gs://cloud-samples-data/video/cat.mp4', self.video_ctx),
    65          ('gs://some-other-video/sample.mp4', self.video_ctx),
    66      ]
    67  
    68      expected_counter = len(videos_to_annotate)
    69      with mock.patch.object(videointelligenceml,
    70                             'get_videointelligence_client',
    71                             return_value=self._mock_client):
    72        p = beam.Pipeline()
    73        context_side_input = (p | "Video contexts" >> beam.Create(video_contexts))
    74  
    75        _ = (
    76            p
    77            | "Create data" >> beam.Create(videos_to_annotate)
    78            | "Annotate video" >> videointelligenceml.AnnotateVideo(
    79                self.features,
    80                context_side_input=beam.pvalue.AsDict(context_side_input)))
    81        result = p.run()
    82        result.wait_until_finish()
    83  
    84        read_filter = MetricsFilter().with_name('API Calls')
    85        query_result = result.metrics().query(read_filter)
    86        if query_result['counters']:
    87          read_counter = query_result['counters'][0]
    88          self.assertTrue(read_counter.committed == expected_counter)
    89  
    90    def test_AnnotateVideo_URIs(self):
    91      videos_to_annotate = [
    92          'gs://cloud-samples-data/video/cat.mp4',
    93          'gs://cloud-samples-data/video/cat.mp4'
    94      ]
    95      expected_counter = len(videos_to_annotate)
    96      with mock.patch.object(videointelligenceml,
    97                             'get_videointelligence_client',
    98                             return_value=self._mock_client):
    99        p = beam.Pipeline()
   100        _ = (
   101            p
   102            | "Create data" >> beam.Create(videos_to_annotate)
   103            |
   104            "Annotate video" >> videointelligenceml.AnnotateVideo(self.features))
   105        result = p.run()
   106        result.wait_until_finish()
   107  
   108        read_filter = MetricsFilter().with_name('API Calls')
   109        query_result = result.metrics().query(read_filter)
   110        if query_result['counters']:
   111          read_counter = query_result['counters'][0]
   112          self.assertTrue(read_counter.committed == expected_counter)
   113  
   114    def test_AnnotateVideoWithContext_b64_content(self):
   115      base_64_encoded_video = \
   116        b'YmVnaW4gNjQ0IGNhdC12aWRlby5tcDRNICAgICgmOVQ+NyFNPCMwUi4uZmFrZV92aWRlb'
   117      videos_to_annotate = [
   118          (base_64_encoded_video, self.video_ctx),
   119          (base_64_encoded_video, None),
   120          (base_64_encoded_video, self.video_ctx),
   121      ]
   122      expected_counter = len(videos_to_annotate)
   123      with mock.patch.object(videointelligenceml,
   124                             'get_videointelligence_client',
   125                             return_value=self._mock_client):
   126        p = beam.Pipeline()
   127        _ = (
   128            p
   129            | "Create data" >> beam.Create(videos_to_annotate)
   130            | "Annotate video" >> videointelligenceml.AnnotateVideoWithContext(
   131                self.features))
   132        result = p.run()
   133        result.wait_until_finish()
   134  
   135      read_filter = MetricsFilter().with_name('API Calls')
   136      query_result = result.metrics().query(read_filter)
   137      if query_result['counters']:
   138        read_counter = query_result['counters'][0]
   139        self.assertTrue(read_counter.committed == expected_counter)
   140  
   141    def test_AnnotateVideo_b64_content(self):
   142      base_64_encoded_video = \
   143        b'YmVnaW4gNjQ0IGNhdC12aWRlby5tcDRNICAgICgmOVQ+NyFNPCMwUi4uZmFrZV92aWRlb'
   144      videos_to_annotate = [
   145          base_64_encoded_video,
   146          base_64_encoded_video,
   147          base_64_encoded_video,
   148      ]
   149      expected_counter = len(videos_to_annotate)
   150      with mock.patch.object(videointelligenceml,
   151                             'get_videointelligence_client',
   152                             return_value=self._mock_client):
   153        p = beam.Pipeline()
   154        _ = (
   155            p
   156            | "Create data" >> beam.Create(videos_to_annotate)
   157            |
   158            "Annotate video" >> videointelligenceml.AnnotateVideo(self.features))
   159        result = p.run()
   160        result.wait_until_finish()
   161  
   162        read_filter = MetricsFilter().with_name('API Calls')
   163        query_result = result.metrics().query(read_filter)
   164        if query_result['counters']:
   165          read_counter = query_result['counters'][0]
   166          self.assertTrue(read_counter.committed == expected_counter)
   167  
   168    def test_AnnotateVideoWithContext_bad_input(self):
   169      """AnnotateVideoWithContext should not accept videos without context"""
   170      videos_to_annotate = [
   171          'gs://cloud-samples-data/video/cat.mp4',
   172          'gs://cloud-samples-data/video/cat.mp4'
   173      ]
   174      with mock.patch.object(videointelligenceml,
   175                             'get_videointelligence_client',
   176                             return_value=self._mock_client):
   177        with self.assertRaises(TypeCheckError):
   178          p = beam.Pipeline()
   179          _ = (
   180              p
   181              | "Create data" >> beam.Create(videos_to_annotate)
   182              | "Annotate video" >> videointelligenceml.AnnotateVideoWithContext(
   183                  self.features))
   184          result = p.run()
   185          result.wait_until_finish()
   186  
   187    def test_AnnotateVideo_bad_input(self):
   188      videos_to_annotate = [123456789, 123456789, 123456789]
   189      with mock.patch.object(videointelligenceml,
   190                             'get_videointelligence_client',
   191                             return_value=self._mock_client):
   192        with self.assertRaises(TypeCheckError):
   193          p = beam.Pipeline()
   194          _ = (
   195              p
   196              | "Create data" >> beam.Create(videos_to_annotate)
   197              | "Annotate video" >> videointelligenceml.AnnotateVideo(
   198                  self.features))
   199          result = p.run()
   200          result.wait_until_finish()
   201  
   202  
   203  if __name__ == '__main__':
   204    logging.getLogger().setLevel(logging.INFO)
   205    unittest.main()