github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/experimental/spannerio_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import datetime
    19  import logging
    20  import random
    21  import string
    22  import typing
    23  import unittest
    24  
    25  import mock
    26  
    27  import apache_beam as beam
    28  from apache_beam.metrics.metric import MetricsFilter
    29  from apache_beam.testing.test_pipeline import TestPipeline
    30  from apache_beam.testing.util import assert_that
    31  from apache_beam.testing.util import equal_to
    32  
    33  # Protect against environments where spanner library is not available.
    34  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    35  # pylint: disable=unused-import
    36  try:
    37    from google.cloud import spanner
    38    from apache_beam.io.gcp.experimental.spannerio import create_transaction
    39    from apache_beam.io.gcp.experimental.spannerio import ReadOperation
    40    from apache_beam.io.gcp.experimental.spannerio import ReadFromSpanner
    41    from apache_beam.io.gcp.experimental.spannerio import WriteMutation
    42    from apache_beam.io.gcp.experimental.spannerio import MutationGroup
    43    from apache_beam.io.gcp.experimental.spannerio import WriteToSpanner
    44    from apache_beam.io.gcp.experimental.spannerio import _BatchFn
    45    from apache_beam.io.gcp import resource_identifiers
    46    from apache_beam.metrics import monitoring_infos
    47    from apache_beam.metrics.execution import MetricsEnvironment
    48    from apache_beam.metrics.metricbase import MetricName
    49  except ImportError:
    50    spanner = None
    51  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    52  # pylint: enable=unused-import
    53  
    54  MAX_DB_NAME_LENGTH = 30
    55  TEST_PROJECT_ID = 'apache-beam-testing'
    56  TEST_INSTANCE_ID = 'beam-test'
    57  TEST_DATABASE_PREFIX = 'spanner-testdb-'
    58  FAKE_TRANSACTION_INFO = {"session_id": "qwerty", "transaction_id": "qwerty"}
    59  FAKE_ROWS = [[1, 'Alice'], [2, 'Bob'], [3, 'Carl'], [4, 'Dan'], [5, 'Evan'],
    60               [6, 'Floyd']]
    61  
    62  
    63  def _generate_database_name():
    64    mask = string.ascii_lowercase + string.digits
    65    length = MAX_DB_NAME_LENGTH - 1 - len(TEST_DATABASE_PREFIX)
    66    return TEST_DATABASE_PREFIX + ''.join(
    67        random.choice(mask) for i in range(length))
    68  
    69  
    70  def _generate_test_data():
    71    mask = string.ascii_lowercase + string.digits
    72    length = 100
    73    return [(
    74        'users', ['Key', 'Value'],
    75        [(x, ''.join(random.choice(mask) for _ in range(length)))
    76         for x in range(1, 5)])]
    77  
    78  
    79  @unittest.skipIf(spanner is None, 'GCP dependencies are not installed.')
    80  @mock.patch('apache_beam.io.gcp.experimental.spannerio.Client')
    81  @mock.patch('apache_beam.io.gcp.experimental.spannerio.BatchSnapshot')
    82  class SpannerReadTest(unittest.TestCase):
    83    def test_read_with_query_batch(
    84        self, mock_batch_snapshot_class, mock_client_class):
    85  
    86      mock_snapshot_instance = mock.MagicMock()
    87      mock_snapshot_instance.generate_query_batches.return_value = [{
    88          'query': {
    89              'sql': 'SELECT * FROM users'
    90          }, 'partition': 'test_partition'
    91      } for _ in range(3)]
    92      mock_snapshot_instance.to_dict.return_value = {}
    93  
    94      mock_batch_snapshot_instance = mock.MagicMock()
    95      # Prepare process_query_batch return results for three pipelines
    96      mock_batch_snapshot_instance.process_query_batch.side_effect = [
    97          FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
    98      ] * 3
    99      mock_client_class.return_value.instance.return_value.database.return_value \
   100          .batch_snapshot.return_value = mock_snapshot_instance
   101      mock_batch_snapshot_class.from_dict.return_value \
   102          = mock_batch_snapshot_instance
   103  
   104      ro = [ReadOperation.query("Select * from users")]
   105      with TestPipeline() as pipeline:
   106        read = (
   107            pipeline
   108            | 'read' >> ReadFromSpanner(
   109                TEST_PROJECT_ID,
   110                TEST_INSTANCE_ID,
   111                _generate_database_name(),
   112                sql="SELECT * FROM users"))
   113        assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
   114  
   115      with TestPipeline() as pipeline:
   116        readall = (
   117            pipeline
   118            | 'read all' >> ReadFromSpanner(
   119                TEST_PROJECT_ID,
   120                TEST_INSTANCE_ID,
   121                _generate_database_name(),
   122                read_operations=ro))
   123        assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
   124  
   125      with TestPipeline() as pipeline:
   126        readpipeline = (
   127            pipeline
   128            | 'create reads' >> beam.Create(ro)
   129            | 'reads' >> ReadFromSpanner(
   130                TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name()))
   131        assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
   132  
   133      # three pipelines
   134      self.assertEqual(
   135          mock_snapshot_instance.generate_query_batches.call_count, 3)
   136      # three pipelines, each called three times
   137      self.assertEqual(
   138          mock_batch_snapshot_instance.process_query_batch.call_count, 3 * 3)
   139  
   140    def test_read_with_table_batch(
   141        self, mock_batch_snapshot_class, mock_client_class):
   142      mock_snapshot_instance = mock.MagicMock()
   143      mock_snapshot_instance.generate_read_batches.return_value = [{
   144          'read': {
   145              'table': 'users',
   146              'keyset': {
   147                  'all': True
   148              },
   149              'columns': ['Key', 'Value'],
   150              'index': ''
   151          },
   152          'partition': 'test_partition'
   153      } for _ in range(3)]
   154      mock_snapshot_instance.to_dict.return_value = {}
   155  
   156      mock_batch_snapshot_instance = mock.MagicMock()
   157      # Prepare process_read_batch return results for three pipelines
   158      mock_batch_snapshot_instance.process_read_batch.side_effect = [
   159          FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
   160      ] * 3
   161  
   162      mock_client_class.return_value.instance.return_value.database.return_value \
   163          .batch_snapshot.return_value = mock_snapshot_instance
   164      mock_batch_snapshot_class.from_dict.return_value \
   165          = mock_batch_snapshot_instance
   166  
   167      ro = [ReadOperation.table("users", ["Key", "Value"])]
   168      with TestPipeline() as pipeline:
   169        read = (
   170            pipeline
   171            | 'read' >> ReadFromSpanner(
   172                TEST_PROJECT_ID,
   173                TEST_INSTANCE_ID,
   174                _generate_database_name(),
   175                table="users",
   176                columns=["Key", "Value"]))
   177        assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
   178  
   179      with TestPipeline() as pipeline:
   180        readall = (
   181            pipeline
   182            | 'read all' >> ReadFromSpanner(
   183                TEST_PROJECT_ID,
   184                TEST_INSTANCE_ID,
   185                _generate_database_name(),
   186                read_operations=ro))
   187        assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
   188  
   189      with TestPipeline() as pipeline:
   190        readpipeline = (
   191            pipeline
   192            | 'create reads' >> beam.Create(ro)
   193            | 'reads' >> ReadFromSpanner(
   194                TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name()))
   195        assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
   196  
   197      # three pipelines
   198      self.assertEqual(mock_snapshot_instance.generate_read_batches.call_count, 3)
   199      # three pipelines, each called three times
   200      self.assertEqual(
   201          mock_batch_snapshot_instance.process_read_batch.call_count, 3 * 3)
   202  
   203      with TestPipeline() as pipeline, self.assertRaises(ValueError):
   204        # Test the exception raised at pipeline construction time, when user
   205        # passes the read operations in the constructor and also in the pipeline
   206        _ = (
   207            pipeline | 'reads error' >> ReadFromSpanner(
   208                project_id=TEST_PROJECT_ID,
   209                instance_id=TEST_INSTANCE_ID,
   210                database_id=_generate_database_name(),
   211                table="users"))
   212  
   213    def test_read_with_index(self, mock_batch_snapshot_class, mock_client_class):
   214      mock_snapshot_instance = mock.MagicMock()
   215      mock_snapshot_instance.generate_read_batches.return_value = [{
   216          'read': {
   217              'table': 'users',
   218              'keyset': {
   219                  'all': True
   220              },
   221              'columns': ['Key', 'Value'],
   222              'index': ''
   223          },
   224          'partition': 'test_partition'
   225      } for _ in range(3)]
   226  
   227      mock_batch_snapshot_instance = mock.MagicMock()
   228      # Prepare process_read_batch return results for three pipelines
   229      mock_batch_snapshot_instance.process_read_batch.side_effect = [
   230          FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
   231      ] * 3
   232  
   233      mock_snapshot_instance.to_dict.return_value = {}
   234  
   235      mock_client_class.return_value.instance.return_value.database.return_value \
   236          .batch_snapshot.return_value = mock_snapshot_instance
   237      mock_batch_snapshot_class.from_dict.return_value \
   238          = mock_batch_snapshot_instance
   239  
   240      ro = [ReadOperation.table("users", ["Key", "Value"], index="Key")]
   241      with TestPipeline() as pipeline:
   242        read = (
   243            pipeline
   244            | 'read' >> ReadFromSpanner(
   245                TEST_PROJECT_ID,
   246                TEST_INSTANCE_ID,
   247                _generate_database_name(),
   248                table="users",
   249                columns=["Key", "Value"]))
   250        assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
   251  
   252      with TestPipeline() as pipeline:
   253        readall = (
   254            pipeline
   255            | 'read all' >> ReadFromSpanner(
   256                TEST_PROJECT_ID,
   257                TEST_INSTANCE_ID,
   258                _generate_database_name(),
   259                read_operations=ro))
   260        assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
   261  
   262      with TestPipeline() as pipeline:
   263        readpipeline = (
   264            pipeline
   265            | 'create reads' >> beam.Create(ro)
   266            | 'reads' >> ReadFromSpanner(
   267                TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name()))
   268        assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
   269  
   270      # three pipelines
   271      self.assertEqual(mock_snapshot_instance.generate_read_batches.call_count, 3)
   272      # three pipelines, each called three times
   273      self.assertEqual(
   274          mock_batch_snapshot_instance.process_read_batch.call_count, 3 * 3)
   275  
   276      with TestPipeline() as pipeline, self.assertRaises(ValueError):
   277        # Test the exception raised at pipeline construction time, when user
   278        # passes the read operations in the constructor and also in the pipeline.
   279        _ = (
   280            pipeline | 'reads error' >> ReadFromSpanner(
   281                project_id=TEST_PROJECT_ID,
   282                instance_id=TEST_INSTANCE_ID,
   283                database_id=_generate_database_name(),
   284                table="users"))
   285  
   286    def test_read_with_transaction(
   287        self, mock_batch_snapshot_class, mock_client_class):
   288      mock_snapshot_instance = mock.MagicMock()
   289      mock_snapshot_instance.to_dict.return_value = FAKE_TRANSACTION_INFO
   290  
   291      mock_transaction_instance = mock.MagicMock()
   292      mock_transaction_instance.execute_sql.return_value = FAKE_ROWS
   293      mock_transaction_instance.read.return_value = FAKE_ROWS
   294  
   295      mock_client_class.return_value.instance.return_value.database.return_value \
   296          .batch_snapshot.return_value = mock_snapshot_instance
   297      mock_client_class.return_value.instance.return_value.database.return_value \
   298          .session.return_value.transaction.return_value.__enter__.return_value \
   299              = mock_transaction_instance
   300  
   301      ro = [ReadOperation.query("Select * from users")]
   302  
   303      with TestPipeline() as p:
   304        transaction = (
   305            p | create_transaction(
   306                project_id=TEST_PROJECT_ID,
   307                instance_id=TEST_INSTANCE_ID,
   308                database_id=_generate_database_name(),
   309                exact_staleness=datetime.timedelta(seconds=10)))
   310  
   311        read_query = (
   312            p | 'with query' >> ReadFromSpanner(
   313                project_id=TEST_PROJECT_ID,
   314                instance_id=TEST_INSTANCE_ID,
   315                database_id=_generate_database_name(),
   316                transaction=transaction,
   317                sql="Select * from users"))
   318        assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery')
   319  
   320        read_table = (
   321            p | 'with table' >> ReadFromSpanner(
   322                project_id=TEST_PROJECT_ID,
   323                instance_id=TEST_INSTANCE_ID,
   324                database_id=_generate_database_name(),
   325                transaction=transaction,
   326                table="users",
   327                columns=["Key", "Value"]))
   328        assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable')
   329  
   330        read_indexed_table = (
   331            p | 'with index' >> ReadFromSpanner(
   332                project_id=TEST_PROJECT_ID,
   333                instance_id=TEST_INSTANCE_ID,
   334                database_id=_generate_database_name(),
   335                transaction=transaction,
   336                table="users",
   337                index="Key",
   338                columns=["Key", "Value"]))
   339        assert_that(
   340            read_indexed_table, equal_to(FAKE_ROWS), label='checkTableIndex')
   341  
   342        read = (
   343            p | 'read all' >> ReadFromSpanner(
   344                TEST_PROJECT_ID,
   345                TEST_INSTANCE_ID,
   346                _generate_database_name(),
   347                transaction=transaction,
   348                read_operations=ro))
   349        assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll')
   350  
   351        read_pipeline = (
   352            p
   353            | 'create read operations' >> beam.Create(ro)
   354            | 'reads' >> ReadFromSpanner(
   355                TEST_PROJECT_ID,
   356                TEST_INSTANCE_ID,
   357                _generate_database_name(),
   358                transaction=transaction))
   359        assert_that(read_pipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
   360  
   361      # transaction setup once
   362      self.assertEqual(mock_snapshot_instance.to_dict.call_count, 1)
   363      # three pipelines called execute_sql
   364      self.assertEqual(mock_transaction_instance.execute_sql.call_count, 3)
   365      # two pipelines called read
   366      self.assertEqual(mock_transaction_instance.read.call_count, 2)
   367  
   368      with TestPipeline() as p, self.assertRaises(ValueError):
   369        # Test the exception raised at pipeline construction time, when user
   370        # passes the read operations in the constructor and also in the pipeline.
   371        transaction = (
   372            p | create_transaction(
   373                project_id=TEST_PROJECT_ID,
   374                instance_id=TEST_INSTANCE_ID,
   375                database_id=_generate_database_name(),
   376                exact_staleness=datetime.timedelta(seconds=10)))
   377        _ = (
   378            p
   379            | 'create read operations2' >> beam.Create(ro)
   380            | 'reads with error' >> ReadFromSpanner(
   381                TEST_PROJECT_ID,
   382                TEST_INSTANCE_ID,
   383                _generate_database_name(),
   384                transaction=transaction,
   385                read_operations=ro))
   386  
   387    def test_invalid_transaction(
   388        self, mock_batch_snapshot_class, mock_client_class):
   389      # test exception raises at pipeline execution time
   390      with self.assertRaises(ValueError), TestPipeline() as p:
   391        transaction = (
   392            p | beam.Create([{
   393                "invalid": "transaction"
   394            }]).with_output_types(typing.Any))
   395        _ = (
   396            p | 'with query' >> ReadFromSpanner(
   397                project_id=TEST_PROJECT_ID,
   398                instance_id=TEST_INSTANCE_ID,
   399                database_id=_generate_database_name(),
   400                transaction=transaction,
   401                sql="Select * from users"))
   402  
   403    def test_display_data(self, *args):
   404      dd_sql = ReadFromSpanner(
   405          project_id=TEST_PROJECT_ID,
   406          instance_id=TEST_INSTANCE_ID,
   407          database_id=_generate_database_name(),
   408          sql="Select * from users").display_data()
   409  
   410      dd_table = ReadFromSpanner(
   411          project_id=TEST_PROJECT_ID,
   412          instance_id=TEST_INSTANCE_ID,
   413          database_id=_generate_database_name(),
   414          table="users",
   415          columns=['id', 'name']).display_data()
   416  
   417      dd_transaction = ReadFromSpanner(
   418          project_id=TEST_PROJECT_ID,
   419          instance_id=TEST_INSTANCE_ID,
   420          database_id=_generate_database_name(),
   421          table="users",
   422          columns=['id', 'name'],
   423          transaction={
   424              "transaction_id": "test123", "session_id": "test456"
   425          }).display_data()
   426  
   427      self.assertTrue("sql" in dd_sql)
   428      self.assertTrue("table" in dd_table)
   429      self.assertTrue("table" in dd_transaction)
   430      self.assertTrue("transaction" in dd_transaction)
   431  
   432  
   433  @unittest.skipIf(spanner is None, 'GCP dependencies are not installed.')
   434  @mock.patch('apache_beam.io.gcp.experimental.spannerio.Client')
   435  @mock.patch('google.cloud.spanner_v1.database.BatchCheckout')
   436  class SpannerWriteTest(unittest.TestCase):
   437    def test_spanner_write(self, mock_batch_snapshot_class, mock_batch_checkout):
   438      ks = spanner.KeySet(keys=[[1233], [1234]])
   439  
   440      mutations = [
   441          WriteMutation.delete("roles", ks),
   442          WriteMutation.insert(
   443              "roles", ("key", "rolename"), [('1233', "mutations-inset-1233")]),
   444          WriteMutation.insert(
   445              "roles", ("key", "rolename"), [('1234', "mutations-inset-1234")]),
   446          WriteMutation.update(
   447              "roles", ("key", "rolename"),
   448              [('1234', "mutations-inset-1233-updated")]),
   449      ]
   450  
   451      p = TestPipeline()
   452      _ = (
   453          p
   454          | beam.Create(mutations)
   455          | WriteToSpanner(
   456              project_id=TEST_PROJECT_ID,
   457              instance_id=TEST_INSTANCE_ID,
   458              database_id=_generate_database_name(),
   459              max_batch_size_bytes=1024))
   460      res = p.run()
   461      res.wait_until_finish()
   462  
   463      metric_results = res.metrics().query(
   464          MetricsFilter().with_name("SpannerBatches"))
   465      batches_counter = metric_results['counters'][0]
   466  
   467      self.assertEqual(batches_counter.committed, 2)
   468      self.assertEqual(batches_counter.attempted, 2)
   469  
   470    def test_spanner_bundles_size(
   471        self, mock_batch_snapshot_class, mock_batch_checkout):
   472      ks = spanner.KeySet(keys=[[1233], [1234]])
   473      mutations = [
   474          WriteMutation.delete("roles", ks),
   475          WriteMutation.insert(
   476              "roles", ("key", "rolename"), [('1234', "mutations-inset-1234")])
   477      ] * 50
   478      p = TestPipeline()
   479      _ = (
   480          p
   481          | beam.Create(mutations)
   482          | WriteToSpanner(
   483              project_id=TEST_PROJECT_ID,
   484              instance_id=TEST_INSTANCE_ID,
   485              database_id=_generate_database_name(),
   486              max_batch_size_bytes=1024))
   487      res = p.run()
   488      res.wait_until_finish()
   489  
   490      metric_results = res.metrics().query(
   491          MetricsFilter().with_name('SpannerBatches'))
   492      batches_counter = metric_results['counters'][0]
   493  
   494      self.assertEqual(batches_counter.committed, 53)
   495      self.assertEqual(batches_counter.attempted, 53)
   496  
   497    def test_spanner_write_mutation_groups(
   498        self, mock_batch_snapshot_class, mock_batch_checkout):
   499      ks = spanner.KeySet(keys=[[1233], [1234]])
   500      mutation_groups = [
   501          MutationGroup([
   502              WriteMutation.insert(
   503                  "roles", ("key", "rolename"),
   504                  [('9001233', "mutations-inset-1233")]),
   505              WriteMutation.insert(
   506                  "roles", ("key", "rolename"),
   507                  [('9001234', "mutations-inset-1234")])
   508          ]),
   509          MutationGroup([
   510              WriteMutation.update(
   511                  "roles", ("key", "rolename"),
   512                  [('9001234', "mutations-inset-9001233-updated")])
   513          ]),
   514          MutationGroup([WriteMutation.delete("roles", ks)])
   515      ]
   516  
   517      p = TestPipeline()
   518      _ = (
   519          p
   520          | beam.Create(mutation_groups)
   521          | WriteToSpanner(
   522              project_id=TEST_PROJECT_ID,
   523              instance_id=TEST_INSTANCE_ID,
   524              database_id=_generate_database_name(),
   525              max_batch_size_bytes=100))
   526      res = p.run()
   527      res.wait_until_finish()
   528  
   529      metric_results = res.metrics().query(
   530          MetricsFilter().with_name('SpannerBatches'))
   531      batches_counter = metric_results['counters'][0]
   532  
   533      self.assertEqual(batches_counter.committed, 3)
   534      self.assertEqual(batches_counter.attempted, 3)
   535  
   536    def test_batch_byte_size(
   537        self, mock_batch_snapshot_class, mock_batch_checkout):
   538  
   539      # each mutation group byte size is 58 bytes.
   540      mutation_group = [
   541          MutationGroup([
   542              WriteMutation.insert(
   543                  "roles",
   544                  ("key", "rolename"), [('1234', "mutations-inset-1234")])
   545          ])
   546      ] * 50
   547  
   548      with TestPipeline() as p:
   549        # the total 50 mutation group size will be 2900 (58 * 50)
   550        # if we want to make two batches, so batch size should be 1450 (2900 / 2)
   551        # and each bach should contains 25 mutations.
   552        res = (
   553            p | beam.Create(mutation_group)
   554            | beam.ParDo(
   555                _BatchFn(
   556                    max_batch_size_bytes=1450,
   557                    max_number_rows=50,
   558                    max_number_cells=500))
   559            | beam.Map(lambda x: len(x)))
   560        assert_that(res, equal_to([25] * 2))
   561  
   562    def test_batch_disable(self, mock_batch_snapshot_class, mock_batch_checkout):
   563  
   564      mutation_group = [
   565          MutationGroup([
   566              WriteMutation.insert(
   567                  "roles",
   568                  ("key", "rolename"), [('1234', "mutations-inset-1234")])
   569          ])
   570      ] * 4
   571  
   572      with TestPipeline() as p:
   573        # to disable to batching, we need to set any of the batching parameters
   574        # either to lower value or zero
   575        res = (
   576            p | beam.Create(mutation_group)
   577            | beam.ParDo(
   578                _BatchFn(
   579                    max_batch_size_bytes=1450,
   580                    max_number_rows=0,
   581                    max_number_cells=500))
   582            | beam.Map(lambda x: len(x)))
   583        assert_that(res, equal_to([1] * 4))
   584  
   585    def test_batch_max_rows(self, mock_batch_snapshot_class, mock_batch_checkout):
   586  
   587      mutation_group = [
   588          MutationGroup([
   589              WriteMutation.insert(
   590                  "roles", ("key", "rolename"),
   591                  [
   592                      ('1234', "mutations-inset-1234"),
   593                      ('1235', "mutations-inset-1235"),
   594                  ])
   595          ])
   596      ] * 50
   597  
   598      with TestPipeline() as p:
   599        # There are total 50 mutation groups, each contains two rows.
   600        # The total number of rows will be 100 (50 * 2).
   601        # If each batch contains 10 rows max then batch count should be 10
   602        # (contains 5 mutation groups each).
   603        res = (
   604            p | beam.Create(mutation_group)
   605            | beam.ParDo(
   606                _BatchFn(
   607                    max_batch_size_bytes=1048576,
   608                    max_number_rows=10,
   609                    max_number_cells=500))
   610            | beam.Map(lambda x: len(x)))
   611        assert_that(res, equal_to([5] * 10))
   612  
   613    def test_batch_max_cells(
   614        self, mock_batch_snapshot_class, mock_batch_checkout):
   615  
   616      mutation_group = [
   617          MutationGroup([
   618              WriteMutation.insert(
   619                  "roles", ("key", "rolename"),
   620                  [
   621                      ('1234', "mutations-inset-1234"),
   622                      ('1235', "mutations-inset-1235"),
   623                  ])
   624          ])
   625      ] * 50
   626  
   627      with TestPipeline() as p:
   628        # There are total 50 mutation groups, each contains two rows (or 4 cells).
   629        # The total number of cells will be 200 (50 groups * 4 cells).
   630        # If each batch contains 50 cells max then batch count should be 5.
   631        # 4 batches contains 12 mutations groups and the fifth batch should be
   632        # consists of 2 mutation group element.
   633        # No. of mutations groups per batch = Max Cells / Cells per mutation group
   634        # total_batches = Total Number of Cells / Max Cells
   635        res = (
   636            p | beam.Create(mutation_group)
   637            | beam.ParDo(
   638                _BatchFn(
   639                    max_batch_size_bytes=1048576,
   640                    max_number_rows=500,
   641                    max_number_cells=50))
   642            | beam.Map(lambda x: len(x)))
   643        assert_that(res, equal_to([12, 12, 12, 12, 2]))
   644  
   645    def test_write_mutation_error(self, *args):
   646      with self.assertRaises(ValueError):
   647        # since `WriteMutation` only accept one operation.
   648        WriteMutation(insert="table-name", update="table-name")
   649  
   650    def test_display_data(self, *args):
   651      data = WriteToSpanner(
   652          project_id=TEST_PROJECT_ID,
   653          instance_id=TEST_INSTANCE_ID,
   654          database_id=_generate_database_name(),
   655          max_batch_size_bytes=1024).display_data()
   656      self.assertTrue("project_id" in data)
   657      self.assertTrue("instance_id" in data)
   658      self.assertTrue("pool" in data)
   659      self.assertTrue("database" in data)
   660      self.assertTrue("batch_size" in data)
   661      self.assertTrue("max_number_rows" in data)
   662      self.assertTrue("max_number_cells" in data)
   663  
   664  
   665  if __name__ == '__main__':
   666    logging.getLogger().setLevel(logging.INFO)
   667    unittest.main()