github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/experimental/spannerio.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Google Cloud Spanner IO 19 20 Experimental; no backwards-compatibility guarantees. 21 22 This is an experimental module for reading and writing data from Google Cloud 23 Spanner. Visit: https://cloud.google.com/spanner for more details. 24 25 Reading Data from Cloud Spanner. 26 27 To read from Cloud Spanner apply ReadFromSpanner transformation. It will 28 return a PCollection, where each element represents an individual row returned 29 from the read operation. Both Query and Read APIs are supported. 30 31 ReadFromSpanner relies on the ReadOperation objects which is exposed by the 32 SpannerIO API. ReadOperation holds the immutable data which is responsible to 33 execute batch and naive reads on Cloud Spanner. This is done for more 34 convenient programming. 35 36 ReadFromSpanner reads from Cloud Spanner by providing either an 'sql' param 37 in the constructor or 'table' name with 'columns' as list. For example::: 38 39 records = (pipeline 40 | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME, 41 sql='Select * from users')) 42 43 records = (pipeline 44 | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME, 45 table='users', columns=['id', 'name', 'email'])) 46 47 You can also perform multiple reads by providing a list of ReadOperations 48 to the ReadFromSpanner transform constructor. ReadOperation exposes two static 49 methods. Use 'query' to perform sql based reads, 'table' to perform read from 50 table name. For example::: 51 52 read_operations = [ 53 ReadOperation.table(table='customers', columns=['name', 54 'email']), 55 ReadOperation.table(table='vendors', columns=['name', 56 'email']), 57 ] 58 all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME, 59 read_operations=read_operations) 60 61 ...OR... 62 63 read_operations = [ 64 ReadOperation.query(sql='Select name, email from 65 customers'), 66 ReadOperation.query( 67 sql='Select * from users where id <= @user_id', 68 params={'user_id': 100}, 69 params_type={'user_id': param_types.INT64} 70 ), 71 ] 72 # `params_types` are instance of `google.cloud.spanner.param_types` 73 all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME, 74 read_operations=read_operations) 75 76 For more information, please review the docs on class ReadOperation. 77 78 User can also able to provide the ReadOperation in form of PCollection via 79 pipeline. For example::: 80 81 users = (pipeline 82 | beam.Create([ReadOperation...]) 83 | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME)) 84 85 User may also create cloud spanner transaction from the transform called 86 `create_transaction` which is available in the SpannerIO API. 87 88 The transform is guaranteed to be executed on a consistent snapshot of data, 89 utilizing the power of read only transactions. Staleness of data can be 90 controlled by providing the `read_timestamp` or `exact_staleness` param values 91 in the constructor. 92 93 This transform requires root of the pipeline (PBegin) and returns PTransform 94 which is passed later to the `ReadFromSpanner` constructor. `ReadFromSpanner` 95 pass this transaction PTransform as a singleton side input to the 96 `_NaiveSpannerReadDoFn` containing 'session_id' and 'transaction_id'. 97 For example::: 98 99 transaction = (pipeline | create_transaction(TEST_PROJECT_ID, 100 TEST_INSTANCE_ID, 101 DB_NAME)) 102 103 users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME, 104 sql='Select * from users', transaction=transaction) 105 106 tweets = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME, 107 sql='Select * from tweets', transaction=transaction) 108 109 For further details of this transform, please review the docs on the 110 :meth:`create_transaction` method available in the SpannerIO API. 111 112 ReadFromSpanner takes this transform in the constructor and pass this to the 113 read pipeline as the singleton side input. 114 115 Writing Data to Cloud Spanner. 116 117 The WriteToSpanner transform writes to Cloud Spanner by executing a 118 collection a input rows (WriteMutation). The mutations are grouped into 119 batches for efficiency. 120 121 WriteToSpanner transform relies on the WriteMutation objects which is exposed 122 by the SpannerIO API. WriteMutation have five static methods (insert, update, 123 insert_or_update, replace, delete). These methods returns the instance of the 124 _Mutator object which contains the mutation type and the Spanner Mutation 125 object. For more details, review the docs of the class SpannerIO.WriteMutation. 126 For example::: 127 128 mutations = [ 129 WriteMutation.insert(table='user', columns=('name', 'email'), 130 values=[('sara', 'sara@dev.com')]) 131 ] 132 _ = (p 133 | beam.Create(mutations) 134 | WriteToSpanner( 135 project_id=SPANNER_PROJECT_ID, 136 instance_id=SPANNER_INSTANCE_ID, 137 database_id=SPANNER_DATABASE_NAME) 138 ) 139 140 You can also create WriteMutation via calling its constructor. For example::: 141 142 mutations = [ 143 WriteMutation(insert='users', columns=('name', 'email'), 144 values=[('sara", 'sara@example.com')]) 145 ] 146 147 For more information, review the docs available on WriteMutation class. 148 149 WriteToSpanner transform also takes three batching parameters (max_number_rows, 150 max_number_cells and max_batch_size_bytes). By default, max_number_rows is set 151 to 50 rows, max_number_cells is set to 500 cells and max_batch_size_bytes is 152 set to 1MB (1048576 bytes). These parameter used to reduce the number of 153 transactions sent to spanner by grouping the mutation into batches. Setting 154 these param values either to smaller value or zero to disable batching. 155 Unlike the Java connector, this connector does not create batches of 156 transactions sorted by table and primary key. 157 158 WriteToSpanner transforms starts with the grouping into batches. The first step 159 in this process is to make the mutation groups of the WriteMutation 160 objects and then filtering them into batchable and unbatchable mutation 161 groups. There are three batching parameters (max_number_cells, max_number_rows 162 & max_batch_size_bytes). We calculated th mutation byte size from the method 163 available in the `google.cloud.spanner_v1.proto.mutation_pb2.Mutation.ByteSize`. 164 if the mutation rows, cells or byte size are larger than value of the any 165 batching parameters param, it will be tagged as "unbatchable" mutation. After 166 this all the batchable mutation are merged into a single mutation group whos 167 size is not larger than the "max_batch_size_bytes", after this process, all the 168 mutation groups together to process. If the Mutation references a table or 169 column does not exits, it will cause a exception and fails the entire pipeline. 170 """ 171 import typing 172 from collections import deque 173 from collections import namedtuple 174 175 from apache_beam import Create 176 from apache_beam import DoFn 177 from apache_beam import Flatten 178 from apache_beam import ParDo 179 from apache_beam import Reshuffle 180 from apache_beam.internal.metrics.metric import ServiceCallMetric 181 from apache_beam.io.gcp import resource_identifiers 182 from apache_beam.metrics import Metrics 183 from apache_beam.metrics import monitoring_infos 184 from apache_beam.pvalue import AsSingleton 185 from apache_beam.pvalue import PBegin 186 from apache_beam.pvalue import TaggedOutput 187 from apache_beam.transforms import PTransform 188 from apache_beam.transforms import ptransform_fn 189 from apache_beam.transforms import window 190 from apache_beam.transforms.display import DisplayDataItem 191 from apache_beam.typehints import with_input_types 192 from apache_beam.typehints import with_output_types 193 194 # Protect against environments where spanner library is not available. 195 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 196 # pylint: disable=unused-import 197 try: 198 from google.cloud.spanner import Client 199 from google.cloud.spanner import KeySet 200 from google.cloud.spanner_v1 import batch 201 from google.cloud.spanner_v1.database import BatchSnapshot 202 from google.api_core.exceptions import ClientError, GoogleAPICallError 203 from apitools.base.py.exceptions import HttpError 204 except ImportError: 205 Client = None 206 KeySet = None 207 BatchSnapshot = None 208 209 try: 210 from google.cloud.spanner_v1 import Mutation 211 except ImportError: 212 try: 213 # Remove this and the try clause when we upgrade to google-cloud-spanner 214 # 3.x.x. 215 from google.cloud.spanner_v1.proto.mutation_pb2 import Mutation 216 except ImportError: 217 # Ignoring for environments where the Spanner library is not available. 218 pass 219 220 __all__ = [ 221 'create_transaction', 222 'ReadFromSpanner', 223 'ReadOperation', 224 'WriteToSpanner', 225 'WriteMutation', 226 'MutationGroup' 227 ] 228 229 230 class _SPANNER_TRANSACTION(namedtuple("SPANNER_TRANSACTION", ["transaction"])): 231 """ 232 Holds the spanner transaction details. 233 """ 234 235 __slots__ = () 236 237 238 class ReadOperation(namedtuple( 239 "ReadOperation", ["is_sql", "is_table", "read_operation", "kwargs"])): 240 """ 241 Encapsulates a spanner read operation. 242 """ 243 244 __slots__ = () 245 246 @classmethod 247 def query(cls, sql, params=None, param_types=None): 248 """ 249 A convenient method to construct ReadOperation from sql query. 250 251 Args: 252 sql: SQL query statement 253 params: (optional) values for parameter replacement. Keys must match the 254 names used in sql 255 param_types: (optional) maps explicit types for one or more param values; 256 required if parameters are passed. 257 """ 258 259 if params: 260 assert param_types is not None 261 262 return cls( 263 is_sql=True, 264 is_table=False, 265 read_operation="process_query_batch", 266 kwargs={ 267 'sql': sql, 'params': params, 'param_types': param_types 268 }) 269 270 @classmethod 271 def table(cls, table, columns, index="", keyset=None): 272 """ 273 A convenient method to construct ReadOperation from table. 274 275 Args: 276 table: name of the table from which to fetch data. 277 columns: names of columns to be retrieved. 278 index: (optional) name of index to use, rather than the table's primary 279 key. 280 keyset: (optional) `KeySet` keys / ranges identifying rows to be 281 retrieved. 282 """ 283 keyset = keyset or KeySet(all_=True) 284 if not isinstance(keyset, KeySet): 285 raise ValueError( 286 "keyset must be an instance of class " 287 "google.cloud.spanner.KeySet") 288 return cls( 289 is_sql=False, 290 is_table=True, 291 read_operation="process_read_batch", 292 kwargs={ 293 'table': table, 294 'columns': columns, 295 'index': index, 296 'keyset': keyset 297 }) 298 299 300 class _BeamSpannerConfiguration(namedtuple("_BeamSpannerConfiguration", 301 ["project", 302 "instance", 303 "database", 304 "table", 305 "query_name", 306 "credentials", 307 "pool", 308 "snapshot_read_timestamp", 309 "snapshot_exact_staleness"])): 310 """ 311 A namedtuple holds the immutable data of the connection string to the cloud 312 spanner. 313 """ 314 @property 315 def snapshot_options(self): 316 snapshot_options = {} 317 if self.snapshot_exact_staleness: 318 snapshot_options['exact_staleness'] = self.snapshot_exact_staleness 319 if self.snapshot_read_timestamp: 320 snapshot_options['read_timestamp'] = self.snapshot_read_timestamp 321 return snapshot_options 322 323 324 @with_input_types(ReadOperation, _SPANNER_TRANSACTION) 325 @with_output_types(typing.List[typing.Any]) 326 class _NaiveSpannerReadDoFn(DoFn): 327 def __init__(self, spanner_configuration): 328 """ 329 A naive version of Spanner read which uses the transaction API of the 330 cloud spanner. 331 https://googleapis.dev/python/spanner/latest/transaction-api.html 332 In Naive reads, this transform performs single reads, where as the 333 Batch reads use the spanner partitioning query to create batches. 334 335 Args: 336 spanner_configuration: (_BeamSpannerConfiguration) Connection details to 337 connect with cloud spanner. 338 """ 339 self._spanner_configuration = spanner_configuration 340 self._snapshot = None 341 self._session = None 342 self.base_labels = { 343 monitoring_infos.SERVICE_LABEL: 'Spanner', 344 monitoring_infos.METHOD_LABEL: 'Read', 345 monitoring_infos.SPANNER_PROJECT_ID: ( 346 self._spanner_configuration.project), 347 monitoring_infos.SPANNER_DATABASE_ID: ( 348 self._spanner_configuration.database), 349 } 350 351 def _table_metric(self, table_id, status): 352 database_id = self._spanner_configuration.database 353 project_id = self._spanner_configuration.project 354 resource = resource_identifiers.SpannerTable( 355 project_id, database_id, table_id) 356 labels = { 357 **self.base_labels, 358 monitoring_infos.RESOURCE_LABEL: resource, 359 monitoring_infos.SPANNER_TABLE_ID: table_id 360 } 361 service_call_metric = ServiceCallMetric( 362 request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, 363 base_labels=labels) 364 service_call_metric.call(str(status)) 365 366 def _query_metric(self, query_name, status): 367 project_id = self._spanner_configuration.project 368 resource = resource_identifiers.SpannerSqlQuery(project_id, query_name) 369 labels = { 370 **self.base_labels, 371 monitoring_infos.RESOURCE_LABEL: resource, 372 monitoring_infos.SPANNER_QUERY_NAME: query_name 373 } 374 service_call_metric = ServiceCallMetric( 375 request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, 376 base_labels=labels) 377 service_call_metric.call(str(status)) 378 379 def _get_session(self): 380 if self._session is None: 381 session = self._session = self._database.session() 382 session.create() 383 return self._session 384 385 def _close_session(self): 386 if self._session is not None: 387 self._session.delete() 388 389 def setup(self): 390 # setting up client to connect with cloud spanner 391 spanner_client = Client(self._spanner_configuration.project) 392 instance = spanner_client.instance(self._spanner_configuration.instance) 393 self._database = instance.database( 394 self._spanner_configuration.database, 395 pool=self._spanner_configuration.pool) 396 397 def process(self, element, spanner_transaction): 398 # `spanner_transaction` should be the instance of the _SPANNER_TRANSACTION 399 # object. 400 if not isinstance(spanner_transaction, _SPANNER_TRANSACTION): 401 raise ValueError( 402 "Invalid transaction object: %s. It should be instance " 403 "of SPANNER_TRANSACTION object created by " 404 "spannerio.create_transaction transform." % type(spanner_transaction)) 405 406 transaction_info = spanner_transaction.transaction 407 408 # We used batch snapshot to reuse the same transaction passed through the 409 # side input 410 self._snapshot = BatchSnapshot.from_dict(self._database, transaction_info) 411 412 # getting the transaction from the snapshot's session to run read operation. 413 # with self._snapshot.session().transaction() as transaction: 414 with self._get_session().transaction() as transaction: 415 table_id = self._spanner_configuration.table 416 query_name = self._spanner_configuration.query_name or '' 417 418 if element.is_sql is True: 419 transaction_read = transaction.execute_sql 420 metric_action = self._query_metric 421 metric_id = query_name 422 elif element.is_table is True: 423 transaction_read = transaction.read 424 metric_action = self._table_metric 425 metric_id = table_id 426 else: 427 raise ValueError( 428 "ReadOperation is improperly configure: %s" % str(element)) 429 430 try: 431 for row in transaction_read(**element.kwargs): 432 yield row 433 434 metric_action(metric_id, 'ok') 435 except (ClientError, GoogleAPICallError) as e: 436 metric_action(metric_id, e.code.value) 437 raise 438 except HttpError as e: 439 metric_action(metric_id, e) 440 raise 441 442 443 @with_input_types(ReadOperation) 444 @with_output_types(typing.Dict[typing.Any, typing.Any]) 445 class _CreateReadPartitions(DoFn): 446 """ 447 A DoFn to create partitions. Uses the Partitioning API (PartitionRead / 448 PartitionQuery) request to start a partitioned query operation. Returns a 449 list of batch information needed to perform the actual queries. 450 451 If the element is the instance of :class:`ReadOperation` is to perform sql 452 query, `PartitionQuery` API is used the create partitions and returns mappings 453 of information used perform actual partitioned reads via 454 :meth:`process_query_batch`. 455 456 If the element is the instance of :class:`ReadOperation` is to perform read 457 from table, `PartitionRead` API is used the create partitions and returns 458 mappings of information used perform actual partitioned reads via 459 :meth:`process_read_batch`. 460 """ 461 def __init__(self, spanner_configuration): 462 self._spanner_configuration = spanner_configuration 463 464 def setup(self): 465 spanner_client = Client( 466 project=self._spanner_configuration.project, 467 credentials=self._spanner_configuration.credentials) 468 instance = spanner_client.instance(self._spanner_configuration.instance) 469 self._database = instance.database( 470 self._spanner_configuration.database, 471 pool=self._spanner_configuration.pool) 472 self._snapshot = self._database.batch_snapshot( 473 **self._spanner_configuration.snapshot_options) 474 self._snapshot_dict = self._snapshot.to_dict() 475 476 def process(self, element): 477 if element.is_sql is True: 478 partitioning_action = self._snapshot.generate_query_batches 479 elif element.is_table is True: 480 partitioning_action = self._snapshot.generate_read_batches 481 else: 482 raise ValueError( 483 "ReadOperation is improperly configure: %s" % str(element)) 484 485 for p in partitioning_action(**element.kwargs): 486 yield { 487 "is_sql": element.is_sql, 488 "is_table": element.is_table, 489 "read_operation": element.read_operation, 490 "partitions": p, 491 "transaction_info": self._snapshot_dict 492 } 493 494 495 @with_input_types(int) 496 @with_output_types(_SPANNER_TRANSACTION) 497 class _CreateTransactionFn(DoFn): 498 """ 499 A DoFn to create the transaction of cloud spanner. 500 It connects to the database and and returns the transaction_id and session_id 501 by using the batch_snapshot.to_dict() method available in the google cloud 502 spanner sdk. 503 504 https://googleapis.dev/python/spanner/latest/database-api.html?highlight= 505 batch_snapshot#google.cloud.spanner_v1.database.BatchSnapshot.to_dict 506 """ 507 def __init__( 508 self, 509 project_id, 510 instance_id, 511 database_id, 512 credentials, 513 pool, 514 read_timestamp, 515 exact_staleness): 516 self._project_id = project_id 517 self._instance_id = instance_id 518 self._database_id = database_id 519 self._credentials = credentials 520 self._pool = pool 521 522 self._snapshot_options = {} 523 if read_timestamp: 524 self._snapshot_options['read_timestamp'] = read_timestamp 525 if exact_staleness: 526 self._snapshot_options['exact_staleness'] = exact_staleness 527 self._snapshot = None 528 529 def setup(self): 530 self._spanner_client = Client( 531 project=self._project_id, credentials=self._credentials) 532 self._instance = self._spanner_client.instance(self._instance_id) 533 self._database = self._instance.database(self._database_id, pool=self._pool) 534 535 def process(self, element, *args, **kwargs): 536 self._snapshot = self._database.batch_snapshot(**self._snapshot_options) 537 return [_SPANNER_TRANSACTION(self._snapshot.to_dict())] 538 539 540 @ptransform_fn 541 def create_transaction( 542 pbegin, 543 project_id, 544 instance_id, 545 database_id, 546 credentials=None, 547 pool=None, 548 read_timestamp=None, 549 exact_staleness=None): 550 """ 551 A PTransform method to create a batch transaction. 552 553 Args: 554 pbegin: Root of the pipeline 555 project_id: Cloud spanner project id. Be sure to use the Project ID, 556 not the Project Number. 557 instance_id: Cloud spanner instance id. 558 database_id: Cloud spanner database id. 559 credentials: (optional) The authorization credentials to attach to requests. 560 These credentials identify this application to the service. 561 If none are specified, the client will attempt to ascertain 562 the credentials from the environment. 563 pool: (optional) session pool to be used by database. If not passed, 564 Spanner Cloud SDK uses the BurstyPool by default. 565 `google.cloud.spanner.BurstyPool`. Ref: 566 https://googleapis.dev/python/spanner/latest/database-api.html?#google. 567 cloud.spanner_v1.database.Database 568 read_timestamp: (optional) An instance of the `datetime.datetime` object to 569 execute all reads at the given timestamp. 570 exact_staleness: (optional) An instance of the `datetime.timedelta` 571 object. These timestamp bounds execute reads at a user-specified 572 timestamp. 573 """ 574 575 assert isinstance(pbegin, PBegin) 576 577 return ( 578 pbegin | Create([1]) | ParDo( 579 _CreateTransactionFn( 580 project_id, 581 instance_id, 582 database_id, 583 credentials, 584 pool, 585 read_timestamp, 586 exact_staleness))) 587 588 589 @with_input_types(typing.Dict[typing.Any, typing.Any]) 590 @with_output_types(typing.List[typing.Any]) 591 class _ReadFromPartitionFn(DoFn): 592 """ 593 A DoFn to perform reads from the partition. 594 """ 595 def __init__(self, spanner_configuration): 596 self._spanner_configuration = spanner_configuration 597 self.base_labels = { 598 monitoring_infos.SERVICE_LABEL: 'Spanner', 599 monitoring_infos.METHOD_LABEL: 'Read', 600 monitoring_infos.SPANNER_PROJECT_ID: ( 601 self._spanner_configuration.project), 602 monitoring_infos.SPANNER_DATABASE_ID: ( 603 self._spanner_configuration.database), 604 } 605 self.service_metric = None 606 607 def _table_metric(self, table_id): 608 database_id = self._spanner_configuration.database 609 project_id = self._spanner_configuration.project 610 resource = resource_identifiers.SpannerTable( 611 project_id, database_id, table_id) 612 labels = { 613 **self.base_labels, 614 monitoring_infos.RESOURCE_LABEL: resource, 615 monitoring_infos.SPANNER_TABLE_ID: table_id 616 } 617 service_call_metric = ServiceCallMetric( 618 request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, 619 base_labels=labels) 620 return service_call_metric 621 622 def _query_metric(self, query_name): 623 project_id = self._spanner_configuration.project 624 resource = resource_identifiers.SpannerSqlQuery(project_id, query_name) 625 labels = { 626 **self.base_labels, 627 monitoring_infos.RESOURCE_LABEL: resource, 628 monitoring_infos.SPANNER_QUERY_NAME: query_name 629 } 630 service_call_metric = ServiceCallMetric( 631 request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, 632 base_labels=labels) 633 return service_call_metric 634 635 def setup(self): 636 spanner_client = Client(self._spanner_configuration.project) 637 instance = spanner_client.instance(self._spanner_configuration.instance) 638 self._database = instance.database( 639 self._spanner_configuration.database, 640 pool=self._spanner_configuration.pool) 641 self._snapshot = self._database.batch_snapshot( 642 **self._spanner_configuration.snapshot_options) 643 644 def process(self, element): 645 self._snapshot = BatchSnapshot.from_dict( 646 self._database, element['transaction_info']) 647 648 table_id = self._spanner_configuration.table 649 query_name = self._spanner_configuration.query_name or '' 650 651 if element['is_sql'] is True: 652 read_action = self._snapshot.process_query_batch 653 self.service_metric = self._query_metric(query_name) 654 elif element['is_table'] is True: 655 read_action = self._snapshot.process_read_batch 656 self.service_metric = self._table_metric(table_id) 657 else: 658 raise ValueError( 659 "ReadOperation is improperly configure: %s" % str(element)) 660 661 try: 662 for row in read_action(element['partitions']): 663 yield row 664 665 self.service_metric.call('ok') 666 except (ClientError, GoogleAPICallError) as e: 667 self.service_metric(str(e.code.value)) 668 raise 669 except HttpError as e: 670 self.service_metric(str(e)) 671 raise 672 673 def teardown(self): 674 if self._snapshot: 675 self._snapshot.close() 676 677 678 class ReadFromSpanner(PTransform): 679 """ 680 A PTransform to perform reads from cloud spanner. 681 ReadFromSpanner uses BatchAPI to perform all read operations. 682 """ 683 684 def __init__(self, project_id, instance_id, database_id, pool=None, 685 read_timestamp=None, exact_staleness=None, credentials=None, 686 sql=None, params=None, param_types=None, # with_query 687 table=None, query_name=None, columns=None, index="", 688 keyset=None, # with_table 689 read_operations=None, # for read all 690 transaction=None 691 ): 692 """ 693 A PTransform that uses Spanner Batch API to perform reads. 694 695 Args: 696 project_id: Cloud spanner project id. Be sure to use the Project ID, 697 not the Project Number. 698 instance_id: Cloud spanner instance id. 699 database_id: Cloud spanner database id. 700 pool: (optional) session pool to be used by database. If not passed, 701 Spanner Cloud SDK uses the BurstyPool by default. 702 `google.cloud.spanner.BurstyPool`. Ref: 703 https://googleapis.dev/python/spanner/latest/database-api.html?#google. 704 cloud.spanner_v1.database.Database 705 read_timestamp: (optional) An instance of the `datetime.datetime` object 706 to execute all reads at the given timestamp. By default, set to `None`. 707 exact_staleness: (optional) An instance of the `datetime.timedelta` 708 object. These timestamp bounds execute reads at a user-specified 709 timestamp. By default, set to `None`. 710 credentials: (optional) The authorization credentials to attach to 711 requests. These credentials identify this application to the service. 712 If none are specified, the client will attempt to ascertain 713 the credentials from the environment. By default, set to `None`. 714 sql: (optional) SQL query statement. 715 params: (optional) Values for parameter replacement. Keys must match the 716 names used in sql. By default, set to `None`. 717 param_types: (optional) maps explicit types for one or more param values; 718 required if params are passed. By default, set to `None`. 719 table: (optional) Name of the table from which to fetch data. By 720 default, set to `None`. 721 columns: (optional) List of names of columns to be retrieved; required if 722 the table is passed. By default, set to `None`. 723 index: (optional) name of index to use, rather than the table's primary 724 key. By default, set to `None`. 725 keyset: (optional) keys / ranges identifying rows to be retrieved. By 726 default, set to `None`. 727 read_operations: (optional) List of the objects of :class:`ReadOperation` 728 to perform read all. By default, set to `None`. 729 transaction: (optional) PTransform of the :meth:`create_transaction` to 730 perform naive read on cloud spanner. By default, set to `None`. 731 """ 732 self._configuration = _BeamSpannerConfiguration( 733 project=project_id, 734 instance=instance_id, 735 database=database_id, 736 table=table, 737 query_name=query_name, 738 credentials=credentials, 739 pool=pool, 740 snapshot_read_timestamp=read_timestamp, 741 snapshot_exact_staleness=exact_staleness) 742 743 self._read_operations = read_operations 744 self._transaction = transaction 745 746 if self._read_operations is None: 747 if table is not None: 748 if columns is None: 749 raise ValueError("Columns are required with the table name.") 750 self._read_operations = [ 751 ReadOperation.table( 752 table=table, columns=columns, index=index, keyset=keyset) 753 ] 754 elif sql is not None: 755 self._read_operations = [ 756 ReadOperation.query( 757 sql=sql, params=params, param_types=param_types) 758 ] 759 760 def expand(self, pbegin): 761 if self._read_operations is not None and isinstance(pbegin, PBegin): 762 pcoll = pbegin.pipeline | Create(self._read_operations) 763 elif not isinstance(pbegin, PBegin): 764 if self._read_operations is not None: 765 raise ValueError( 766 "Read operation in the constructor only works with " 767 "the root of the pipeline.") 768 pcoll = pbegin 769 else: 770 raise ValueError( 771 "Spanner required read operation, sql or table " 772 "with columns.") 773 774 if self._transaction is None: 775 # reading as batch read using the spanner partitioning query to create 776 # batches. 777 p = ( 778 pcoll 779 | 'Generate Partitions' >> ParDo( 780 _CreateReadPartitions(spanner_configuration=self._configuration)) 781 | 'Reshuffle' >> Reshuffle() 782 | 'Read From Partitions' >> ParDo( 783 _ReadFromPartitionFn(spanner_configuration=self._configuration))) 784 else: 785 # reading as naive read, in which we don't make batches and execute the 786 # queries as a single read. 787 p = ( 788 pcoll 789 | 'Reshuffle' >> Reshuffle().with_input_types(ReadOperation) 790 | 'Perform Read' >> ParDo( 791 _NaiveSpannerReadDoFn(spanner_configuration=self._configuration), 792 AsSingleton(self._transaction))) 793 return p 794 795 def display_data(self): 796 res = {} 797 sql = [] 798 table = [] 799 if self._read_operations is not None: 800 for ro in self._read_operations: 801 if ro.is_sql is True: 802 sql.append(ro.kwargs) 803 elif ro.is_table is True: 804 table.append(ro.kwargs) 805 806 if sql: 807 res['sql'] = DisplayDataItem(str(sql), label='Sql') 808 if table: 809 res['table'] = DisplayDataItem(str(table), label='Table') 810 811 if self._transaction: 812 res['transaction'] = DisplayDataItem( 813 str(self._transaction), label='transaction') 814 815 return res 816 817 818 class WriteToSpanner(PTransform): 819 def __init__( 820 self, 821 project_id, 822 instance_id, 823 database_id, 824 pool=None, 825 credentials=None, 826 max_batch_size_bytes=1048576, 827 max_number_rows=50, 828 max_number_cells=500): 829 """ 830 A PTransform to write onto Google Cloud Spanner. 831 832 Args: 833 project_id: Cloud spanner project id. Be sure to use the Project ID, 834 not the Project Number. 835 instance_id: Cloud spanner instance id. 836 database_id: Cloud spanner database id. 837 max_batch_size_bytes: (optional) Split the mutations into batches to 838 reduce the number of transaction sent to Spanner. By default it is 839 set to 1 MB (1048576 Bytes). 840 max_number_rows: (optional) Split the mutations into batches to 841 reduce the number of transaction sent to Spanner. By default it is 842 set to 50 rows per batch. 843 max_number_cells: (optional) Split the mutations into batches to 844 reduce the number of transaction sent to Spanner. By default it is 845 set to 500 cells per batch. 846 """ 847 self._configuration = _BeamSpannerConfiguration( 848 project=project_id, 849 instance=instance_id, 850 database=database_id, 851 table=None, 852 query_name=None, 853 credentials=credentials, 854 pool=pool, 855 snapshot_read_timestamp=None, 856 snapshot_exact_staleness=None) 857 self._max_batch_size_bytes = max_batch_size_bytes 858 self._max_number_rows = max_number_rows 859 self._max_number_cells = max_number_cells 860 self._database_id = database_id 861 self._project_id = project_id 862 self._instance_id = instance_id 863 self._pool = pool 864 865 def display_data(self): 866 res = { 867 'project_id': DisplayDataItem(self._project_id, label='Project Id'), 868 'instance_id': DisplayDataItem(self._instance_id, label='Instance Id'), 869 'pool': DisplayDataItem(str(self._pool), label='Pool'), 870 'database': DisplayDataItem(self._database_id, label='Database'), 871 'batch_size': DisplayDataItem( 872 self._max_batch_size_bytes, label="Batch Size"), 873 'max_number_rows': DisplayDataItem( 874 self._max_number_rows, label="Max Rows"), 875 'max_number_cells': DisplayDataItem( 876 self._max_number_cells, label="Max Cells"), 877 } 878 return res 879 880 def expand(self, pcoll): 881 return ( 882 pcoll 883 | "make batches" >> _WriteGroup( 884 max_batch_size_bytes=self._max_batch_size_bytes, 885 max_number_rows=self._max_number_rows, 886 max_number_cells=self._max_number_cells) 887 | 888 'Writing to spanner' >> ParDo(_WriteToSpannerDoFn(self._configuration))) 889 890 891 class _Mutator(namedtuple('_Mutator', 892 ["mutation", "operation", "kwargs", "rows", "cells"]) 893 ): 894 __slots__ = () 895 896 @property 897 def byte_size(self): 898 if hasattr(self.mutation, '_pb'): 899 # google-cloud-spanner 3.x 900 return self.mutation._pb.ByteSize() 901 else: 902 # google-cloud-spanner 1.x 903 return self.mutation.ByteSize() 904 905 906 class MutationGroup(deque): 907 """ 908 A Bundle of Spanner Mutations (_Mutator). 909 """ 910 @property 911 def info(self): 912 cells = 0 913 rows = 0 914 bytes = 0 915 for m in self.__iter__(): 916 bytes += m.byte_size 917 rows += m.rows 918 cells += m.cells 919 return {"rows": rows, "cells": cells, "byte_size": bytes} 920 921 def primary(self): 922 return next(self.__iter__()) 923 924 925 class WriteMutation(object): 926 927 _OPERATION_DELETE = "delete" 928 _OPERATION_INSERT = "insert" 929 _OPERATION_INSERT_OR_UPDATE = "insert_or_update" 930 _OPERATION_REPLACE = "replace" 931 _OPERATION_UPDATE = "update" 932 933 def __init__( 934 self, 935 insert=None, 936 update=None, 937 insert_or_update=None, 938 replace=None, 939 delete=None, 940 columns=None, 941 values=None, 942 keyset=None): 943 """ 944 A convenient class to create Spanner Mutations for Write. User can provide 945 the operation via constructor or via static methods. 946 947 Note: If a user passing the operation via construction, make sure that it 948 will only accept one operation at a time. For example, if a user passing 949 a table name in the `insert` parameter, and he also passes the `update` 950 parameter value, this will cause an error. 951 952 Args: 953 insert: (Optional) Name of the table in which rows will be inserted. 954 update: (Optional) Name of the table in which existing rows will be 955 updated. 956 insert_or_update: (Optional) Table name in which rows will be written. 957 Like insert, except that if the row already exists, then its column 958 values are overwritten with the ones provided. Any column values not 959 explicitly written are preserved. 960 replace: (Optional) Table name in which rows will be replaced. Like 961 insert, except that if the row already exists, it is deleted, and the 962 column values provided are inserted instead. Unlike `insert_or_update`, 963 this means any values not explicitly written become `NULL`. 964 delete: (Optional) Table name from which rows will be deleted. Succeeds 965 whether or not the named rows were present. 966 columns: The names of the columns in table to be written. The list of 967 columns must contain enough columns to allow Cloud Spanner to derive 968 values for all primary key columns in the row(s) to be modified. 969 values: The values to be written. `values` can contain more than one 970 list of values. If it does, then multiple rows are written, one for 971 each entry in `values`. Each list in `values` must have exactly as 972 many entries as there are entries in columns above. Sending multiple 973 lists is equivalent to sending multiple Mutations, each containing one 974 `values` entry and repeating table and columns. 975 keyset: (Optional) The primary keys of the rows within table to delete. 976 Delete is idempotent. The transaction will succeed even if some or 977 all rows do not exist. 978 """ 979 self._columns = columns 980 self._values = values 981 self._keyset = keyset 982 983 self._insert = insert 984 self._update = update 985 self._insert_or_update = insert_or_update 986 self._replace = replace 987 self._delete = delete 988 989 if sum([1 for x in [self._insert, 990 self._update, 991 self._insert_or_update, 992 self._replace, 993 self._delete] if x is not None]) != 1: 994 raise ValueError( 995 "No or more than one write mutation operation " 996 "provided: <%s: %s>" % (self.__class__.__name__, str(self.__dict__))) 997 998 def __call__(self, *args, **kwargs): 999 if self._insert is not None: 1000 return WriteMutation.insert( 1001 table=self._insert, columns=self._columns, values=self._values) 1002 elif self._update is not None: 1003 return WriteMutation.update( 1004 table=self._update, columns=self._columns, values=self._values) 1005 elif self._insert_or_update is not None: 1006 return WriteMutation.insert_or_update( 1007 table=self._insert_or_update, 1008 columns=self._columns, 1009 values=self._values) 1010 elif self._replace is not None: 1011 return WriteMutation.replace( 1012 table=self._replace, columns=self._columns, values=self._values) 1013 elif self._delete is not None: 1014 return WriteMutation.delete(table=self._delete, keyset=self._keyset) 1015 1016 @staticmethod 1017 def insert(table, columns, values): 1018 """Insert one or more new table rows. 1019 1020 Args: 1021 table: Name of the table to be modified. 1022 columns: Name of the table columns to be modified. 1023 values: Values to be modified. 1024 """ 1025 rows = len(values) 1026 cells = len(columns) * len(values) 1027 return _Mutator( 1028 mutation=Mutation(insert=batch._make_write_pb(table, columns, values)), 1029 operation=WriteMutation._OPERATION_INSERT, 1030 rows=rows, 1031 cells=cells, 1032 kwargs={ 1033 "table": table, "columns": columns, "values": values 1034 }) 1035 1036 @staticmethod 1037 def update(table, columns, values): 1038 """Update one or more existing table rows. 1039 1040 Args: 1041 table: Name of the table to be modified. 1042 columns: Name of the table columns to be modified. 1043 values: Values to be modified. 1044 """ 1045 rows = len(values) 1046 cells = len(columns) * len(values) 1047 return _Mutator( 1048 mutation=Mutation(update=batch._make_write_pb(table, columns, values)), 1049 operation=WriteMutation._OPERATION_UPDATE, 1050 rows=rows, 1051 cells=cells, 1052 kwargs={ 1053 "table": table, "columns": columns, "values": values 1054 }) 1055 1056 @staticmethod 1057 def insert_or_update(table, columns, values): 1058 """Insert/update one or more table rows. 1059 Args: 1060 table: Name of the table to be modified. 1061 columns: Name of the table columns to be modified. 1062 values: Values to be modified. 1063 """ 1064 rows = len(values) 1065 cells = len(columns) * len(values) 1066 return _Mutator( 1067 mutation=Mutation( 1068 insert_or_update=batch._make_write_pb(table, columns, values)), 1069 operation=WriteMutation._OPERATION_INSERT_OR_UPDATE, 1070 rows=rows, 1071 cells=cells, 1072 kwargs={ 1073 "table": table, "columns": columns, "values": values 1074 }) 1075 1076 @staticmethod 1077 def replace(table, columns, values): 1078 """Replace one or more table rows. 1079 1080 Args: 1081 table: Name of the table to be modified. 1082 columns: Name of the table columns to be modified. 1083 values: Values to be modified. 1084 """ 1085 rows = len(values) 1086 cells = len(columns) * len(values) 1087 return _Mutator( 1088 mutation=Mutation(replace=batch._make_write_pb(table, columns, values)), 1089 operation=WriteMutation._OPERATION_REPLACE, 1090 rows=rows, 1091 cells=cells, 1092 kwargs={ 1093 "table": table, "columns": columns, "values": values 1094 }) 1095 1096 @staticmethod 1097 def delete(table, keyset): 1098 """Delete one or more table rows. 1099 1100 Args: 1101 table: Name of the table to be modified. 1102 keyset: Keys/ranges identifying rows to delete. 1103 """ 1104 delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) 1105 return _Mutator( 1106 mutation=Mutation(delete=delete), 1107 rows=0, 1108 cells=0, 1109 operation=WriteMutation._OPERATION_DELETE, 1110 kwargs={ 1111 "table": table, "keyset": keyset 1112 }) 1113 1114 1115 @with_input_types(typing.Union[MutationGroup, TaggedOutput]) 1116 @with_output_types(MutationGroup) 1117 class _BatchFn(DoFn): 1118 """ 1119 Batches mutations together. 1120 """ 1121 def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): 1122 self._max_batch_size_bytes = max_batch_size_bytes 1123 self._max_number_rows = max_number_rows 1124 self._max_number_cells = max_number_cells 1125 1126 def start_bundle(self): 1127 self._batch = MutationGroup() 1128 self._size_in_bytes = 0 1129 self._rows = 0 1130 self._cells = 0 1131 1132 def _reset_count(self): 1133 self._batch = MutationGroup() 1134 self._size_in_bytes = 0 1135 self._rows = 0 1136 self._cells = 0 1137 1138 def process(self, element): 1139 mg_info = element.info 1140 1141 if mg_info['byte_size'] + self._size_in_bytes > self._max_batch_size_bytes \ 1142 or mg_info['cells'] + self._cells > self._max_number_cells \ 1143 or mg_info['rows'] + self._rows > self._max_number_rows: 1144 # Batch is full, output the batch and resetting the count. 1145 if self._batch: 1146 yield self._batch 1147 self._reset_count() 1148 1149 self._batch.extend(element) 1150 1151 # total byte size of the mutation group. 1152 self._size_in_bytes += mg_info['byte_size'] 1153 1154 # total rows in the mutation group. 1155 self._rows += mg_info['rows'] 1156 1157 # total cells in the mutation group. 1158 self._cells += mg_info['cells'] 1159 1160 def finish_bundle(self): 1161 if self._batch is not None: 1162 yield window.GlobalWindows.windowed_value(self._batch) 1163 self._batch = None 1164 1165 1166 @with_input_types(MutationGroup) 1167 @with_output_types(MutationGroup) 1168 class _BatchableFilterFn(DoFn): 1169 """ 1170 Filters MutationGroups larger than the batch size to the output tagged with 1171 OUTPUT_TAG_UNBATCHABLE. 1172 """ 1173 OUTPUT_TAG_UNBATCHABLE = 'unbatchable' 1174 1175 def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): 1176 self._max_batch_size_bytes = max_batch_size_bytes 1177 self._max_number_rows = max_number_rows 1178 self._max_number_cells = max_number_cells 1179 self._batchable = None 1180 self._unbatchable = None 1181 1182 def process(self, element): 1183 if element.primary().operation == WriteMutation._OPERATION_DELETE: 1184 # As delete mutations are not batchable. 1185 yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element) 1186 else: 1187 mg_info = element.info 1188 if mg_info['byte_size'] > self._max_batch_size_bytes \ 1189 or mg_info['cells'] > self._max_number_cells \ 1190 or mg_info['rows'] > self._max_number_rows: 1191 yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element) 1192 else: 1193 yield element 1194 1195 1196 class _WriteToSpannerDoFn(DoFn): 1197 def __init__(self, spanner_configuration): 1198 self._spanner_configuration = spanner_configuration 1199 self._db_instance = None 1200 self.batches = Metrics.counter(self.__class__, 'SpannerBatches') 1201 self.base_labels = { 1202 monitoring_infos.SERVICE_LABEL: 'Spanner', 1203 monitoring_infos.METHOD_LABEL: 'Write', 1204 monitoring_infos.SPANNER_PROJECT_ID: spanner_configuration.project, 1205 monitoring_infos.SPANNER_DATABASE_ID: spanner_configuration.database, 1206 } 1207 # table_id to metrics 1208 self.service_metrics = {} 1209 1210 def _register_table_metric(self, table_id): 1211 if table_id in self.service_metrics: 1212 return 1213 database_id = self._spanner_configuration.database 1214 project_id = self._spanner_configuration.project 1215 resource = resource_identifiers.SpannerTable( 1216 project_id, database_id, table_id) 1217 labels = { 1218 **self.base_labels, 1219 monitoring_infos.RESOURCE_LABEL: resource, 1220 monitoring_infos.SPANNER_TABLE_ID: table_id 1221 } 1222 service_call_metric = ServiceCallMetric( 1223 request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, 1224 base_labels=labels) 1225 self.service_metrics[table_id] = service_call_metric 1226 1227 def setup(self): 1228 spanner_client = Client(self._spanner_configuration.project) 1229 instance = spanner_client.instance(self._spanner_configuration.instance) 1230 self._db_instance = instance.database( 1231 self._spanner_configuration.database, 1232 pool=self._spanner_configuration.pool) 1233 1234 def start_bundle(self): 1235 self.service_metrics = {} 1236 1237 def process(self, element): 1238 self.batches.inc() 1239 try: 1240 with self._db_instance.batch() as b: 1241 for m in element: 1242 table_id = m.kwargs['table'] 1243 self._register_table_metric(table_id) 1244 1245 if m.operation == WriteMutation._OPERATION_DELETE: 1246 batch_func = b.delete 1247 elif m.operation == WriteMutation._OPERATION_REPLACE: 1248 batch_func = b.replace 1249 elif m.operation == WriteMutation._OPERATION_INSERT_OR_UPDATE: 1250 batch_func = b.insert_or_update 1251 elif m.operation == WriteMutation._OPERATION_INSERT: 1252 batch_func = b.insert 1253 elif m.operation == WriteMutation._OPERATION_UPDATE: 1254 batch_func = b.update 1255 else: 1256 raise ValueError("Unknown operation action: %s" % m.operation) 1257 batch_func(**m.kwargs) 1258 except (ClientError, GoogleAPICallError) as e: 1259 for service_metric in self.service_metrics.values(): 1260 service_metric.call(str(e.code.value)) 1261 raise 1262 except HttpError as e: 1263 for service_metric in self.service_metrics.values(): 1264 service_metric.call(str(e)) 1265 raise 1266 else: 1267 for service_metric in self.service_metrics.values(): 1268 service_metric.call('ok') 1269 1270 1271 @with_input_types(typing.Union[MutationGroup, _Mutator]) 1272 @with_output_types(MutationGroup) 1273 class _MakeMutationGroupsFn(DoFn): 1274 """ 1275 Make Mutation group object if the element is the instance of _Mutator. 1276 """ 1277 def process(self, element): 1278 if isinstance(element, MutationGroup): 1279 yield element 1280 elif isinstance(element, _Mutator): 1281 yield MutationGroup([element]) 1282 else: 1283 raise ValueError( 1284 "Invalid object type: %s. Object must be an instance of " 1285 "MutationGroup or WriteMutations" % str(element)) 1286 1287 1288 class _WriteGroup(PTransform): 1289 def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): 1290 self._max_batch_size_bytes = max_batch_size_bytes 1291 self._max_number_rows = max_number_rows 1292 self._max_number_cells = max_number_cells 1293 1294 def expand(self, pcoll): 1295 filter_batchable_mutations = ( 1296 pcoll 1297 | 'Making mutation groups' >> ParDo(_MakeMutationGroupsFn()) 1298 | 'Filtering Batchable Mutations' >> ParDo( 1299 _BatchableFilterFn( 1300 max_batch_size_bytes=self._max_batch_size_bytes, 1301 max_number_rows=self._max_number_rows, 1302 max_number_cells=self._max_number_cells)).with_outputs( 1303 _BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, main='batchable') 1304 ) 1305 1306 batching_batchables = ( 1307 filter_batchable_mutations['batchable'] 1308 | ParDo( 1309 _BatchFn( 1310 max_batch_size_bytes=self._max_batch_size_bytes, 1311 max_number_rows=self._max_number_rows, 1312 max_number_cells=self._max_number_cells))) 1313 1314 return (( 1315 batching_batchables, 1316 filter_batchable_mutations[_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE]) 1317 | 'Merging batchable and unbatchable' >> Flatten())