github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/cloud_dlp_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for Google Cloud Video Intelligence API transforms."""
    19  
    20  import logging
    21  import unittest
    22  
    23  import mock
    24  
    25  import apache_beam as beam
    26  from apache_beam.metrics import Metrics
    27  from apache_beam.testing.test_pipeline import TestPipeline
    28  
    29  # Protect against environments with google-cloud-dlp unavailable.
    30  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    31  try:
    32    from google.cloud import dlp_v2
    33  except ImportError:
    34    dlp_v2 = None
    35  else:
    36    from apache_beam.ml.gcp.cloud_dlp import InspectForDetails
    37    from apache_beam.ml.gcp.cloud_dlp import MaskDetectedDetails
    38    from apache_beam.ml.gcp.cloud_dlp import _DeidentifyFn
    39    from apache_beam.ml.gcp.cloud_dlp import _InspectFn
    40    from google.cloud.dlp_v2.types import dlp
    41  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    42  
    43  _LOGGER = logging.getLogger(__name__)
    44  
    45  
    46  @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed')
    47  class TestDeidentifyText(unittest.TestCase):
    48    def test_exception_raised_when_no_config_is_provided(self):
    49      with self.assertRaises(ValueError):
    50        with TestPipeline() as p:
    51          # pylint: disable=expression-not-assigned
    52          p | MaskDetectedDetails()
    53  
    54  
    55  @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed')
    56  class TestDeidentifyFn(unittest.TestCase):
    57    def test_deidentify_called(self):
    58      class ClientMock(object):
    59        def deidentify_content(self, *args, **kwargs):
    60          # Check that we can marshal a valid request.
    61          dlp.DeidentifyContentRequest(kwargs['request'])
    62  
    63          called = Metrics.counter('test_deidentify_text', 'called')
    64          called.inc()
    65          operation = mock.Mock()
    66          item = mock.Mock()
    67          item.value = [None]
    68          operation.item = item
    69          return operation
    70  
    71        def common_project_path(self, *args):
    72          return 'test'
    73  
    74      with mock.patch('google.cloud.dlp_v2.DlpServiceClient', ClientMock):
    75        p = TestPipeline()
    76        config = {
    77            "deidentify_config": {
    78                "info_type_transformations": {
    79                    "transformations": [{
    80                        "primitive_transformation": {
    81                            "character_mask_config": {
    82                                "masking_character": '#'
    83                            }
    84                        }
    85                    }]
    86                }
    87            }
    88        }
    89        # pylint: disable=expression-not-assigned
    90        (
    91            p
    92            | beam.Create(['mary.sue@example.com', 'john.doe@example.com'])
    93            | beam.ParDo(_DeidentifyFn(config=config)))
    94        result = p.run()
    95        result.wait_until_finish()
    96      called = result.metrics().query()['counters'][0]
    97      self.assertEqual(called.result, 2)
    98  
    99  
   100  @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed')
   101  class TestInspectText(unittest.TestCase):
   102    def test_exception_raised_then_no_config_provided(self):
   103      with self.assertRaises(ValueError):
   104        with TestPipeline() as p:
   105          #pylint: disable=expression-not-assigned
   106          p | InspectForDetails()
   107  
   108  
   109  @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed')
   110  class TestInspectFn(unittest.TestCase):
   111    def test_inspect_called(self):
   112      class ClientMock(object):
   113        def inspect_content(self, *args, **kwargs):
   114          # Check that we can marshal a valid request.
   115          dlp.InspectContentRequest(kwargs['request'])
   116  
   117          called = Metrics.counter('test_inspect_text', 'called')
   118          called.inc()
   119          operation = mock.Mock()
   120          operation.result = mock.Mock()
   121          operation.result.findings = [None]
   122          return operation
   123  
   124        def common_project_path(self, *args):
   125          return 'test'
   126  
   127      with mock.patch('google.cloud.dlp_v2.DlpServiceClient', ClientMock):
   128        p = TestPipeline()
   129        config = {"inspect_config": {"info_types": [{"name": "EMAIL_ADDRESS"}]}}
   130        # pylint: disable=expression-not-assigned
   131        (
   132            p
   133            | beam.Create(['mary.sue@example.com', 'john.doe@example.com'])
   134            | beam.ParDo(_InspectFn(config=config)))
   135        result = p.run()
   136        result.wait_until_finish()
   137        called = result.metrics().query()['counters'][0]
   138        self.assertEqual(called.result, 2)
   139  
   140  
   141  if __name__ == '__main__':
   142    logging.getLogger().setLevel(logging.INFO)
   143    unittest.main()