github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/sql_chain.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module for tracking a chain of beam_sql magics applied.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  """
    22  
    23  # pytype: skip-file
    24  
    25  import importlib
    26  import logging
    27  from dataclasses import dataclass
    28  from typing import Any
    29  from typing import Dict
    30  from typing import Optional
    31  from typing import Set
    32  from typing import Union
    33  
    34  import apache_beam as beam
    35  from apache_beam.internal import pickler
    36  from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
    37  from apache_beam.runners.interactive.utils import create_var_in_main
    38  from apache_beam.runners.interactive.utils import pcoll_by_name
    39  from apache_beam.runners.interactive.utils import progress_indicated
    40  from apache_beam.transforms.sql import SqlTransform
    41  from apache_beam.utils.interactive_utils import is_in_ipython
    42  
    43  _LOGGER = logging.getLogger(__name__)
    44  
    45  
    46  @dataclass
    47  class SqlNode:
    48    """Each SqlNode represents a beam_sql magic applied.
    49  
    50    Attributes:
    51      output_name: the watched unique name of the beam_sql output. Can be used as
    52        an identifier.
    53      source: the inputs consumed by this node. Can be a pipeline or a set of
    54        PCollections represented by their variable names watched. When it's a
    55        pipeline, the node computes from raw values in the query, so the output
    56        can be consumed by any SqlNode in any SqlChain.
    57      query: the SQL query applied by this node.
    58      schemas: the schemas (NamedTuple classes) used by this node.
    59      evaluated: the pipelines this node has been evaluated for.
    60      next: the next SqlNode applied chronologically.
    61      execution_count: the execution count if in an IPython env.
    62    """
    63    output_name: str
    64    source: Union[beam.Pipeline, Set[str]]
    65    query: str
    66    schemas: Set[Any] = None
    67    evaluated: Set[beam.Pipeline] = None
    68    next: Optional['SqlNode'] = None
    69    execution_count: int = 0
    70  
    71    def __post_init__(self):
    72      if not self.schemas:
    73        self.schemas = set()
    74      if not self.evaluated:
    75        self.evaluated = set()
    76      if is_in_ipython():
    77        from IPython import get_ipython
    78        self.execution_count = get_ipython().execution_count
    79  
    80    def __hash__(self):
    81      return hash(
    82          (self.output_name, self.source, self.query, self.execution_count))
    83  
    84    def to_pipeline(self, pipeline: Optional[beam.Pipeline]) -> beam.Pipeline:
    85      """Converts the chain into an executable pipeline."""
    86      if pipeline not in self.evaluated:
    87        # The whole chain should form a single pipeline.
    88        source = self.source
    89        if isinstance(self.source, beam.Pipeline):
    90          if pipeline:  # use the known pipeline
    91            source = pipeline
    92          else:  # use the source pipeline
    93            pipeline = self.source
    94        else:
    95          name_to_pcoll = pcoll_by_name()
    96          if len(self.source) == 1:
    97            source = name_to_pcoll.get(next(iter(self.source)))
    98          else:
    99            source = {s: name_to_pcoll.get(s) for s in self.source}
   100        if isinstance(source, beam.Pipeline):
   101          output = source | 'beam_sql_{}_{}'.format(
   102              self.output_name, self.execution_count) >> SqlTransform(self.query)
   103        else:
   104          output = source | 'schema_loaded_beam_sql_{}_{}'.format(
   105              self.output_name, self.execution_count
   106          ) >> SchemaLoadedSqlTransform(
   107              self.output_name, self.query, self.schemas, self.execution_count)
   108        _ = create_var_in_main(self.output_name, output)
   109        self.evaluated.add(pipeline)
   110      if self.next:
   111        return self.next.to_pipeline(pipeline)
   112      else:
   113        return pipeline
   114  
   115  
   116  class SchemaLoadedSqlTransform(beam.PTransform):
   117    """PTransform that loads schema before executing SQL.
   118  
   119    When submitting a pipeline to remote runner for execution, schemas defined in
   120    the main module are not available without save_main_session. However,
   121    save_main_session might fail when there is anything unpicklable. This DoFn
   122    makes sure only the schemas needed are pickled locally and restored later on
   123    workers.
   124    """
   125    def __init__(self, output_name, query, schemas, execution_count):
   126      self.output_name = output_name
   127      self.query = query
   128      self.schemas = schemas
   129      self.execution_count = execution_count
   130      # TODO(BEAM-8123): clean up this attribute or the whole wrapper PTransform.
   131      # Dill does not preserve everything. On the other hand, save_main_session
   132      # is not stable. Until cloudpickle replaces dill in Beam, we work around
   133      # it by explicitly pickling annotations and load schemas in remote main
   134      # sessions.
   135      self.schema_annotations = [s.__annotations__ for s in self.schemas]
   136  
   137    class _SqlTransformDoFn(beam.DoFn):
   138      """The DoFn yields all its input without any transform but a setup to
   139      configure the main session."""
   140      def __init__(self, schemas, annotations):
   141        self.pickled_schemas = [pickler.dumps(s) for s in schemas]
   142        self.pickled_annotations = [pickler.dumps(a) for a in annotations]
   143  
   144      def setup(self):
   145        main_session = importlib.import_module('__main__')
   146        for pickled_schema, pickled_annotation in zip(
   147            self.pickled_schemas, self.pickled_annotations):
   148          schema = pickler.loads(pickled_schema)
   149          schema.__annotations__ = pickler.loads(pickled_annotation)
   150          if not hasattr(main_session, schema.__name__) or not hasattr(
   151              getattr(main_session, schema.__name__), '__annotations__'):
   152            # Restore the schema in the main session on the [remote] worker.
   153            setattr(main_session, schema.__name__, schema)
   154          register_coder_for_schema(schema)
   155  
   156      def process(self, e):
   157        yield e
   158  
   159    def expand(self, source):
   160      """Applies the SQL transform. If a PCollection uses a schema defined in
   161      the main session, use the additional DoFn to restore it on the worker."""
   162      if isinstance(source, dict):
   163        schema_loaded = {
   164            tag: pcoll | 'load_schemas_{}_tag_{}_{}'.format(
   165                self.output_name, tag, self.execution_count) >> beam.ParDo(
   166                    self._SqlTransformDoFn(self.schemas, self.schema_annotations))
   167            if pcoll.element_type in self.schemas else pcoll
   168            for tag,
   169            pcoll in source.items()
   170        }
   171      elif isinstance(source, beam.pvalue.PCollection):
   172        schema_loaded = source | 'load_schemas_{}_{}'.format(
   173            self.output_name, self.execution_count) >> beam.ParDo(
   174                self._SqlTransformDoFn(self.schemas, self.schema_annotations)
   175            ) if source.element_type in self.schemas else source
   176      else:
   177        raise ValueError(
   178            '{} should be either a single PCollection or a dict of named '
   179            'PCollections.'.format(source))
   180      return schema_loaded | 'beam_sql_{}_{}'.format(
   181          self.output_name, self.execution_count) >> SqlTransform(self.query)
   182  
   183  
   184  @dataclass
   185  class SqlChain:
   186    """A chain of SqlNodes.
   187  
   188    Attributes:
   189      nodes: all nodes by their output_names.
   190      root: the first SqlNode applied chronologically.
   191      current: the last node applied.
   192      user_pipeline: the user defined pipeline this chain originates from. If
   193        None, the whole chain just computes from raw values in queries.
   194        Otherwise, at least some of the nodes in chain has queried against
   195        PCollections.
   196    """
   197    nodes: Dict[str, SqlNode] = None
   198    root: Optional[SqlNode] = None
   199    current: Optional[SqlNode] = None
   200    user_pipeline: Optional[beam.Pipeline] = None
   201  
   202    def __post_init__(self):
   203      if not self.nodes:
   204        self.nodes = {}
   205  
   206    @progress_indicated
   207    def to_pipeline(self) -> beam.Pipeline:
   208      """Converts the chain into a beam pipeline."""
   209      pipeline_to_execute = self.root.to_pipeline(self.user_pipeline)
   210      # The pipeline definitely contains external transform: SqlTransform.
   211      pipeline_to_execute.contains_external_transforms = True
   212      return pipeline_to_execute
   213  
   214    def append(self, node: SqlNode) -> 'SqlChain':
   215      """Appends a node to the chain."""
   216      if self.current:
   217        self.current.next = node
   218      else:
   219        self.root = node
   220      self.current = node
   221      self.nodes[node.output_name] = node
   222      return self
   223  
   224    def get(self, output_name: str) -> Optional[SqlNode]:
   225      """Gets a node from the chain based on the given output_name."""
   226      return self.nodes.get(output_name, None)