github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/experimental/spannerio_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import datetime 19 import logging 20 import random 21 import string 22 import typing 23 import unittest 24 25 import mock 26 27 import apache_beam as beam 28 from apache_beam.metrics.metric import MetricsFilter 29 from apache_beam.testing.test_pipeline import TestPipeline 30 from apache_beam.testing.util import assert_that 31 from apache_beam.testing.util import equal_to 32 33 # Protect against environments where spanner library is not available. 34 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 35 # pylint: disable=unused-import 36 try: 37 from google.cloud import spanner 38 from apache_beam.io.gcp.experimental.spannerio import create_transaction 39 from apache_beam.io.gcp.experimental.spannerio import ReadOperation 40 from apache_beam.io.gcp.experimental.spannerio import ReadFromSpanner 41 from apache_beam.io.gcp.experimental.spannerio import WriteMutation 42 from apache_beam.io.gcp.experimental.spannerio import MutationGroup 43 from apache_beam.io.gcp.experimental.spannerio import WriteToSpanner 44 from apache_beam.io.gcp.experimental.spannerio import _BatchFn 45 from apache_beam.io.gcp import resource_identifiers 46 from apache_beam.metrics import monitoring_infos 47 from apache_beam.metrics.execution import MetricsEnvironment 48 from apache_beam.metrics.metricbase import MetricName 49 except ImportError: 50 spanner = None 51 # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports 52 # pylint: enable=unused-import 53 54 MAX_DB_NAME_LENGTH = 30 55 TEST_PROJECT_ID = 'apache-beam-testing' 56 TEST_INSTANCE_ID = 'beam-test' 57 TEST_DATABASE_PREFIX = 'spanner-testdb-' 58 FAKE_TRANSACTION_INFO = {"session_id": "qwerty", "transaction_id": "qwerty"} 59 FAKE_ROWS = [[1, 'Alice'], [2, 'Bob'], [3, 'Carl'], [4, 'Dan'], [5, 'Evan'], 60 [6, 'Floyd']] 61 62 63 def _generate_database_name(): 64 mask = string.ascii_lowercase + string.digits 65 length = MAX_DB_NAME_LENGTH - 1 - len(TEST_DATABASE_PREFIX) 66 return TEST_DATABASE_PREFIX + ''.join( 67 random.choice(mask) for i in range(length)) 68 69 70 def _generate_test_data(): 71 mask = string.ascii_lowercase + string.digits 72 length = 100 73 return [( 74 'users', ['Key', 'Value'], 75 [(x, ''.join(random.choice(mask) for _ in range(length))) 76 for x in range(1, 5)])] 77 78 79 @unittest.skipIf(spanner is None, 'GCP dependencies are not installed.') 80 @mock.patch('apache_beam.io.gcp.experimental.spannerio.Client') 81 @mock.patch('apache_beam.io.gcp.experimental.spannerio.BatchSnapshot') 82 class SpannerReadTest(unittest.TestCase): 83 def test_read_with_query_batch( 84 self, mock_batch_snapshot_class, mock_client_class): 85 86 mock_snapshot_instance = mock.MagicMock() 87 mock_snapshot_instance.generate_query_batches.return_value = [{ 88 'query': { 89 'sql': 'SELECT * FROM users' 90 }, 'partition': 'test_partition' 91 } for _ in range(3)] 92 mock_snapshot_instance.to_dict.return_value = {} 93 94 mock_batch_snapshot_instance = mock.MagicMock() 95 # Prepare process_query_batch return results for three pipelines 96 mock_batch_snapshot_instance.process_query_batch.side_effect = [ 97 FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:] 98 ] * 3 99 mock_client_class.return_value.instance.return_value.database.return_value \ 100 .batch_snapshot.return_value = mock_snapshot_instance 101 mock_batch_snapshot_class.from_dict.return_value \ 102 = mock_batch_snapshot_instance 103 104 ro = [ReadOperation.query("Select * from users")] 105 with TestPipeline() as pipeline: 106 read = ( 107 pipeline 108 | 'read' >> ReadFromSpanner( 109 TEST_PROJECT_ID, 110 TEST_INSTANCE_ID, 111 _generate_database_name(), 112 sql="SELECT * FROM users")) 113 assert_that(read, equal_to(FAKE_ROWS), label='checkRead') 114 115 with TestPipeline() as pipeline: 116 readall = ( 117 pipeline 118 | 'read all' >> ReadFromSpanner( 119 TEST_PROJECT_ID, 120 TEST_INSTANCE_ID, 121 _generate_database_name(), 122 read_operations=ro)) 123 assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll') 124 125 with TestPipeline() as pipeline: 126 readpipeline = ( 127 pipeline 128 | 'create reads' >> beam.Create(ro) 129 | 'reads' >> ReadFromSpanner( 130 TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name())) 131 assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline') 132 133 # three pipelines 134 self.assertEqual( 135 mock_snapshot_instance.generate_query_batches.call_count, 3) 136 # three pipelines, each called three times 137 self.assertEqual( 138 mock_batch_snapshot_instance.process_query_batch.call_count, 3 * 3) 139 140 def test_read_with_table_batch( 141 self, mock_batch_snapshot_class, mock_client_class): 142 mock_snapshot_instance = mock.MagicMock() 143 mock_snapshot_instance.generate_read_batches.return_value = [{ 144 'read': { 145 'table': 'users', 146 'keyset': { 147 'all': True 148 }, 149 'columns': ['Key', 'Value'], 150 'index': '' 151 }, 152 'partition': 'test_partition' 153 } for _ in range(3)] 154 mock_snapshot_instance.to_dict.return_value = {} 155 156 mock_batch_snapshot_instance = mock.MagicMock() 157 # Prepare process_read_batch return results for three pipelines 158 mock_batch_snapshot_instance.process_read_batch.side_effect = [ 159 FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:] 160 ] * 3 161 162 mock_client_class.return_value.instance.return_value.database.return_value \ 163 .batch_snapshot.return_value = mock_snapshot_instance 164 mock_batch_snapshot_class.from_dict.return_value \ 165 = mock_batch_snapshot_instance 166 167 ro = [ReadOperation.table("users", ["Key", "Value"])] 168 with TestPipeline() as pipeline: 169 read = ( 170 pipeline 171 | 'read' >> ReadFromSpanner( 172 TEST_PROJECT_ID, 173 TEST_INSTANCE_ID, 174 _generate_database_name(), 175 table="users", 176 columns=["Key", "Value"])) 177 assert_that(read, equal_to(FAKE_ROWS), label='checkRead') 178 179 with TestPipeline() as pipeline: 180 readall = ( 181 pipeline 182 | 'read all' >> ReadFromSpanner( 183 TEST_PROJECT_ID, 184 TEST_INSTANCE_ID, 185 _generate_database_name(), 186 read_operations=ro)) 187 assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll') 188 189 with TestPipeline() as pipeline: 190 readpipeline = ( 191 pipeline 192 | 'create reads' >> beam.Create(ro) 193 | 'reads' >> ReadFromSpanner( 194 TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name())) 195 assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline') 196 197 # three pipelines 198 self.assertEqual(mock_snapshot_instance.generate_read_batches.call_count, 3) 199 # three pipelines, each called three times 200 self.assertEqual( 201 mock_batch_snapshot_instance.process_read_batch.call_count, 3 * 3) 202 203 with TestPipeline() as pipeline, self.assertRaises(ValueError): 204 # Test the exception raised at pipeline construction time, when user 205 # passes the read operations in the constructor and also in the pipeline 206 _ = ( 207 pipeline | 'reads error' >> ReadFromSpanner( 208 project_id=TEST_PROJECT_ID, 209 instance_id=TEST_INSTANCE_ID, 210 database_id=_generate_database_name(), 211 table="users")) 212 213 def test_read_with_index(self, mock_batch_snapshot_class, mock_client_class): 214 mock_snapshot_instance = mock.MagicMock() 215 mock_snapshot_instance.generate_read_batches.return_value = [{ 216 'read': { 217 'table': 'users', 218 'keyset': { 219 'all': True 220 }, 221 'columns': ['Key', 'Value'], 222 'index': '' 223 }, 224 'partition': 'test_partition' 225 } for _ in range(3)] 226 227 mock_batch_snapshot_instance = mock.MagicMock() 228 # Prepare process_read_batch return results for three pipelines 229 mock_batch_snapshot_instance.process_read_batch.side_effect = [ 230 FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:] 231 ] * 3 232 233 mock_snapshot_instance.to_dict.return_value = {} 234 235 mock_client_class.return_value.instance.return_value.database.return_value \ 236 .batch_snapshot.return_value = mock_snapshot_instance 237 mock_batch_snapshot_class.from_dict.return_value \ 238 = mock_batch_snapshot_instance 239 240 ro = [ReadOperation.table("users", ["Key", "Value"], index="Key")] 241 with TestPipeline() as pipeline: 242 read = ( 243 pipeline 244 | 'read' >> ReadFromSpanner( 245 TEST_PROJECT_ID, 246 TEST_INSTANCE_ID, 247 _generate_database_name(), 248 table="users", 249 columns=["Key", "Value"])) 250 assert_that(read, equal_to(FAKE_ROWS), label='checkRead') 251 252 with TestPipeline() as pipeline: 253 readall = ( 254 pipeline 255 | 'read all' >> ReadFromSpanner( 256 TEST_PROJECT_ID, 257 TEST_INSTANCE_ID, 258 _generate_database_name(), 259 read_operations=ro)) 260 assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll') 261 262 with TestPipeline() as pipeline: 263 readpipeline = ( 264 pipeline 265 | 'create reads' >> beam.Create(ro) 266 | 'reads' >> ReadFromSpanner( 267 TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name())) 268 assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline') 269 270 # three pipelines 271 self.assertEqual(mock_snapshot_instance.generate_read_batches.call_count, 3) 272 # three pipelines, each called three times 273 self.assertEqual( 274 mock_batch_snapshot_instance.process_read_batch.call_count, 3 * 3) 275 276 with TestPipeline() as pipeline, self.assertRaises(ValueError): 277 # Test the exception raised at pipeline construction time, when user 278 # passes the read operations in the constructor and also in the pipeline. 279 _ = ( 280 pipeline | 'reads error' >> ReadFromSpanner( 281 project_id=TEST_PROJECT_ID, 282 instance_id=TEST_INSTANCE_ID, 283 database_id=_generate_database_name(), 284 table="users")) 285 286 def test_read_with_transaction( 287 self, mock_batch_snapshot_class, mock_client_class): 288 mock_snapshot_instance = mock.MagicMock() 289 mock_snapshot_instance.to_dict.return_value = FAKE_TRANSACTION_INFO 290 291 mock_transaction_instance = mock.MagicMock() 292 mock_transaction_instance.execute_sql.return_value = FAKE_ROWS 293 mock_transaction_instance.read.return_value = FAKE_ROWS 294 295 mock_client_class.return_value.instance.return_value.database.return_value \ 296 .batch_snapshot.return_value = mock_snapshot_instance 297 mock_client_class.return_value.instance.return_value.database.return_value \ 298 .session.return_value.transaction.return_value.__enter__.return_value \ 299 = mock_transaction_instance 300 301 ro = [ReadOperation.query("Select * from users")] 302 303 with TestPipeline() as p: 304 transaction = ( 305 p | create_transaction( 306 project_id=TEST_PROJECT_ID, 307 instance_id=TEST_INSTANCE_ID, 308 database_id=_generate_database_name(), 309 exact_staleness=datetime.timedelta(seconds=10))) 310 311 read_query = ( 312 p | 'with query' >> ReadFromSpanner( 313 project_id=TEST_PROJECT_ID, 314 instance_id=TEST_INSTANCE_ID, 315 database_id=_generate_database_name(), 316 transaction=transaction, 317 sql="Select * from users")) 318 assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery') 319 320 read_table = ( 321 p | 'with table' >> ReadFromSpanner( 322 project_id=TEST_PROJECT_ID, 323 instance_id=TEST_INSTANCE_ID, 324 database_id=_generate_database_name(), 325 transaction=transaction, 326 table="users", 327 columns=["Key", "Value"])) 328 assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable') 329 330 read_indexed_table = ( 331 p | 'with index' >> ReadFromSpanner( 332 project_id=TEST_PROJECT_ID, 333 instance_id=TEST_INSTANCE_ID, 334 database_id=_generate_database_name(), 335 transaction=transaction, 336 table="users", 337 index="Key", 338 columns=["Key", "Value"])) 339 assert_that( 340 read_indexed_table, equal_to(FAKE_ROWS), label='checkTableIndex') 341 342 read = ( 343 p | 'read all' >> ReadFromSpanner( 344 TEST_PROJECT_ID, 345 TEST_INSTANCE_ID, 346 _generate_database_name(), 347 transaction=transaction, 348 read_operations=ro)) 349 assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll') 350 351 read_pipeline = ( 352 p 353 | 'create read operations' >> beam.Create(ro) 354 | 'reads' >> ReadFromSpanner( 355 TEST_PROJECT_ID, 356 TEST_INSTANCE_ID, 357 _generate_database_name(), 358 transaction=transaction)) 359 assert_that(read_pipeline, equal_to(FAKE_ROWS), label='checkReadPipeline') 360 361 # transaction setup once 362 self.assertEqual(mock_snapshot_instance.to_dict.call_count, 1) 363 # three pipelines called execute_sql 364 self.assertEqual(mock_transaction_instance.execute_sql.call_count, 3) 365 # two pipelines called read 366 self.assertEqual(mock_transaction_instance.read.call_count, 2) 367 368 with TestPipeline() as p, self.assertRaises(ValueError): 369 # Test the exception raised at pipeline construction time, when user 370 # passes the read operations in the constructor and also in the pipeline. 371 transaction = ( 372 p | create_transaction( 373 project_id=TEST_PROJECT_ID, 374 instance_id=TEST_INSTANCE_ID, 375 database_id=_generate_database_name(), 376 exact_staleness=datetime.timedelta(seconds=10))) 377 _ = ( 378 p 379 | 'create read operations2' >> beam.Create(ro) 380 | 'reads with error' >> ReadFromSpanner( 381 TEST_PROJECT_ID, 382 TEST_INSTANCE_ID, 383 _generate_database_name(), 384 transaction=transaction, 385 read_operations=ro)) 386 387 def test_invalid_transaction( 388 self, mock_batch_snapshot_class, mock_client_class): 389 # test exception raises at pipeline execution time 390 with self.assertRaises(ValueError), TestPipeline() as p: 391 transaction = ( 392 p | beam.Create([{ 393 "invalid": "transaction" 394 }]).with_output_types(typing.Any)) 395 _ = ( 396 p | 'with query' >> ReadFromSpanner( 397 project_id=TEST_PROJECT_ID, 398 instance_id=TEST_INSTANCE_ID, 399 database_id=_generate_database_name(), 400 transaction=transaction, 401 sql="Select * from users")) 402 403 def test_display_data(self, *args): 404 dd_sql = ReadFromSpanner( 405 project_id=TEST_PROJECT_ID, 406 instance_id=TEST_INSTANCE_ID, 407 database_id=_generate_database_name(), 408 sql="Select * from users").display_data() 409 410 dd_table = ReadFromSpanner( 411 project_id=TEST_PROJECT_ID, 412 instance_id=TEST_INSTANCE_ID, 413 database_id=_generate_database_name(), 414 table="users", 415 columns=['id', 'name']).display_data() 416 417 dd_transaction = ReadFromSpanner( 418 project_id=TEST_PROJECT_ID, 419 instance_id=TEST_INSTANCE_ID, 420 database_id=_generate_database_name(), 421 table="users", 422 columns=['id', 'name'], 423 transaction={ 424 "transaction_id": "test123", "session_id": "test456" 425 }).display_data() 426 427 self.assertTrue("sql" in dd_sql) 428 self.assertTrue("table" in dd_table) 429 self.assertTrue("table" in dd_transaction) 430 self.assertTrue("transaction" in dd_transaction) 431 432 433 @unittest.skipIf(spanner is None, 'GCP dependencies are not installed.') 434 @mock.patch('apache_beam.io.gcp.experimental.spannerio.Client') 435 @mock.patch('google.cloud.spanner_v1.database.BatchCheckout') 436 class SpannerWriteTest(unittest.TestCase): 437 def test_spanner_write(self, mock_batch_snapshot_class, mock_batch_checkout): 438 ks = spanner.KeySet(keys=[[1233], [1234]]) 439 440 mutations = [ 441 WriteMutation.delete("roles", ks), 442 WriteMutation.insert( 443 "roles", ("key", "rolename"), [('1233', "mutations-inset-1233")]), 444 WriteMutation.insert( 445 "roles", ("key", "rolename"), [('1234', "mutations-inset-1234")]), 446 WriteMutation.update( 447 "roles", ("key", "rolename"), 448 [('1234', "mutations-inset-1233-updated")]), 449 ] 450 451 p = TestPipeline() 452 _ = ( 453 p 454 | beam.Create(mutations) 455 | WriteToSpanner( 456 project_id=TEST_PROJECT_ID, 457 instance_id=TEST_INSTANCE_ID, 458 database_id=_generate_database_name(), 459 max_batch_size_bytes=1024)) 460 res = p.run() 461 res.wait_until_finish() 462 463 metric_results = res.metrics().query( 464 MetricsFilter().with_name("SpannerBatches")) 465 batches_counter = metric_results['counters'][0] 466 467 self.assertEqual(batches_counter.committed, 2) 468 self.assertEqual(batches_counter.attempted, 2) 469 470 def test_spanner_bundles_size( 471 self, mock_batch_snapshot_class, mock_batch_checkout): 472 ks = spanner.KeySet(keys=[[1233], [1234]]) 473 mutations = [ 474 WriteMutation.delete("roles", ks), 475 WriteMutation.insert( 476 "roles", ("key", "rolename"), [('1234', "mutations-inset-1234")]) 477 ] * 50 478 p = TestPipeline() 479 _ = ( 480 p 481 | beam.Create(mutations) 482 | WriteToSpanner( 483 project_id=TEST_PROJECT_ID, 484 instance_id=TEST_INSTANCE_ID, 485 database_id=_generate_database_name(), 486 max_batch_size_bytes=1024)) 487 res = p.run() 488 res.wait_until_finish() 489 490 metric_results = res.metrics().query( 491 MetricsFilter().with_name('SpannerBatches')) 492 batches_counter = metric_results['counters'][0] 493 494 self.assertEqual(batches_counter.committed, 53) 495 self.assertEqual(batches_counter.attempted, 53) 496 497 def test_spanner_write_mutation_groups( 498 self, mock_batch_snapshot_class, mock_batch_checkout): 499 ks = spanner.KeySet(keys=[[1233], [1234]]) 500 mutation_groups = [ 501 MutationGroup([ 502 WriteMutation.insert( 503 "roles", ("key", "rolename"), 504 [('9001233', "mutations-inset-1233")]), 505 WriteMutation.insert( 506 "roles", ("key", "rolename"), 507 [('9001234', "mutations-inset-1234")]) 508 ]), 509 MutationGroup([ 510 WriteMutation.update( 511 "roles", ("key", "rolename"), 512 [('9001234', "mutations-inset-9001233-updated")]) 513 ]), 514 MutationGroup([WriteMutation.delete("roles", ks)]) 515 ] 516 517 p = TestPipeline() 518 _ = ( 519 p 520 | beam.Create(mutation_groups) 521 | WriteToSpanner( 522 project_id=TEST_PROJECT_ID, 523 instance_id=TEST_INSTANCE_ID, 524 database_id=_generate_database_name(), 525 max_batch_size_bytes=100)) 526 res = p.run() 527 res.wait_until_finish() 528 529 metric_results = res.metrics().query( 530 MetricsFilter().with_name('SpannerBatches')) 531 batches_counter = metric_results['counters'][0] 532 533 self.assertEqual(batches_counter.committed, 3) 534 self.assertEqual(batches_counter.attempted, 3) 535 536 def test_batch_byte_size( 537 self, mock_batch_snapshot_class, mock_batch_checkout): 538 539 # each mutation group byte size is 58 bytes. 540 mutation_group = [ 541 MutationGroup([ 542 WriteMutation.insert( 543 "roles", 544 ("key", "rolename"), [('1234', "mutations-inset-1234")]) 545 ]) 546 ] * 50 547 548 with TestPipeline() as p: 549 # the total 50 mutation group size will be 2900 (58 * 50) 550 # if we want to make two batches, so batch size should be 1450 (2900 / 2) 551 # and each bach should contains 25 mutations. 552 res = ( 553 p | beam.Create(mutation_group) 554 | beam.ParDo( 555 _BatchFn( 556 max_batch_size_bytes=1450, 557 max_number_rows=50, 558 max_number_cells=500)) 559 | beam.Map(lambda x: len(x))) 560 assert_that(res, equal_to([25] * 2)) 561 562 def test_batch_disable(self, mock_batch_snapshot_class, mock_batch_checkout): 563 564 mutation_group = [ 565 MutationGroup([ 566 WriteMutation.insert( 567 "roles", 568 ("key", "rolename"), [('1234', "mutations-inset-1234")]) 569 ]) 570 ] * 4 571 572 with TestPipeline() as p: 573 # to disable to batching, we need to set any of the batching parameters 574 # either to lower value or zero 575 res = ( 576 p | beam.Create(mutation_group) 577 | beam.ParDo( 578 _BatchFn( 579 max_batch_size_bytes=1450, 580 max_number_rows=0, 581 max_number_cells=500)) 582 | beam.Map(lambda x: len(x))) 583 assert_that(res, equal_to([1] * 4)) 584 585 def test_batch_max_rows(self, mock_batch_snapshot_class, mock_batch_checkout): 586 587 mutation_group = [ 588 MutationGroup([ 589 WriteMutation.insert( 590 "roles", ("key", "rolename"), 591 [ 592 ('1234', "mutations-inset-1234"), 593 ('1235', "mutations-inset-1235"), 594 ]) 595 ]) 596 ] * 50 597 598 with TestPipeline() as p: 599 # There are total 50 mutation groups, each contains two rows. 600 # The total number of rows will be 100 (50 * 2). 601 # If each batch contains 10 rows max then batch count should be 10 602 # (contains 5 mutation groups each). 603 res = ( 604 p | beam.Create(mutation_group) 605 | beam.ParDo( 606 _BatchFn( 607 max_batch_size_bytes=1048576, 608 max_number_rows=10, 609 max_number_cells=500)) 610 | beam.Map(lambda x: len(x))) 611 assert_that(res, equal_to([5] * 10)) 612 613 def test_batch_max_cells( 614 self, mock_batch_snapshot_class, mock_batch_checkout): 615 616 mutation_group = [ 617 MutationGroup([ 618 WriteMutation.insert( 619 "roles", ("key", "rolename"), 620 [ 621 ('1234', "mutations-inset-1234"), 622 ('1235', "mutations-inset-1235"), 623 ]) 624 ]) 625 ] * 50 626 627 with TestPipeline() as p: 628 # There are total 50 mutation groups, each contains two rows (or 4 cells). 629 # The total number of cells will be 200 (50 groups * 4 cells). 630 # If each batch contains 50 cells max then batch count should be 5. 631 # 4 batches contains 12 mutations groups and the fifth batch should be 632 # consists of 2 mutation group element. 633 # No. of mutations groups per batch = Max Cells / Cells per mutation group 634 # total_batches = Total Number of Cells / Max Cells 635 res = ( 636 p | beam.Create(mutation_group) 637 | beam.ParDo( 638 _BatchFn( 639 max_batch_size_bytes=1048576, 640 max_number_rows=500, 641 max_number_cells=50)) 642 | beam.Map(lambda x: len(x))) 643 assert_that(res, equal_to([12, 12, 12, 12, 2])) 644 645 def test_write_mutation_error(self, *args): 646 with self.assertRaises(ValueError): 647 # since `WriteMutation` only accept one operation. 648 WriteMutation(insert="table-name", update="table-name") 649 650 def test_display_data(self, *args): 651 data = WriteToSpanner( 652 project_id=TEST_PROJECT_ID, 653 instance_id=TEST_INSTANCE_ID, 654 database_id=_generate_database_name(), 655 max_batch_size_bytes=1024).display_data() 656 self.assertTrue("project_id" in data) 657 self.assertTrue("instance_id" in data) 658 self.assertTrue("pool" in data) 659 self.assertTrue("database" in data) 660 self.assertTrue("batch_size" in data) 661 self.assertTrue("max_number_rows" in data) 662 self.assertTrue("max_number_cells" in data) 663 664 665 if __name__ == '__main__': 666 logging.getLogger().setLevel(logging.INFO) 667 unittest.main()