github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/cookbook/bigtableio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Integration tests for bigtableio."""
    19  # pytype: skip-file
    20  
    21  import datetime
    22  import logging
    23  import random
    24  import string
    25  import unittest
    26  import uuid
    27  from typing import TYPE_CHECKING
    28  from typing import List
    29  
    30  import pytest
    31  import pytz
    32  
    33  import apache_beam as beam
    34  from apache_beam.io.gcp.bigtableio import WriteToBigTable
    35  from apache_beam.metrics.metric import MetricsFilter
    36  from apache_beam.options.pipeline_options import PipelineOptions
    37  from apache_beam.runners.runner import PipelineState
    38  from apache_beam.testing.test_pipeline import TestPipeline
    39  
    40  # Protect against environments where bigtable library is not available.
    41  # pylint: disable=wrong-import-order, wrong-import-position
    42  try:
    43    from google.cloud._helpers import _datetime_from_microseconds
    44    from google.cloud._helpers import _microseconds_from_datetime
    45    from google.cloud._helpers import UTC
    46    from google.cloud.bigtable import row, column_family, Client
    47  except ImportError:
    48    Client = None
    49    UTC = pytz.utc
    50    _microseconds_from_datetime = lambda label_stamp: label_stamp
    51    _datetime_from_microseconds = lambda micro: micro
    52  
    53  if TYPE_CHECKING:
    54    import google.cloud.bigtable.instance
    55  
    56  EXISTING_INSTANCES = []  # type: List[google.cloud.bigtable.instance.Instance]
    57  LABEL_KEY = u'python-bigtable-beam'
    58  label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC)
    59  label_stamp_micros = _microseconds_from_datetime(label_stamp)
    60  LABELS = {LABEL_KEY: str(label_stamp_micros)}
    61  
    62  
    63  class GenerateTestRows(beam.PTransform):
    64    """ A transform test to run write to the Bigtable Table.
    65  
    66    A PTransform that generate a list of `DirectRow` to write it in
    67    Bigtable Table.
    68  
    69    """
    70    def __init__(self, number, project_id=None, instance_id=None, table_id=None):
    71      # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3.
    72      # super().__init__()
    73      beam.PTransform.__init__(self)
    74      self.number = number
    75      self.rand = random.choice(string.ascii_letters + string.digits)
    76      self.column_family_id = 'cf1'
    77      self.beam_options = {
    78          'project_id': project_id,
    79          'instance_id': instance_id,
    80          'table_id': table_id
    81      }
    82  
    83    def _generate(self):
    84      value = ''.join(self.rand for i in range(100))
    85  
    86      for index in range(self.number):
    87        key = "beam_key%s" % ('{0:07}'.format(index))
    88        direct_row = row.DirectRow(row_key=key)
    89        for column_id in range(10):
    90          direct_row.set_cell(
    91              self.column_family_id, ('field%s' % column_id).encode('utf-8'),
    92              value,
    93              datetime.datetime.now())
    94        yield direct_row
    95  
    96    def expand(self, pvalue):
    97      beam_options = self.beam_options
    98      return (
    99          pvalue
   100          | beam.Create(self._generate())
   101          | WriteToBigTable(
   102              beam_options['project_id'],
   103              beam_options['instance_id'],
   104              beam_options['table_id']))
   105  
   106  
   107  @unittest.skipIf(Client is None, 'GCP Bigtable dependencies are not installed')
   108  class BigtableIOWriteTest(unittest.TestCase):
   109    """ Bigtable Write Connector Test
   110  
   111    """
   112    DEFAULT_TABLE_PREFIX = "python-test"
   113    instance_id = DEFAULT_TABLE_PREFIX + "-" + str(uuid.uuid4())[:8]
   114    cluster_id = DEFAULT_TABLE_PREFIX + "-" + str(uuid.uuid4())[:8]
   115    table_id = DEFAULT_TABLE_PREFIX + "-" + str(uuid.uuid4())[:8]
   116    number = 500
   117    LOCATION_ID = "us-east1-b"
   118  
   119    def setUp(self):
   120      try:
   121        from google.cloud.bigtable import enums
   122        self.STORAGE_TYPE = enums.StorageType.HDD
   123        self.INSTANCE_TYPE = enums.Instance.Type.DEVELOPMENT
   124      except ImportError:
   125        self.STORAGE_TYPE = 2
   126        self.INSTANCE_TYPE = 2
   127  
   128      self.test_pipeline = TestPipeline(is_integration_test=True)
   129      self.runner_name = type(self.test_pipeline.runner).__name__
   130      self.project = self.test_pipeline.get_option('project')
   131      self.client = Client(project=self.project, admin=True)
   132  
   133      self._delete_old_instances()
   134  
   135      self.instance = self.client.instance(
   136          self.instance_id, instance_type=self.INSTANCE_TYPE, labels=LABELS)
   137  
   138      if not self.instance.exists():
   139        cluster = self.instance.cluster(
   140            self.cluster_id,
   141            self.LOCATION_ID,
   142            default_storage_type=self.STORAGE_TYPE)
   143        operation = self.instance.create(clusters=[cluster])
   144        operation.result(timeout=300)  # Wait up to 5 min.
   145  
   146      self.table = self.instance.table(self.table_id)
   147  
   148      if not self.table.exists():
   149        max_versions_rule = column_family.MaxVersionsGCRule(2)
   150        column_family_id = 'cf1'
   151        column_families = {column_family_id: max_versions_rule}
   152        self.table.create(column_families=column_families)
   153  
   154    def _delete_old_instances(self):
   155      instances = self.client.list_instances()
   156      EXISTING_INSTANCES[:] = instances
   157  
   158      def age_in_hours(micros):
   159        return (
   160            datetime.datetime.utcnow().replace(tzinfo=UTC) -
   161            (_datetime_from_microseconds(micros))).total_seconds() // 3600
   162  
   163      CLEAN_INSTANCE = [
   164          i for instance in EXISTING_INSTANCES for i in instance if (
   165              LABEL_KEY in i.labels.keys() and
   166              (age_in_hours(int(i.labels[LABEL_KEY])) >= 2))
   167      ]
   168  
   169      if CLEAN_INSTANCE:
   170        for instance in CLEAN_INSTANCE:
   171          instance.delete()
   172  
   173    def tearDown(self):
   174      if self.instance.exists():
   175        self.instance.delete()
   176  
   177    @pytest.mark.it_postcommit
   178    def test_bigtable_write(self):
   179      number = self.number
   180      pipeline_args = self.test_pipeline.options_list
   181      pipeline_options = PipelineOptions(pipeline_args)
   182  
   183      with beam.Pipeline(options=pipeline_options) as pipeline:
   184        config_data = {
   185            'project_id': self.project,
   186            'instance_id': self.instance_id,
   187            'table_id': self.table_id
   188        }
   189        _ = (
   190            pipeline
   191            | 'Generate Direct Rows' >> GenerateTestRows(number, **config_data))
   192  
   193      assert pipeline.result.state == PipelineState.DONE
   194  
   195      read_rows = self.table.read_rows()
   196      assert len([_ for _ in read_rows]) == number
   197  
   198      if not hasattr(pipeline.result, 'has_job') or pipeline.result.has_job:
   199        read_filter = MetricsFilter().with_name('Written Row')
   200        query_result = pipeline.result.metrics().query(read_filter)
   201        if query_result['counters']:
   202          read_counter = query_result['counters'][0]
   203  
   204          logging.info('Number of Rows: %d', read_counter.committed)
   205          assert read_counter.committed == number
   206  
   207  
   208  if __name__ == '__main__':
   209    logging.getLogger().setLevel(logging.INFO)
   210    unittest.main()