github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/experimental/spannerio.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Google Cloud Spanner IO
    19  
    20  Experimental; no backwards-compatibility guarantees.
    21  
    22  This is an experimental module for reading and writing data from Google Cloud
    23  Spanner. Visit: https://cloud.google.com/spanner for more details.
    24  
    25  Reading Data from Cloud Spanner.
    26  
    27  To read from Cloud Spanner apply ReadFromSpanner transformation. It will
    28  return a PCollection, where each element represents an individual row returned
    29  from the read operation. Both Query and Read APIs are supported.
    30  
    31  ReadFromSpanner relies on the ReadOperation objects which is exposed by the
    32  SpannerIO API. ReadOperation holds the immutable data which is responsible to
    33  execute batch and naive reads on Cloud Spanner. This is done for more
    34  convenient programming.
    35  
    36  ReadFromSpanner reads from Cloud Spanner by providing either an 'sql' param
    37  in the constructor or 'table' name with 'columns' as list. For example:::
    38  
    39    records = (pipeline
    40              | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
    41              sql='Select * from users'))
    42  
    43    records = (pipeline
    44              | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
    45              table='users', columns=['id', 'name', 'email']))
    46  
    47  You can also perform multiple reads by providing a list of ReadOperations
    48  to the ReadFromSpanner transform constructor. ReadOperation exposes two static
    49  methods. Use 'query' to perform sql based reads, 'table' to perform read from
    50  table name. For example:::
    51  
    52    read_operations = [
    53                        ReadOperation.table(table='customers', columns=['name',
    54                        'email']),
    55                        ReadOperation.table(table='vendors', columns=['name',
    56                        'email']),
    57                      ]
    58    all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
    59          read_operations=read_operations)
    60  
    61    ...OR...
    62  
    63    read_operations = [
    64                        ReadOperation.query(sql='Select name, email from
    65                        customers'),
    66                        ReadOperation.query(
    67                          sql='Select * from users where id <= @user_id',
    68                          params={'user_id': 100},
    69                          params_type={'user_id': param_types.INT64}
    70                        ),
    71                      ]
    72    # `params_types` are instance of `google.cloud.spanner.param_types`
    73    all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
    74          read_operations=read_operations)
    75  
    76  For more information, please review the docs on class ReadOperation.
    77  
    78  User can also able to provide the ReadOperation in form of PCollection via
    79  pipeline. For example:::
    80  
    81    users = (pipeline
    82             | beam.Create([ReadOperation...])
    83             | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME))
    84  
    85  User may also create cloud spanner transaction from the transform called
    86  `create_transaction` which is available in the SpannerIO API.
    87  
    88  The transform is guaranteed to be executed on a consistent snapshot of data,
    89  utilizing the power of read only transactions. Staleness of data can be
    90  controlled by providing the `read_timestamp` or `exact_staleness` param values
    91  in the constructor.
    92  
    93  This transform requires root of the pipeline (PBegin) and returns PTransform
    94  which is passed later to the `ReadFromSpanner` constructor. `ReadFromSpanner`
    95  pass this transaction PTransform as a singleton side input to the
    96  `_NaiveSpannerReadDoFn` containing 'session_id' and 'transaction_id'.
    97  For example:::
    98  
    99    transaction = (pipeline | create_transaction(TEST_PROJECT_ID,
   100                                                TEST_INSTANCE_ID,
   101                                                DB_NAME))
   102  
   103    users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
   104          sql='Select * from users', transaction=transaction)
   105  
   106    tweets = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
   107          sql='Select * from tweets', transaction=transaction)
   108  
   109  For further details of this transform, please review the docs on the
   110  :meth:`create_transaction` method available in the SpannerIO API.
   111  
   112  ReadFromSpanner takes this transform in the constructor and pass this to the
   113  read pipeline as the singleton side input.
   114  
   115  Writing Data to Cloud Spanner.
   116  
   117  The WriteToSpanner transform writes to Cloud Spanner by executing a
   118  collection a input rows (WriteMutation). The mutations are grouped into
   119  batches for efficiency.
   120  
   121  WriteToSpanner transform relies on the WriteMutation objects which is exposed
   122  by the SpannerIO API. WriteMutation have five static methods (insert, update,
   123  insert_or_update, replace, delete). These methods returns the instance of the
   124  _Mutator object which contains the mutation type and the Spanner Mutation
   125  object. For more details, review the docs of the class SpannerIO.WriteMutation.
   126  For example:::
   127  
   128    mutations = [
   129                  WriteMutation.insert(table='user', columns=('name', 'email'),
   130                  values=[('sara', 'sara@dev.com')])
   131                ]
   132    _ = (p
   133         | beam.Create(mutations)
   134         | WriteToSpanner(
   135            project_id=SPANNER_PROJECT_ID,
   136            instance_id=SPANNER_INSTANCE_ID,
   137            database_id=SPANNER_DATABASE_NAME)
   138          )
   139  
   140  You can also create WriteMutation via calling its constructor. For example:::
   141  
   142    mutations = [
   143        WriteMutation(insert='users', columns=('name', 'email'),
   144                      values=[('sara", 'sara@example.com')])
   145    ]
   146  
   147  For more information, review the docs available on WriteMutation class.
   148  
   149  WriteToSpanner transform also takes three batching parameters (max_number_rows,
   150  max_number_cells and max_batch_size_bytes). By default, max_number_rows is set
   151  to 50 rows, max_number_cells is set to 500 cells and max_batch_size_bytes is
   152  set to 1MB (1048576 bytes). These parameter used to reduce the number of
   153  transactions sent to spanner by grouping the mutation into batches. Setting
   154  these param values either to smaller value or zero to disable batching.
   155  Unlike the Java connector, this connector does not create batches of
   156  transactions sorted by table and primary key.
   157  
   158  WriteToSpanner transforms starts with the grouping into batches. The first step
   159  in this process is to make the mutation groups of the WriteMutation
   160  objects and then filtering them into batchable and unbatchable mutation
   161  groups. There are three batching parameters (max_number_cells, max_number_rows
   162  & max_batch_size_bytes). We calculated th mutation byte size from the method
   163  available in the `google.cloud.spanner_v1.proto.mutation_pb2.Mutation.ByteSize`.
   164  if the mutation rows, cells or byte size are larger than value of the any
   165  batching parameters param, it will be tagged as "unbatchable" mutation. After
   166  this all the batchable mutation are merged into a single mutation group whos
   167  size is not larger than the "max_batch_size_bytes", after this process, all the
   168  mutation groups together to process. If the Mutation references a table or
   169  column does not exits, it will cause a exception and fails the entire pipeline.
   170  """
   171  import typing
   172  from collections import deque
   173  from collections import namedtuple
   174  
   175  from apache_beam import Create
   176  from apache_beam import DoFn
   177  from apache_beam import Flatten
   178  from apache_beam import ParDo
   179  from apache_beam import Reshuffle
   180  from apache_beam.internal.metrics.metric import ServiceCallMetric
   181  from apache_beam.io.gcp import resource_identifiers
   182  from apache_beam.metrics import Metrics
   183  from apache_beam.metrics import monitoring_infos
   184  from apache_beam.pvalue import AsSingleton
   185  from apache_beam.pvalue import PBegin
   186  from apache_beam.pvalue import TaggedOutput
   187  from apache_beam.transforms import PTransform
   188  from apache_beam.transforms import ptransform_fn
   189  from apache_beam.transforms import window
   190  from apache_beam.transforms.display import DisplayDataItem
   191  from apache_beam.typehints import with_input_types
   192  from apache_beam.typehints import with_output_types
   193  
   194  # Protect against environments where spanner library is not available.
   195  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
   196  # pylint: disable=unused-import
   197  try:
   198    from google.cloud.spanner import Client
   199    from google.cloud.spanner import KeySet
   200    from google.cloud.spanner_v1 import batch
   201    from google.cloud.spanner_v1.database import BatchSnapshot
   202    from google.api_core.exceptions import ClientError, GoogleAPICallError
   203    from apitools.base.py.exceptions import HttpError
   204  except ImportError:
   205    Client = None
   206    KeySet = None
   207    BatchSnapshot = None
   208  
   209  try:
   210    from google.cloud.spanner_v1 import Mutation
   211  except ImportError:
   212    try:
   213      # Remove this and the try clause when we upgrade to google-cloud-spanner
   214      # 3.x.x.
   215      from google.cloud.spanner_v1.proto.mutation_pb2 import Mutation
   216    except ImportError:
   217      # Ignoring for environments where the Spanner library is not available.
   218      pass
   219  
   220  __all__ = [
   221      'create_transaction',
   222      'ReadFromSpanner',
   223      'ReadOperation',
   224      'WriteToSpanner',
   225      'WriteMutation',
   226      'MutationGroup'
   227  ]
   228  
   229  
   230  class _SPANNER_TRANSACTION(namedtuple("SPANNER_TRANSACTION", ["transaction"])):
   231    """
   232    Holds the spanner transaction details.
   233    """
   234  
   235    __slots__ = ()
   236  
   237  
   238  class ReadOperation(namedtuple(
   239      "ReadOperation", ["is_sql", "is_table", "read_operation", "kwargs"])):
   240    """
   241    Encapsulates a spanner read operation.
   242    """
   243  
   244    __slots__ = ()
   245  
   246    @classmethod
   247    def query(cls, sql, params=None, param_types=None):
   248      """
   249      A convenient method to construct ReadOperation from sql query.
   250  
   251      Args:
   252        sql: SQL query statement
   253        params: (optional) values for parameter replacement. Keys must match the
   254          names used in sql
   255        param_types: (optional) maps explicit types for one or more param values;
   256          required if parameters are passed.
   257      """
   258  
   259      if params:
   260        assert param_types is not None
   261  
   262      return cls(
   263          is_sql=True,
   264          is_table=False,
   265          read_operation="process_query_batch",
   266          kwargs={
   267              'sql': sql, 'params': params, 'param_types': param_types
   268          })
   269  
   270    @classmethod
   271    def table(cls, table, columns, index="", keyset=None):
   272      """
   273      A convenient method to construct ReadOperation from table.
   274  
   275      Args:
   276        table: name of the table from which to fetch data.
   277        columns: names of columns to be retrieved.
   278        index: (optional) name of index to use, rather than the table's primary
   279          key.
   280        keyset: (optional) `KeySet` keys / ranges identifying rows to be
   281          retrieved.
   282      """
   283      keyset = keyset or KeySet(all_=True)
   284      if not isinstance(keyset, KeySet):
   285        raise ValueError(
   286            "keyset must be an instance of class "
   287            "google.cloud.spanner.KeySet")
   288      return cls(
   289          is_sql=False,
   290          is_table=True,
   291          read_operation="process_read_batch",
   292          kwargs={
   293              'table': table,
   294              'columns': columns,
   295              'index': index,
   296              'keyset': keyset
   297          })
   298  
   299  
   300  class _BeamSpannerConfiguration(namedtuple("_BeamSpannerConfiguration",
   301                                             ["project",
   302                                              "instance",
   303                                              "database",
   304                                              "table",
   305                                              "query_name",
   306                                              "credentials",
   307                                              "pool",
   308                                              "snapshot_read_timestamp",
   309                                              "snapshot_exact_staleness"])):
   310    """
   311    A namedtuple holds the immutable data of the connection string to the cloud
   312    spanner.
   313    """
   314    @property
   315    def snapshot_options(self):
   316      snapshot_options = {}
   317      if self.snapshot_exact_staleness:
   318        snapshot_options['exact_staleness'] = self.snapshot_exact_staleness
   319      if self.snapshot_read_timestamp:
   320        snapshot_options['read_timestamp'] = self.snapshot_read_timestamp
   321      return snapshot_options
   322  
   323  
   324  @with_input_types(ReadOperation, _SPANNER_TRANSACTION)
   325  @with_output_types(typing.List[typing.Any])
   326  class _NaiveSpannerReadDoFn(DoFn):
   327    def __init__(self, spanner_configuration):
   328      """
   329      A naive version of Spanner read which uses the transaction API of the
   330      cloud spanner.
   331      https://googleapis.dev/python/spanner/latest/transaction-api.html
   332      In Naive reads, this transform performs single reads, where as the
   333      Batch reads use the spanner partitioning query to create batches.
   334  
   335      Args:
   336        spanner_configuration: (_BeamSpannerConfiguration) Connection details to
   337          connect with cloud spanner.
   338      """
   339      self._spanner_configuration = spanner_configuration
   340      self._snapshot = None
   341      self._session = None
   342      self.base_labels = {
   343          monitoring_infos.SERVICE_LABEL: 'Spanner',
   344          monitoring_infos.METHOD_LABEL: 'Read',
   345          monitoring_infos.SPANNER_PROJECT_ID: (
   346              self._spanner_configuration.project),
   347          monitoring_infos.SPANNER_DATABASE_ID: (
   348              self._spanner_configuration.database),
   349      }
   350  
   351    def _table_metric(self, table_id, status):
   352      database_id = self._spanner_configuration.database
   353      project_id = self._spanner_configuration.project
   354      resource = resource_identifiers.SpannerTable(
   355          project_id, database_id, table_id)
   356      labels = {
   357          **self.base_labels,
   358          monitoring_infos.RESOURCE_LABEL: resource,
   359          monitoring_infos.SPANNER_TABLE_ID: table_id
   360      }
   361      service_call_metric = ServiceCallMetric(
   362          request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
   363          base_labels=labels)
   364      service_call_metric.call(str(status))
   365  
   366    def _query_metric(self, query_name, status):
   367      project_id = self._spanner_configuration.project
   368      resource = resource_identifiers.SpannerSqlQuery(project_id, query_name)
   369      labels = {
   370          **self.base_labels,
   371          monitoring_infos.RESOURCE_LABEL: resource,
   372          monitoring_infos.SPANNER_QUERY_NAME: query_name
   373      }
   374      service_call_metric = ServiceCallMetric(
   375          request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
   376          base_labels=labels)
   377      service_call_metric.call(str(status))
   378  
   379    def _get_session(self):
   380      if self._session is None:
   381        session = self._session = self._database.session()
   382        session.create()
   383      return self._session
   384  
   385    def _close_session(self):
   386      if self._session is not None:
   387        self._session.delete()
   388  
   389    def setup(self):
   390      # setting up client to connect with cloud spanner
   391      spanner_client = Client(self._spanner_configuration.project)
   392      instance = spanner_client.instance(self._spanner_configuration.instance)
   393      self._database = instance.database(
   394          self._spanner_configuration.database,
   395          pool=self._spanner_configuration.pool)
   396  
   397    def process(self, element, spanner_transaction):
   398      # `spanner_transaction` should be the instance of the _SPANNER_TRANSACTION
   399      # object.
   400      if not isinstance(spanner_transaction, _SPANNER_TRANSACTION):
   401        raise ValueError(
   402            "Invalid transaction object: %s. It should be instance "
   403            "of SPANNER_TRANSACTION object created by "
   404            "spannerio.create_transaction transform." % type(spanner_transaction))
   405  
   406      transaction_info = spanner_transaction.transaction
   407  
   408      # We used batch snapshot to reuse the same transaction passed through the
   409      # side input
   410      self._snapshot = BatchSnapshot.from_dict(self._database, transaction_info)
   411  
   412      # getting the transaction from the snapshot's session to run read operation.
   413      # with self._snapshot.session().transaction() as transaction:
   414      with self._get_session().transaction() as transaction:
   415        table_id = self._spanner_configuration.table
   416        query_name = self._spanner_configuration.query_name or ''
   417  
   418        if element.is_sql is True:
   419          transaction_read = transaction.execute_sql
   420          metric_action = self._query_metric
   421          metric_id = query_name
   422        elif element.is_table is True:
   423          transaction_read = transaction.read
   424          metric_action = self._table_metric
   425          metric_id = table_id
   426        else:
   427          raise ValueError(
   428              "ReadOperation is improperly configure: %s" % str(element))
   429  
   430        try:
   431          for row in transaction_read(**element.kwargs):
   432            yield row
   433  
   434          metric_action(metric_id, 'ok')
   435        except (ClientError, GoogleAPICallError) as e:
   436          metric_action(metric_id, e.code.value)
   437          raise
   438        except HttpError as e:
   439          metric_action(metric_id, e)
   440          raise
   441  
   442  
   443  @with_input_types(ReadOperation)
   444  @with_output_types(typing.Dict[typing.Any, typing.Any])
   445  class _CreateReadPartitions(DoFn):
   446    """
   447    A DoFn to create partitions. Uses the Partitioning API (PartitionRead /
   448    PartitionQuery) request to start a partitioned query operation. Returns a
   449    list of batch information needed to perform the actual queries.
   450  
   451    If the element is the instance of :class:`ReadOperation` is to perform sql
   452    query, `PartitionQuery` API is used the create partitions and returns mappings
   453    of information used perform actual partitioned reads via
   454    :meth:`process_query_batch`.
   455  
   456    If the element is the instance of :class:`ReadOperation` is to perform read
   457    from table, `PartitionRead` API is used the create partitions and returns
   458    mappings of information used perform actual partitioned reads via
   459    :meth:`process_read_batch`.
   460    """
   461    def __init__(self, spanner_configuration):
   462      self._spanner_configuration = spanner_configuration
   463  
   464    def setup(self):
   465      spanner_client = Client(
   466          project=self._spanner_configuration.project,
   467          credentials=self._spanner_configuration.credentials)
   468      instance = spanner_client.instance(self._spanner_configuration.instance)
   469      self._database = instance.database(
   470          self._spanner_configuration.database,
   471          pool=self._spanner_configuration.pool)
   472      self._snapshot = self._database.batch_snapshot(
   473          **self._spanner_configuration.snapshot_options)
   474      self._snapshot_dict = self._snapshot.to_dict()
   475  
   476    def process(self, element):
   477      if element.is_sql is True:
   478        partitioning_action = self._snapshot.generate_query_batches
   479      elif element.is_table is True:
   480        partitioning_action = self._snapshot.generate_read_batches
   481      else:
   482        raise ValueError(
   483            "ReadOperation is improperly configure: %s" % str(element))
   484  
   485      for p in partitioning_action(**element.kwargs):
   486        yield {
   487            "is_sql": element.is_sql,
   488            "is_table": element.is_table,
   489            "read_operation": element.read_operation,
   490            "partitions": p,
   491            "transaction_info": self._snapshot_dict
   492        }
   493  
   494  
   495  @with_input_types(int)
   496  @with_output_types(_SPANNER_TRANSACTION)
   497  class _CreateTransactionFn(DoFn):
   498    """
   499    A DoFn to create the transaction of cloud spanner.
   500    It connects to the database and and returns the transaction_id and session_id
   501    by using the batch_snapshot.to_dict() method available in the google cloud
   502    spanner sdk.
   503  
   504    https://googleapis.dev/python/spanner/latest/database-api.html?highlight=
   505    batch_snapshot#google.cloud.spanner_v1.database.BatchSnapshot.to_dict
   506    """
   507    def __init__(
   508        self,
   509        project_id,
   510        instance_id,
   511        database_id,
   512        credentials,
   513        pool,
   514        read_timestamp,
   515        exact_staleness):
   516      self._project_id = project_id
   517      self._instance_id = instance_id
   518      self._database_id = database_id
   519      self._credentials = credentials
   520      self._pool = pool
   521  
   522      self._snapshot_options = {}
   523      if read_timestamp:
   524        self._snapshot_options['read_timestamp'] = read_timestamp
   525      if exact_staleness:
   526        self._snapshot_options['exact_staleness'] = exact_staleness
   527      self._snapshot = None
   528  
   529    def setup(self):
   530      self._spanner_client = Client(
   531          project=self._project_id, credentials=self._credentials)
   532      self._instance = self._spanner_client.instance(self._instance_id)
   533      self._database = self._instance.database(self._database_id, pool=self._pool)
   534  
   535    def process(self, element, *args, **kwargs):
   536      self._snapshot = self._database.batch_snapshot(**self._snapshot_options)
   537      return [_SPANNER_TRANSACTION(self._snapshot.to_dict())]
   538  
   539  
   540  @ptransform_fn
   541  def create_transaction(
   542      pbegin,
   543      project_id,
   544      instance_id,
   545      database_id,
   546      credentials=None,
   547      pool=None,
   548      read_timestamp=None,
   549      exact_staleness=None):
   550    """
   551    A PTransform method to create a batch transaction.
   552  
   553    Args:
   554      pbegin: Root of the pipeline
   555      project_id: Cloud spanner project id. Be sure to use the Project ID,
   556        not the Project Number.
   557      instance_id: Cloud spanner instance id.
   558      database_id: Cloud spanner database id.
   559      credentials: (optional) The authorization credentials to attach to requests.
   560        These credentials identify this application to the service.
   561        If none are specified, the client will attempt to ascertain
   562        the credentials from the environment.
   563      pool: (optional) session pool to be used by database. If not passed,
   564        Spanner Cloud SDK uses the BurstyPool by default.
   565        `google.cloud.spanner.BurstyPool`. Ref:
   566        https://googleapis.dev/python/spanner/latest/database-api.html?#google.
   567        cloud.spanner_v1.database.Database
   568      read_timestamp: (optional) An instance of the `datetime.datetime` object to
   569        execute all reads at the given timestamp.
   570      exact_staleness: (optional) An instance of the `datetime.timedelta`
   571        object. These timestamp bounds execute reads at a user-specified
   572        timestamp.
   573    """
   574  
   575    assert isinstance(pbegin, PBegin)
   576  
   577    return (
   578        pbegin | Create([1]) | ParDo(
   579            _CreateTransactionFn(
   580                project_id,
   581                instance_id,
   582                database_id,
   583                credentials,
   584                pool,
   585                read_timestamp,
   586                exact_staleness)))
   587  
   588  
   589  @with_input_types(typing.Dict[typing.Any, typing.Any])
   590  @with_output_types(typing.List[typing.Any])
   591  class _ReadFromPartitionFn(DoFn):
   592    """
   593    A DoFn to perform reads from the partition.
   594    """
   595    def __init__(self, spanner_configuration):
   596      self._spanner_configuration = spanner_configuration
   597      self.base_labels = {
   598          monitoring_infos.SERVICE_LABEL: 'Spanner',
   599          monitoring_infos.METHOD_LABEL: 'Read',
   600          monitoring_infos.SPANNER_PROJECT_ID: (
   601              self._spanner_configuration.project),
   602          monitoring_infos.SPANNER_DATABASE_ID: (
   603              self._spanner_configuration.database),
   604      }
   605      self.service_metric = None
   606  
   607    def _table_metric(self, table_id):
   608      database_id = self._spanner_configuration.database
   609      project_id = self._spanner_configuration.project
   610      resource = resource_identifiers.SpannerTable(
   611          project_id, database_id, table_id)
   612      labels = {
   613          **self.base_labels,
   614          monitoring_infos.RESOURCE_LABEL: resource,
   615          monitoring_infos.SPANNER_TABLE_ID: table_id
   616      }
   617      service_call_metric = ServiceCallMetric(
   618          request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
   619          base_labels=labels)
   620      return service_call_metric
   621  
   622    def _query_metric(self, query_name):
   623      project_id = self._spanner_configuration.project
   624      resource = resource_identifiers.SpannerSqlQuery(project_id, query_name)
   625      labels = {
   626          **self.base_labels,
   627          monitoring_infos.RESOURCE_LABEL: resource,
   628          monitoring_infos.SPANNER_QUERY_NAME: query_name
   629      }
   630      service_call_metric = ServiceCallMetric(
   631          request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
   632          base_labels=labels)
   633      return service_call_metric
   634  
   635    def setup(self):
   636      spanner_client = Client(self._spanner_configuration.project)
   637      instance = spanner_client.instance(self._spanner_configuration.instance)
   638      self._database = instance.database(
   639          self._spanner_configuration.database,
   640          pool=self._spanner_configuration.pool)
   641      self._snapshot = self._database.batch_snapshot(
   642          **self._spanner_configuration.snapshot_options)
   643  
   644    def process(self, element):
   645      self._snapshot = BatchSnapshot.from_dict(
   646          self._database, element['transaction_info'])
   647  
   648      table_id = self._spanner_configuration.table
   649      query_name = self._spanner_configuration.query_name or ''
   650  
   651      if element['is_sql'] is True:
   652        read_action = self._snapshot.process_query_batch
   653        self.service_metric = self._query_metric(query_name)
   654      elif element['is_table'] is True:
   655        read_action = self._snapshot.process_read_batch
   656        self.service_metric = self._table_metric(table_id)
   657      else:
   658        raise ValueError(
   659            "ReadOperation is improperly configure: %s" % str(element))
   660  
   661      try:
   662        for row in read_action(element['partitions']):
   663          yield row
   664  
   665        self.service_metric.call('ok')
   666      except (ClientError, GoogleAPICallError) as e:
   667        self.service_metric(str(e.code.value))
   668        raise
   669      except HttpError as e:
   670        self.service_metric(str(e))
   671        raise
   672  
   673    def teardown(self):
   674      if self._snapshot:
   675        self._snapshot.close()
   676  
   677  
   678  class ReadFromSpanner(PTransform):
   679    """
   680    A PTransform to perform reads from cloud spanner.
   681    ReadFromSpanner uses BatchAPI to perform all read operations.
   682    """
   683  
   684    def __init__(self, project_id, instance_id, database_id, pool=None,
   685                 read_timestamp=None, exact_staleness=None, credentials=None,
   686                 sql=None, params=None, param_types=None,  # with_query
   687                 table=None, query_name=None, columns=None, index="",
   688                 keyset=None,  # with_table
   689                 read_operations=None,  # for read all
   690                 transaction=None
   691                ):
   692      """
   693      A PTransform that uses Spanner Batch API to perform reads.
   694  
   695      Args:
   696        project_id: Cloud spanner project id. Be sure to use the Project ID,
   697          not the Project Number.
   698        instance_id: Cloud spanner instance id.
   699        database_id: Cloud spanner database id.
   700        pool: (optional) session pool to be used by database. If not passed,
   701          Spanner Cloud SDK uses the BurstyPool by default.
   702          `google.cloud.spanner.BurstyPool`. Ref:
   703          https://googleapis.dev/python/spanner/latest/database-api.html?#google.
   704          cloud.spanner_v1.database.Database
   705        read_timestamp: (optional) An instance of the `datetime.datetime` object
   706          to execute all reads at the given timestamp. By default, set to `None`.
   707        exact_staleness: (optional) An instance of the `datetime.timedelta`
   708          object. These timestamp bounds execute reads at a user-specified
   709          timestamp. By default, set to `None`.
   710        credentials: (optional) The authorization credentials to attach to
   711          requests. These credentials identify this application to the service.
   712          If none are specified, the client will attempt to ascertain
   713          the credentials from the environment. By default, set to `None`.
   714        sql: (optional) SQL query statement.
   715        params: (optional) Values for parameter replacement. Keys must match the
   716          names used in sql. By default, set to `None`.
   717        param_types: (optional) maps explicit types for one or more param values;
   718          required if params are passed. By default, set to `None`.
   719        table: (optional) Name of the table from which to fetch data. By
   720          default, set to `None`.
   721        columns: (optional) List of names of columns to be retrieved; required if
   722          the table is passed. By default, set to `None`.
   723        index: (optional) name of index to use, rather than the table's primary
   724          key. By default, set to `None`.
   725        keyset: (optional) keys / ranges identifying rows to be retrieved. By
   726          default, set to `None`.
   727        read_operations: (optional) List of the objects of :class:`ReadOperation`
   728          to perform read all. By default, set to `None`.
   729        transaction: (optional) PTransform of the :meth:`create_transaction` to
   730          perform naive read on cloud spanner. By default, set to `None`.
   731      """
   732      self._configuration = _BeamSpannerConfiguration(
   733          project=project_id,
   734          instance=instance_id,
   735          database=database_id,
   736          table=table,
   737          query_name=query_name,
   738          credentials=credentials,
   739          pool=pool,
   740          snapshot_read_timestamp=read_timestamp,
   741          snapshot_exact_staleness=exact_staleness)
   742  
   743      self._read_operations = read_operations
   744      self._transaction = transaction
   745  
   746      if self._read_operations is None:
   747        if table is not None:
   748          if columns is None:
   749            raise ValueError("Columns are required with the table name.")
   750          self._read_operations = [
   751              ReadOperation.table(
   752                  table=table, columns=columns, index=index, keyset=keyset)
   753          ]
   754        elif sql is not None:
   755          self._read_operations = [
   756              ReadOperation.query(
   757                  sql=sql, params=params, param_types=param_types)
   758          ]
   759  
   760    def expand(self, pbegin):
   761      if self._read_operations is not None and isinstance(pbegin, PBegin):
   762        pcoll = pbegin.pipeline | Create(self._read_operations)
   763      elif not isinstance(pbegin, PBegin):
   764        if self._read_operations is not None:
   765          raise ValueError(
   766              "Read operation in the constructor only works with "
   767              "the root of the pipeline.")
   768        pcoll = pbegin
   769      else:
   770        raise ValueError(
   771            "Spanner required read operation, sql or table "
   772            "with columns.")
   773  
   774      if self._transaction is None:
   775        # reading as batch read using the spanner partitioning query to create
   776        # batches.
   777        p = (
   778            pcoll
   779            | 'Generate Partitions' >> ParDo(
   780                _CreateReadPartitions(spanner_configuration=self._configuration))
   781            | 'Reshuffle' >> Reshuffle()
   782            | 'Read From Partitions' >> ParDo(
   783                _ReadFromPartitionFn(spanner_configuration=self._configuration)))
   784      else:
   785        # reading as naive read, in which we don't make batches and execute the
   786        # queries as a single read.
   787        p = (
   788            pcoll
   789            | 'Reshuffle' >> Reshuffle().with_input_types(ReadOperation)
   790            | 'Perform Read' >> ParDo(
   791                _NaiveSpannerReadDoFn(spanner_configuration=self._configuration),
   792                AsSingleton(self._transaction)))
   793      return p
   794  
   795    def display_data(self):
   796      res = {}
   797      sql = []
   798      table = []
   799      if self._read_operations is not None:
   800        for ro in self._read_operations:
   801          if ro.is_sql is True:
   802            sql.append(ro.kwargs)
   803          elif ro.is_table is True:
   804            table.append(ro.kwargs)
   805  
   806        if sql:
   807          res['sql'] = DisplayDataItem(str(sql), label='Sql')
   808        if table:
   809          res['table'] = DisplayDataItem(str(table), label='Table')
   810  
   811      if self._transaction:
   812        res['transaction'] = DisplayDataItem(
   813            str(self._transaction), label='transaction')
   814  
   815      return res
   816  
   817  
   818  class WriteToSpanner(PTransform):
   819    def __init__(
   820        self,
   821        project_id,
   822        instance_id,
   823        database_id,
   824        pool=None,
   825        credentials=None,
   826        max_batch_size_bytes=1048576,
   827        max_number_rows=50,
   828        max_number_cells=500):
   829      """
   830      A PTransform to write onto Google Cloud Spanner.
   831  
   832      Args:
   833        project_id: Cloud spanner project id. Be sure to use the Project ID,
   834          not the Project Number.
   835        instance_id: Cloud spanner instance id.
   836        database_id: Cloud spanner database id.
   837        max_batch_size_bytes: (optional) Split the mutations into batches to
   838          reduce the number of transaction sent to Spanner. By default it is
   839          set to 1 MB (1048576 Bytes).
   840        max_number_rows: (optional) Split the mutations into batches to
   841          reduce the number of transaction sent to Spanner. By default it is
   842          set to 50 rows per batch.
   843        max_number_cells: (optional) Split the mutations into batches to
   844          reduce the number of transaction sent to Spanner. By default it is
   845          set to 500 cells per batch.
   846      """
   847      self._configuration = _BeamSpannerConfiguration(
   848          project=project_id,
   849          instance=instance_id,
   850          database=database_id,
   851          table=None,
   852          query_name=None,
   853          credentials=credentials,
   854          pool=pool,
   855          snapshot_read_timestamp=None,
   856          snapshot_exact_staleness=None)
   857      self._max_batch_size_bytes = max_batch_size_bytes
   858      self._max_number_rows = max_number_rows
   859      self._max_number_cells = max_number_cells
   860      self._database_id = database_id
   861      self._project_id = project_id
   862      self._instance_id = instance_id
   863      self._pool = pool
   864  
   865    def display_data(self):
   866      res = {
   867          'project_id': DisplayDataItem(self._project_id, label='Project Id'),
   868          'instance_id': DisplayDataItem(self._instance_id, label='Instance Id'),
   869          'pool': DisplayDataItem(str(self._pool), label='Pool'),
   870          'database': DisplayDataItem(self._database_id, label='Database'),
   871          'batch_size': DisplayDataItem(
   872              self._max_batch_size_bytes, label="Batch Size"),
   873          'max_number_rows': DisplayDataItem(
   874              self._max_number_rows, label="Max Rows"),
   875          'max_number_cells': DisplayDataItem(
   876              self._max_number_cells, label="Max Cells"),
   877      }
   878      return res
   879  
   880    def expand(self, pcoll):
   881      return (
   882          pcoll
   883          | "make batches" >> _WriteGroup(
   884              max_batch_size_bytes=self._max_batch_size_bytes,
   885              max_number_rows=self._max_number_rows,
   886              max_number_cells=self._max_number_cells)
   887          |
   888          'Writing to spanner' >> ParDo(_WriteToSpannerDoFn(self._configuration)))
   889  
   890  
   891  class _Mutator(namedtuple('_Mutator',
   892                            ["mutation", "operation", "kwargs", "rows", "cells"])
   893                 ):
   894    __slots__ = ()
   895  
   896    @property
   897    def byte_size(self):
   898      if hasattr(self.mutation, '_pb'):
   899        # google-cloud-spanner 3.x
   900        return self.mutation._pb.ByteSize()
   901      else:
   902        # google-cloud-spanner 1.x
   903        return self.mutation.ByteSize()
   904  
   905  
   906  class MutationGroup(deque):
   907    """
   908    A Bundle of Spanner Mutations (_Mutator).
   909    """
   910    @property
   911    def info(self):
   912      cells = 0
   913      rows = 0
   914      bytes = 0
   915      for m in self.__iter__():
   916        bytes += m.byte_size
   917        rows += m.rows
   918        cells += m.cells
   919      return {"rows": rows, "cells": cells, "byte_size": bytes}
   920  
   921    def primary(self):
   922      return next(self.__iter__())
   923  
   924  
   925  class WriteMutation(object):
   926  
   927    _OPERATION_DELETE = "delete"
   928    _OPERATION_INSERT = "insert"
   929    _OPERATION_INSERT_OR_UPDATE = "insert_or_update"
   930    _OPERATION_REPLACE = "replace"
   931    _OPERATION_UPDATE = "update"
   932  
   933    def __init__(
   934        self,
   935        insert=None,
   936        update=None,
   937        insert_or_update=None,
   938        replace=None,
   939        delete=None,
   940        columns=None,
   941        values=None,
   942        keyset=None):
   943      """
   944      A convenient class to create Spanner Mutations for Write. User can provide
   945      the operation via constructor or via static methods.
   946  
   947      Note: If a user passing the operation via construction, make sure that it
   948      will only accept one operation at a time. For example, if a user passing
   949      a table name in the `insert` parameter, and he also passes the `update`
   950      parameter value, this will cause an error.
   951  
   952      Args:
   953        insert: (Optional) Name of the table in which rows will be inserted.
   954        update: (Optional) Name of the table in which existing rows will be
   955          updated.
   956        insert_or_update: (Optional) Table name in which rows will be written.
   957          Like insert, except that if the row already exists, then its column
   958          values are overwritten with the ones provided. Any column values not
   959          explicitly written are preserved.
   960        replace: (Optional) Table name in which rows will be replaced. Like
   961          insert, except that if the row already exists, it is deleted, and the
   962          column values provided are inserted instead. Unlike `insert_or_update`,
   963          this means any values not explicitly written become `NULL`.
   964        delete: (Optional) Table name from which rows will be deleted. Succeeds
   965          whether or not the named rows were present.
   966        columns: The names of the columns in table to be written. The list of
   967          columns must contain enough columns to allow Cloud Spanner to derive
   968          values for all primary key columns in the row(s) to be modified.
   969        values: The values to be written. `values` can contain more than one
   970          list of values. If it does, then multiple rows are written, one for
   971          each entry in `values`. Each list in `values` must have exactly as
   972          many entries as there are entries in columns above. Sending multiple
   973          lists is equivalent to sending multiple Mutations, each containing one
   974          `values` entry and repeating table and columns.
   975        keyset: (Optional) The primary keys of the rows within table to delete.
   976          Delete is idempotent. The transaction will succeed even if some or
   977          all rows do not exist.
   978      """
   979      self._columns = columns
   980      self._values = values
   981      self._keyset = keyset
   982  
   983      self._insert = insert
   984      self._update = update
   985      self._insert_or_update = insert_or_update
   986      self._replace = replace
   987      self._delete = delete
   988  
   989      if sum([1 for x in [self._insert,
   990                          self._update,
   991                          self._insert_or_update,
   992                          self._replace,
   993                          self._delete] if x is not None]) != 1:
   994        raise ValueError(
   995            "No or more than one write mutation operation "
   996            "provided: <%s: %s>" % (self.__class__.__name__, str(self.__dict__)))
   997  
   998    def __call__(self, *args, **kwargs):
   999      if self._insert is not None:
  1000        return WriteMutation.insert(
  1001            table=self._insert, columns=self._columns, values=self._values)
  1002      elif self._update is not None:
  1003        return WriteMutation.update(
  1004            table=self._update, columns=self._columns, values=self._values)
  1005      elif self._insert_or_update is not None:
  1006        return WriteMutation.insert_or_update(
  1007            table=self._insert_or_update,
  1008            columns=self._columns,
  1009            values=self._values)
  1010      elif self._replace is not None:
  1011        return WriteMutation.replace(
  1012            table=self._replace, columns=self._columns, values=self._values)
  1013      elif self._delete is not None:
  1014        return WriteMutation.delete(table=self._delete, keyset=self._keyset)
  1015  
  1016    @staticmethod
  1017    def insert(table, columns, values):
  1018      """Insert one or more new table rows.
  1019  
  1020      Args:
  1021        table: Name of the table to be modified.
  1022        columns: Name of the table columns to be modified.
  1023        values: Values to be modified.
  1024      """
  1025      rows = len(values)
  1026      cells = len(columns) * len(values)
  1027      return _Mutator(
  1028          mutation=Mutation(insert=batch._make_write_pb(table, columns, values)),
  1029          operation=WriteMutation._OPERATION_INSERT,
  1030          rows=rows,
  1031          cells=cells,
  1032          kwargs={
  1033              "table": table, "columns": columns, "values": values
  1034          })
  1035  
  1036    @staticmethod
  1037    def update(table, columns, values):
  1038      """Update one or more existing table rows.
  1039  
  1040      Args:
  1041        table: Name of the table to be modified.
  1042        columns: Name of the table columns to be modified.
  1043        values: Values to be modified.
  1044      """
  1045      rows = len(values)
  1046      cells = len(columns) * len(values)
  1047      return _Mutator(
  1048          mutation=Mutation(update=batch._make_write_pb(table, columns, values)),
  1049          operation=WriteMutation._OPERATION_UPDATE,
  1050          rows=rows,
  1051          cells=cells,
  1052          kwargs={
  1053              "table": table, "columns": columns, "values": values
  1054          })
  1055  
  1056    @staticmethod
  1057    def insert_or_update(table, columns, values):
  1058      """Insert/update one or more table rows.
  1059      Args:
  1060        table: Name of the table to be modified.
  1061        columns: Name of the table columns to be modified.
  1062        values: Values to be modified.
  1063      """
  1064      rows = len(values)
  1065      cells = len(columns) * len(values)
  1066      return _Mutator(
  1067          mutation=Mutation(
  1068              insert_or_update=batch._make_write_pb(table, columns, values)),
  1069          operation=WriteMutation._OPERATION_INSERT_OR_UPDATE,
  1070          rows=rows,
  1071          cells=cells,
  1072          kwargs={
  1073              "table": table, "columns": columns, "values": values
  1074          })
  1075  
  1076    @staticmethod
  1077    def replace(table, columns, values):
  1078      """Replace one or more table rows.
  1079  
  1080      Args:
  1081        table: Name of the table to be modified.
  1082        columns: Name of the table columns to be modified.
  1083        values: Values to be modified.
  1084      """
  1085      rows = len(values)
  1086      cells = len(columns) * len(values)
  1087      return _Mutator(
  1088          mutation=Mutation(replace=batch._make_write_pb(table, columns, values)),
  1089          operation=WriteMutation._OPERATION_REPLACE,
  1090          rows=rows,
  1091          cells=cells,
  1092          kwargs={
  1093              "table": table, "columns": columns, "values": values
  1094          })
  1095  
  1096    @staticmethod
  1097    def delete(table, keyset):
  1098      """Delete one or more table rows.
  1099  
  1100      Args:
  1101        table: Name of the table to be modified.
  1102        keyset: Keys/ranges identifying rows to delete.
  1103      """
  1104      delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
  1105      return _Mutator(
  1106          mutation=Mutation(delete=delete),
  1107          rows=0,
  1108          cells=0,
  1109          operation=WriteMutation._OPERATION_DELETE,
  1110          kwargs={
  1111              "table": table, "keyset": keyset
  1112          })
  1113  
  1114  
  1115  @with_input_types(typing.Union[MutationGroup, TaggedOutput])
  1116  @with_output_types(MutationGroup)
  1117  class _BatchFn(DoFn):
  1118    """
  1119    Batches mutations together.
  1120    """
  1121    def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells):
  1122      self._max_batch_size_bytes = max_batch_size_bytes
  1123      self._max_number_rows = max_number_rows
  1124      self._max_number_cells = max_number_cells
  1125  
  1126    def start_bundle(self):
  1127      self._batch = MutationGroup()
  1128      self._size_in_bytes = 0
  1129      self._rows = 0
  1130      self._cells = 0
  1131  
  1132    def _reset_count(self):
  1133      self._batch = MutationGroup()
  1134      self._size_in_bytes = 0
  1135      self._rows = 0
  1136      self._cells = 0
  1137  
  1138    def process(self, element):
  1139      mg_info = element.info
  1140  
  1141      if mg_info['byte_size'] + self._size_in_bytes > self._max_batch_size_bytes \
  1142          or mg_info['cells'] + self._cells > self._max_number_cells \
  1143          or mg_info['rows'] + self._rows > self._max_number_rows:
  1144        # Batch is full, output the batch and resetting the count.
  1145        if self._batch:
  1146          yield self._batch
  1147        self._reset_count()
  1148  
  1149      self._batch.extend(element)
  1150  
  1151      # total byte size of the mutation group.
  1152      self._size_in_bytes += mg_info['byte_size']
  1153  
  1154      # total rows in the mutation group.
  1155      self._rows += mg_info['rows']
  1156  
  1157      # total cells in the mutation group.
  1158      self._cells += mg_info['cells']
  1159  
  1160    def finish_bundle(self):
  1161      if self._batch is not None:
  1162        yield window.GlobalWindows.windowed_value(self._batch)
  1163        self._batch = None
  1164  
  1165  
  1166  @with_input_types(MutationGroup)
  1167  @with_output_types(MutationGroup)
  1168  class _BatchableFilterFn(DoFn):
  1169    """
  1170    Filters MutationGroups larger than the batch size to the output tagged with
  1171    OUTPUT_TAG_UNBATCHABLE.
  1172    """
  1173    OUTPUT_TAG_UNBATCHABLE = 'unbatchable'
  1174  
  1175    def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells):
  1176      self._max_batch_size_bytes = max_batch_size_bytes
  1177      self._max_number_rows = max_number_rows
  1178      self._max_number_cells = max_number_cells
  1179      self._batchable = None
  1180      self._unbatchable = None
  1181  
  1182    def process(self, element):
  1183      if element.primary().operation == WriteMutation._OPERATION_DELETE:
  1184        # As delete mutations are not batchable.
  1185        yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element)
  1186      else:
  1187        mg_info = element.info
  1188        if mg_info['byte_size'] > self._max_batch_size_bytes \
  1189            or mg_info['cells'] > self._max_number_cells \
  1190            or mg_info['rows'] > self._max_number_rows:
  1191          yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element)
  1192        else:
  1193          yield element
  1194  
  1195  
  1196  class _WriteToSpannerDoFn(DoFn):
  1197    def __init__(self, spanner_configuration):
  1198      self._spanner_configuration = spanner_configuration
  1199      self._db_instance = None
  1200      self.batches = Metrics.counter(self.__class__, 'SpannerBatches')
  1201      self.base_labels = {
  1202          monitoring_infos.SERVICE_LABEL: 'Spanner',
  1203          monitoring_infos.METHOD_LABEL: 'Write',
  1204          monitoring_infos.SPANNER_PROJECT_ID: spanner_configuration.project,
  1205          monitoring_infos.SPANNER_DATABASE_ID: spanner_configuration.database,
  1206      }
  1207      # table_id to metrics
  1208      self.service_metrics = {}
  1209  
  1210    def _register_table_metric(self, table_id):
  1211      if table_id in self.service_metrics:
  1212        return
  1213      database_id = self._spanner_configuration.database
  1214      project_id = self._spanner_configuration.project
  1215      resource = resource_identifiers.SpannerTable(
  1216          project_id, database_id, table_id)
  1217      labels = {
  1218          **self.base_labels,
  1219          monitoring_infos.RESOURCE_LABEL: resource,
  1220          monitoring_infos.SPANNER_TABLE_ID: table_id
  1221      }
  1222      service_call_metric = ServiceCallMetric(
  1223          request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
  1224          base_labels=labels)
  1225      self.service_metrics[table_id] = service_call_metric
  1226  
  1227    def setup(self):
  1228      spanner_client = Client(self._spanner_configuration.project)
  1229      instance = spanner_client.instance(self._spanner_configuration.instance)
  1230      self._db_instance = instance.database(
  1231          self._spanner_configuration.database,
  1232          pool=self._spanner_configuration.pool)
  1233  
  1234    def start_bundle(self):
  1235      self.service_metrics = {}
  1236  
  1237    def process(self, element):
  1238      self.batches.inc()
  1239      try:
  1240        with self._db_instance.batch() as b:
  1241          for m in element:
  1242            table_id = m.kwargs['table']
  1243            self._register_table_metric(table_id)
  1244  
  1245            if m.operation == WriteMutation._OPERATION_DELETE:
  1246              batch_func = b.delete
  1247            elif m.operation == WriteMutation._OPERATION_REPLACE:
  1248              batch_func = b.replace
  1249            elif m.operation == WriteMutation._OPERATION_INSERT_OR_UPDATE:
  1250              batch_func = b.insert_or_update
  1251            elif m.operation == WriteMutation._OPERATION_INSERT:
  1252              batch_func = b.insert
  1253            elif m.operation == WriteMutation._OPERATION_UPDATE:
  1254              batch_func = b.update
  1255            else:
  1256              raise ValueError("Unknown operation action: %s" % m.operation)
  1257            batch_func(**m.kwargs)
  1258      except (ClientError, GoogleAPICallError) as e:
  1259        for service_metric in self.service_metrics.values():
  1260          service_metric.call(str(e.code.value))
  1261        raise
  1262      except HttpError as e:
  1263        for service_metric in self.service_metrics.values():
  1264          service_metric.call(str(e))
  1265        raise
  1266      else:
  1267        for service_metric in self.service_metrics.values():
  1268          service_metric.call('ok')
  1269  
  1270  
  1271  @with_input_types(typing.Union[MutationGroup, _Mutator])
  1272  @with_output_types(MutationGroup)
  1273  class _MakeMutationGroupsFn(DoFn):
  1274    """
  1275    Make Mutation group object if the element is the instance of _Mutator.
  1276    """
  1277    def process(self, element):
  1278      if isinstance(element, MutationGroup):
  1279        yield element
  1280      elif isinstance(element, _Mutator):
  1281        yield MutationGroup([element])
  1282      else:
  1283        raise ValueError(
  1284            "Invalid object type: %s. Object must be an instance of "
  1285            "MutationGroup or WriteMutations" % str(element))
  1286  
  1287  
  1288  class _WriteGroup(PTransform):
  1289    def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells):
  1290      self._max_batch_size_bytes = max_batch_size_bytes
  1291      self._max_number_rows = max_number_rows
  1292      self._max_number_cells = max_number_cells
  1293  
  1294    def expand(self, pcoll):
  1295      filter_batchable_mutations = (
  1296          pcoll
  1297          | 'Making mutation groups' >> ParDo(_MakeMutationGroupsFn())
  1298          | 'Filtering Batchable Mutations' >> ParDo(
  1299              _BatchableFilterFn(
  1300                  max_batch_size_bytes=self._max_batch_size_bytes,
  1301                  max_number_rows=self._max_number_rows,
  1302                  max_number_cells=self._max_number_cells)).with_outputs(
  1303                      _BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, main='batchable')
  1304      )
  1305  
  1306      batching_batchables = (
  1307          filter_batchable_mutations['batchable']
  1308          | ParDo(
  1309              _BatchFn(
  1310                  max_batch_size_bytes=self._max_batch_size_bytes,
  1311                  max_number_rows=self._max_number_rows,
  1312                  max_number_cells=self._max_number_cells)))
  1313  
  1314      return ((
  1315          batching_batchables,
  1316          filter_batchable_mutations[_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE])
  1317              | 'Merging batchable and unbatchable' >> Flatten())