github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/cloud_dlp_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Integration tests for Google Cloud Video Intelligence API transforms."""
    19  
    20  import logging
    21  import unittest
    22  
    23  import pytest
    24  
    25  import apache_beam as beam
    26  from apache_beam.testing.test_pipeline import TestPipeline
    27  from apache_beam.testing.util import assert_that
    28  from apache_beam.testing.util import equal_to
    29  
    30  # Protect against environments with google-cloud-dlp unavailable.
    31  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    32  try:
    33    from google.cloud import dlp_v2
    34  except ImportError:
    35    dlp_v2 = None
    36  else:
    37    from apache_beam.ml.gcp.cloud_dlp import InspectForDetails
    38    from apache_beam.ml.gcp.cloud_dlp import MaskDetectedDetails
    39  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    40  
    41  _LOGGER = logging.getLogger(__name__)
    42  
    43  INSPECT_CONFIG = {"info_types": [{"name": "EMAIL_ADDRESS"}]}
    44  
    45  DEIDENTIFY_CONFIG = {
    46      "info_type_transformations": {
    47          "transformations": [{
    48              "primitive_transformation": {
    49                  "character_mask_config": {
    50                      "masking_character": '#'
    51                  }
    52              }
    53          }]
    54      }
    55  }
    56  
    57  
    58  def extract_inspection_results(response):
    59    yield beam.pvalue.TaggedOutput('info_type', response[0].info_type.name)
    60  
    61  
    62  @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed')
    63  class CloudDLPIT(unittest.TestCase):
    64    def setUp(self):
    65      self.test_pipeline = TestPipeline(is_integration_test=True)
    66      self.runner_name = type(self.test_pipeline.runner).__name__
    67      self.project = self.test_pipeline.get_option('project')
    68  
    69    @pytest.mark.it_postcommit
    70    def test_deidentification(self):
    71      with TestPipeline(is_integration_test=True) as p:
    72        output = (
    73            p | beam.Create(["mary.sue@example.com"])
    74            | MaskDetectedDetails(
    75                project=self.project,
    76                deidentification_config=DEIDENTIFY_CONFIG,
    77                inspection_config=INSPECT_CONFIG))
    78        assert_that(output, equal_to(['####################']))
    79  
    80    @pytest.mark.it_postcommit
    81    def test_inspection(self):
    82      with TestPipeline(is_integration_test=True) as p:
    83        output = (
    84            p | beam.Create(["mary.sue@example.com"])
    85            | InspectForDetails(
    86                project=self.project, inspection_config=INSPECT_CONFIG)
    87            | beam.ParDo(extract_inspection_results).with_outputs(
    88                'quote', 'info_type'))
    89        assert_that(output.info_type, equal_to(['EMAIL_ADDRESS']), 'Type matches')
    90  
    91  
    92  if __name__ == '__main__':
    93    logging.getLogger().setLevel(logging.WARN)
    94    unittest.main()