github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/bigquery_read_internal.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """
    19  Internal library for reading data from BigQuery.
    20  
    21  NOTHING IN THIS FILE HAS BACKWARDS COMPATIBILITY GUARANTEES.
    22  """
    23  import collections
    24  import decimal
    25  import json
    26  import logging
    27  import random
    28  import time
    29  import uuid
    30  from typing import TYPE_CHECKING
    31  from typing import Any
    32  from typing import Dict
    33  from typing import Iterable
    34  from typing import List
    35  from typing import Optional
    36  from typing import Union
    37  
    38  import apache_beam as beam
    39  from apache_beam.coders import coders
    40  from apache_beam.io.avroio import _create_avro_source
    41  from apache_beam.io.filesystem import CompressionTypes
    42  from apache_beam.io.filesystems import FileSystems
    43  from apache_beam.io.gcp import bigquery_tools
    44  from apache_beam.io.gcp.bigquery_io_metadata import create_bigquery_io_metadata
    45  from apache_beam.io.iobase import BoundedSource
    46  from apache_beam.io.textio import _TextSource
    47  from apache_beam.options.pipeline_options import GoogleCloudOptions
    48  from apache_beam.options.pipeline_options import PipelineOptions
    49  from apache_beam.options.value_provider import ValueProvider
    50  from apache_beam.transforms import PTransform
    51  
    52  if TYPE_CHECKING:
    53    from apache_beam.io.gcp.bigquery import ReadFromBigQueryRequest
    54  
    55  try:
    56    from apache_beam.io.gcp.internal.clients.bigquery import DatasetReference
    57    from apache_beam.io.gcp.internal.clients.bigquery import TableReference
    58  except ImportError:
    59    DatasetReference = None
    60    TableReference = None
    61  
    62  _LOGGER = logging.getLogger(__name__)
    63  
    64  
    65  def bigquery_export_destination_uri(
    66      gcs_location_vp: Optional[ValueProvider],
    67      temp_location: Optional[str],
    68      unique_id: str,
    69      directory_only: bool = False,
    70  ) -> str:
    71    """Returns the fully qualified Google Cloud Storage URI where the
    72    extracted table should be written.
    73    """
    74    file_pattern = 'bigquery-table-dump-*.json'
    75  
    76    gcs_location = None
    77    if gcs_location_vp is not None:
    78      gcs_location = gcs_location_vp.get()
    79  
    80    if gcs_location is not None:
    81      gcs_base = gcs_location
    82    elif temp_location is not None:
    83      gcs_base = temp_location
    84      _LOGGER.debug("gcs_location is empty, using temp_location instead")
    85    else:
    86      raise ValueError(
    87          'ReadFromBigQuery requires a GCS location to be provided. Neither '
    88          'gcs_location in the constructor nor the fallback option '
    89          '--temp_location is set.')
    90  
    91    if not unique_id:
    92      unique_id = uuid.uuid4().hex
    93  
    94    if directory_only:
    95      return FileSystems.join(gcs_base, unique_id)
    96    else:
    97      return FileSystems.join(gcs_base, unique_id, file_pattern)
    98  
    99  
   100  class _PassThroughThenCleanup(PTransform):
   101    """A PTransform that invokes a DoFn after the input PCollection has been
   102      processed.
   103  
   104      DoFn should have arguments (element, side_input, cleanup_signal).
   105  
   106      Utilizes readiness of PCollection to trigger DoFn.
   107    """
   108    def __init__(self, side_input=None):
   109      self.side_input = side_input
   110  
   111    def expand(self, input):
   112      class PassThrough(beam.DoFn):
   113        def process(self, element):
   114          yield element
   115  
   116      class RemoveExtractedFiles(beam.DoFn):
   117        def process(self, unused_element, unused_signal, gcs_locations):
   118          FileSystems.delete(list(gcs_locations))
   119  
   120      main_output, cleanup_signal = input | beam.ParDo(
   121          PassThrough()).with_outputs(
   122          'cleanup_signal', main='main')
   123  
   124      cleanup_input = input.pipeline | beam.Create([None])
   125  
   126      _ = cleanup_input | beam.ParDo(
   127          RemoveExtractedFiles(),
   128          beam.pvalue.AsSingleton(cleanup_signal),
   129          self.side_input,
   130      )
   131  
   132      return main_output
   133  
   134  
   135  class _PassThroughThenCleanupTempDatasets(PTransform):
   136    """A PTransform that invokes a DoFn after the input PCollection has been
   137      processed.
   138  
   139      DoFn should have arguments (element, side_input, cleanup_signal).
   140  
   141      Utilizes readiness of PCollection to trigger DoFn.
   142    """
   143    def __init__(self, side_input=None):
   144      self.side_input = side_input
   145  
   146    def expand(self, input):
   147      class PassThrough(beam.DoFn):
   148        def process(self, element):
   149          yield element
   150  
   151      class CleanUpProjects(beam.DoFn):
   152        def process(self, unused_element, unused_signal, pipeline_details):
   153          bq = bigquery_tools.BigQueryWrapper()
   154          pipeline_details = pipeline_details[0]
   155          if 'temp_table_ref' in pipeline_details.keys():
   156            temp_table_ref = pipeline_details['temp_table_ref']
   157            bq._clean_up_beam_labelled_temporary_datasets(
   158                project_id=temp_table_ref.projectId,
   159                dataset_id=temp_table_ref.datasetId,
   160                table_id=temp_table_ref.tableId)
   161          elif 'project_id' in pipeline_details.keys():
   162            bq._clean_up_beam_labelled_temporary_datasets(
   163                project_id=pipeline_details['project_id'],
   164                labels=pipeline_details['bigquery_dataset_labels'])
   165  
   166      main_output, cleanup_signal = input | beam.ParDo(
   167          PassThrough()).with_outputs(
   168          'cleanup_signal', main='main')
   169  
   170      cleanup_input = input.pipeline | beam.Create([None])
   171  
   172      _ = cleanup_input | beam.ParDo(
   173          CleanUpProjects(),
   174          beam.pvalue.AsSingleton(cleanup_signal),
   175          self.side_input,
   176      )
   177  
   178      return main_output
   179  
   180  
   181  class _BigQueryReadSplit(beam.transforms.DoFn):
   182    """Starts the process of reading from BigQuery.
   183  
   184    This transform will start a BigQuery export job, and output a number of
   185    file sources that are consumed downstream.
   186    """
   187    def __init__(
   188        self,
   189        options: PipelineOptions,
   190        gcs_location: Union[str, ValueProvider] = None,
   191        use_json_exports: bool = False,
   192        bigquery_job_labels: Dict[str, str] = None,
   193        step_name: str = None,
   194        job_name: str = None,
   195        unique_id: str = None,
   196        kms_key: str = None,
   197        project: str = None,
   198        temp_dataset: Union[str, DatasetReference] = None,
   199        query_priority: Optional[str] = None):
   200      self.options = options
   201      self.use_json_exports = use_json_exports
   202      self.gcs_location = gcs_location
   203      self.bigquery_job_labels = bigquery_job_labels or {}
   204      self._step_name = step_name
   205      self._job_name = job_name or 'BQ_READ_SPLIT'
   206      self._source_uuid = unique_id
   207      self.kms_key = kms_key
   208      self.project = project
   209      self.temp_dataset = temp_dataset or 'bq_read_all_%s' % uuid.uuid4().hex
   210      self.query_priority = query_priority
   211      self.bq_io_metadata = None
   212  
   213    def display_data(self):
   214      return {
   215          'use_json_exports': str(self.use_json_exports),
   216          'gcs_location': str(self.gcs_location),
   217          'bigquery_job_labels': json.dumps(self.bigquery_job_labels),
   218          'kms_key': str(self.kms_key),
   219          'project': str(self.project),
   220          'temp_dataset': str(self.temp_dataset)
   221      }
   222  
   223    def _get_temp_dataset(self):
   224      if isinstance(self.temp_dataset, str):
   225        return DatasetReference(
   226            datasetId=self.temp_dataset, projectId=self._get_project())
   227      else:
   228        return self.temp_dataset
   229  
   230    def process(self,
   231                element: 'ReadFromBigQueryRequest') -> Iterable[BoundedSource]:
   232      bq = bigquery_tools.BigQueryWrapper(
   233          temp_dataset_id=self._get_temp_dataset().datasetId)
   234  
   235      if element.query is not None:
   236        self._setup_temporary_dataset(bq, element)
   237        table_reference = self._execute_query(bq, element)
   238      else:
   239        assert element.table
   240        table_reference = bigquery_tools.parse_table_reference(
   241            element.table, project=self._get_project())
   242  
   243      if not table_reference.projectId:
   244        table_reference.projectId = self._get_project()
   245  
   246      schema, metadata_list = self._export_files(bq, element, table_reference)
   247  
   248      for metadata in metadata_list:
   249        yield self._create_source(metadata.path, schema)
   250  
   251      if element.query is not None:
   252        bq._delete_table(
   253            table_reference.projectId,
   254            table_reference.datasetId,
   255            table_reference.tableId)
   256  
   257      if bq.created_temp_dataset:
   258        self._clean_temporary_dataset(bq, element)
   259  
   260    def _get_bq_metadata(self):
   261      if not self.bq_io_metadata:
   262        self.bq_io_metadata = create_bigquery_io_metadata(self._step_name)
   263      return self.bq_io_metadata
   264  
   265    def _create_source(self, path, schema):
   266      if not self.use_json_exports:
   267        return _create_avro_source(path)
   268      else:
   269        return _TextSource(
   270            path,
   271            min_bundle_size=0,
   272            compression_type=CompressionTypes.UNCOMPRESSED,
   273            strip_trailing_newlines=True,
   274            coder=_JsonToDictCoder(schema))
   275  
   276    def _setup_temporary_dataset(
   277        self,
   278        bq: bigquery_tools.BigQueryWrapper,
   279        element: 'ReadFromBigQueryRequest'):
   280      location = bq.get_query_location(
   281          self._get_project(), element.query, not element.use_standard_sql)
   282      bq.create_temporary_dataset(self._get_project(), location)
   283  
   284    def _clean_temporary_dataset(
   285        self,
   286        bq: bigquery_tools.BigQueryWrapper,
   287        element: 'ReadFromBigQueryRequest'):
   288      bq.clean_up_temporary_dataset(self._get_project())
   289  
   290    def _execute_query(
   291        self,
   292        bq: bigquery_tools.BigQueryWrapper,
   293        element: 'ReadFromBigQueryRequest'):
   294      query_job_name = bigquery_tools.generate_bq_job_name(
   295          self._job_name,
   296          self._source_uuid,
   297          bigquery_tools.BigQueryJobTypes.QUERY,
   298          '%s_%s' % (int(time.time()), random.randint(0, 1000)))
   299      job = bq._start_query_job(
   300          self._get_project(),
   301          element.query,
   302          not element.use_standard_sql,
   303          element.flatten_results,
   304          job_id=query_job_name,
   305          priority=self.query_priority,
   306          kms_key=self.kms_key,
   307          job_labels=self._get_bq_metadata().add_additional_bq_job_labels(
   308              self.bigquery_job_labels))
   309      job_ref = job.jobReference
   310      bq.wait_for_bq_job(job_ref, max_retries=0)
   311      return bq._get_temp_table(self._get_project())
   312  
   313    def _export_files(
   314        self,
   315        bq: bigquery_tools.BigQueryWrapper,
   316        element: 'ReadFromBigQueryRequest',
   317        table_reference: TableReference):
   318      """Runs a BigQuery export job.
   319  
   320      Returns:
   321        bigquery.TableSchema instance, a list of FileMetadata instances
   322      """
   323      job_labels = self._get_bq_metadata().add_additional_bq_job_labels(
   324          self.bigquery_job_labels)
   325      export_job_name = bigquery_tools.generate_bq_job_name(
   326          self._job_name,
   327          self._source_uuid,
   328          bigquery_tools.BigQueryJobTypes.EXPORT,
   329          element.obj_id)
   330      temp_location = self.options.view_as(GoogleCloudOptions).temp_location
   331      gcs_location = bigquery_export_destination_uri(
   332          self.gcs_location,
   333          temp_location,
   334          '%s%s' % (self._source_uuid, element.obj_id))
   335      try:
   336        if self.use_json_exports:
   337          job_ref = bq.perform_extract_job([gcs_location],
   338                                           export_job_name,
   339                                           table_reference,
   340                                           bigquery_tools.FileFormat.JSON,
   341                                           project=self._get_project(),
   342                                           job_labels=job_labels,
   343                                           include_header=False)
   344        else:
   345          job_ref = bq.perform_extract_job([gcs_location],
   346                                           export_job_name,
   347                                           table_reference,
   348                                           bigquery_tools.FileFormat.AVRO,
   349                                           project=self._get_project(),
   350                                           include_header=False,
   351                                           job_labels=job_labels,
   352                                           use_avro_logical_types=True)
   353        bq.wait_for_bq_job(job_ref)
   354      except Exception as exn:  # pylint: disable=broad-except
   355        # The error messages thrown in this case are generic and misleading,
   356        # so leave this breadcrumb in case it's the root cause.
   357        logging.warning(
   358            "Error exporting table: %s. "
   359            "Note that external tables cannot be exported: "
   360            "https://cloud.google.com/bigquery/docs/external-tables"
   361            "#external_table_limitations",
   362            exn)
   363        raise
   364      metadata_list = FileSystems.match([gcs_location])[0].metadata_list
   365  
   366      if isinstance(table_reference, ValueProvider):
   367        table_ref = bigquery_tools.parse_table_reference(
   368            element.table, project=self._get_project())
   369      else:
   370        table_ref = table_reference
   371      table = bq.get_table(
   372          table_ref.projectId, table_ref.datasetId, table_ref.tableId)
   373  
   374      return table.schema, metadata_list
   375  
   376    def _get_project(self):
   377      """Returns the project that queries and exports will be billed to."""
   378  
   379      project = self.options.view_as(GoogleCloudOptions).project
   380      if isinstance(project, ValueProvider):
   381        project = project.get()
   382      if not project:
   383        project = self.project
   384      return project
   385  
   386  
   387  FieldSchema = collections.namedtuple('FieldSchema', 'fields mode name type')
   388  
   389  
   390  class _JsonToDictCoder(coders.Coder):
   391    """A coder for a JSON string to a Python dict."""
   392    def __init__(self, table_schema):
   393      self.fields = self._convert_to_tuple(table_schema.fields)
   394      self._converters = {
   395          'INTEGER': int,
   396          'INT64': int,
   397          'FLOAT': float,
   398          'FLOAT64': float,
   399          'NUMERIC': self._to_decimal,
   400          'BYTES': self._to_bytes,
   401      }
   402  
   403    @staticmethod
   404    def _to_decimal(value):
   405      return decimal.Decimal(value)
   406  
   407    @staticmethod
   408    def _to_bytes(value):
   409      """Converts value from str to bytes."""
   410      return value.encode('utf-8')
   411  
   412    @classmethod
   413    def _convert_to_tuple(cls, table_field_schemas):
   414      """Recursively converts the list of TableFieldSchema instances to the
   415      list of tuples to prevent errors when pickling and unpickling
   416      TableFieldSchema instances.
   417      """
   418      if not table_field_schemas:
   419        return []
   420  
   421      return [
   422          FieldSchema(cls._convert_to_tuple(x.fields), x.mode, x.name, x.type)
   423          for x in table_field_schemas
   424      ]
   425  
   426    def decode(self, value):
   427      value = json.loads(value.decode('utf-8'))
   428      return self._decode_row(value, self.fields)
   429  
   430    def _decode_row(self, row: Dict[str, Any], schema_fields: List[FieldSchema]):
   431      for field in schema_fields:
   432        if field.name not in row:
   433          # The field exists in the schema, but it doesn't exist in this row.
   434          # It probably means its value was null, as the extract to JSON job
   435          # doesn't preserve null fields
   436          row[field.name] = None
   437          continue
   438  
   439        if field.mode == 'REPEATED':
   440          for i, elem in enumerate(row[field.name]):
   441            row[field.name][i] = self._decode_data(elem, field)
   442        else:
   443          row[field.name] = self._decode_data(row[field.name], field)
   444      return row
   445  
   446    def _decode_data(self, obj: Any, field: FieldSchema):
   447      if not field.fields:
   448        try:
   449          return self._converters[field.type](obj)
   450        except KeyError:
   451          # No need to do any conversion
   452          return obj
   453      return self._decode_row(obj, field.fields)
   454  
   455    def is_deterministic(self):
   456      return True
   457  
   458    def to_type_hint(self):
   459      return dict