github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/beam_sql_magics.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module of beam_sql cell magic that executes a Beam SQL.
    19  
    20  Only works within an IPython kernel.
    21  """
    22  
    23  import argparse
    24  import importlib
    25  import keyword
    26  import logging
    27  import traceback
    28  from typing import Dict
    29  from typing import List
    30  from typing import Optional
    31  from typing import Tuple
    32  from typing import Union
    33  
    34  import apache_beam as beam
    35  from apache_beam.pvalue import PValue
    36  from apache_beam.runners.interactive import interactive_environment as ie
    37  from apache_beam.runners.interactive.background_caching_job import has_source_to_cache
    38  from apache_beam.runners.interactive.caching.cacheable import CacheKey
    39  from apache_beam.runners.interactive.caching.reify import reify_to_cache
    40  from apache_beam.runners.interactive.caching.reify import unreify_from_cache
    41  from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
    42  from apache_beam.runners.interactive.sql.sql_chain import SqlChain
    43  from apache_beam.runners.interactive.sql.sql_chain import SqlNode
    44  from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
    45  from apache_beam.runners.interactive.sql.utils import find_pcolls
    46  from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
    47  from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
    48  from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
    49  from apache_beam.runners.interactive.utils import create_var_in_main
    50  from apache_beam.runners.interactive.utils import obfuscate
    51  from apache_beam.runners.interactive.utils import pcoll_by_name
    52  from apache_beam.runners.interactive.utils import progress_indicated
    53  from apache_beam.testing import test_stream
    54  from apache_beam.testing.test_stream_service import TestStreamServiceController
    55  from apache_beam.transforms.sql import SqlTransform
    56  from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
    57  from IPython.core.magic import Magics
    58  from IPython.core.magic import line_cell_magic
    59  from IPython.core.magic import magics_class
    60  
    61  _LOGGER = logging.getLogger(__name__)
    62  
    63  _EXAMPLE_USAGE = """beam_sql magic to execute Beam SQL in notebooks
    64  ---------------------------------------------------------
    65  %%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query
    66  ---------------------------------------------------------
    67  Or
    68  ---------------------------------------------------------
    69  %%%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query-line#1
    70  query-line#2
    71  ...
    72  query-line#N
    73  ---------------------------------------------------------
    74  """
    75  
    76  _NOT_SUPPORTED_MSG = """The query was valid and successfully applied.
    77      But beam_sql failed to execute the query: %s
    78  
    79      Runner used by beam_sql was %s.
    80      Some Beam features might have not been supported by the Python SDK and runner combination.
    81      Please check the runner output for more details about the failed items.
    82  
    83      In the meantime, you may check:
    84      https://beam.apache.org/documentation/runners/capability-matrix/
    85      to choose a runner other than the InteractiveRunner and explicitly apply SqlTransform
    86      to build Beam pipelines in a non-interactive manner.
    87  """
    88  
    89  _SUPPORTED_RUNNERS = ['DirectRunner', 'DataflowRunner']
    90  
    91  
    92  class BeamSqlParser:
    93    """A parser to parse beam_sql inputs."""
    94    def __init__(self):
    95      self._parser = argparse.ArgumentParser(usage=_EXAMPLE_USAGE)
    96      self._parser.add_argument(
    97          '-o',
    98          '--output-name',
    99          dest='output_name',
   100          help=(
   101              'The output variable name of the magic, usually a PCollection. '
   102              'Auto-generated if omitted.'))
   103      self._parser.add_argument(
   104          '-v',
   105          '--verbose',
   106          action='store_true',
   107          help='Display more details about the magic execution.')
   108      self._parser.add_argument(
   109          '-r',
   110          '--runner',
   111          dest='runner',
   112          help=(
   113              'The runner to run the query. Supported runners are %s. If not '
   114              'provided, DirectRunner is used and results can be inspected '
   115              'locally.' % _SUPPORTED_RUNNERS))
   116      self._parser.add_argument(
   117          'query',
   118          type=str,
   119          nargs='*',
   120          help=(
   121              'The Beam SQL query to execute. '
   122              'Syntax: https://beam.apache.org/documentation/dsls/sql/calcite/'
   123              'query-syntax/. '
   124              'Please make sure that there is no conflict between your variable '
   125              'names and the SQL keywords, such as "SELECT", "FROM", "WHERE" and '
   126              'etc.'))
   127  
   128    def parse(self, args: List[str]) -> Optional[argparse.Namespace]:
   129      """Parses a list of string inputs.
   130  
   131      The parsed namespace contains these attributes:
   132        output_name: Optional[str], the output variable name.
   133        verbose: bool, whether to display more details of the magic execution.
   134        query: Optional[List[str]], the beam SQL query to execute.
   135  
   136      Returns:
   137        The parsed args or None if fail to parse.
   138      """
   139      try:
   140        return self._parser.parse_args(args)
   141      except KeyboardInterrupt:
   142        raise
   143      except:  # pylint: disable=bare-except
   144        # -h or --help results in SystemExit 0. Do not raise.
   145        return None
   146  
   147    def print_help(self) -> None:
   148      self._parser.print_help()
   149  
   150  
   151  def on_error(error_msg, *args):
   152    """Logs the error and the usage example."""
   153    _LOGGER.error(error_msg, *args)
   154    BeamSqlParser().print_help()
   155  
   156  
   157  @magics_class
   158  class BeamSqlMagics(Magics):
   159    def __init__(self, shell):
   160      super().__init__(shell)
   161      # Eagerly initializes the environment.
   162      _ = ie.current_env()
   163      self._parser = BeamSqlParser()
   164  
   165    @line_cell_magic
   166    def beam_sql(self, line: str, cell: Optional[str] = None) -> Optional[PValue]:
   167      """The beam_sql line/cell magic that executes a Beam SQL.
   168  
   169      Args:
   170        line: the string on the same line after the beam_sql magic.
   171        cell: everything else in the same notebook cell as a string. If None,
   172          beam_sql is used as line magic. Otherwise, cell magic.
   173  
   174      Returns None if running into an error or waiting for user input (running on
   175      a selected runner remotely), otherwise a PValue as if a SqlTransform is
   176      applied.
   177      """
   178      input_str = line
   179      if cell:
   180        input_str += ' ' + cell
   181      parsed = self._parser.parse(input_str.strip().split())
   182      if not parsed:
   183        # Failed to parse inputs, let the parser handle the exit.
   184        return
   185      output_name = parsed.output_name
   186      verbose = parsed.verbose
   187      query = parsed.query
   188      runner = parsed.runner
   189  
   190      if output_name and not output_name.isidentifier() or keyword.iskeyword(
   191          output_name):
   192        on_error(
   193            'The output_name "%s" is not a valid identifier. Please supply a '
   194            'valid identifier that is not a Python keyword.',
   195            line)
   196        return
   197      if not query:
   198        on_error('Please supply the SQL query to be executed.')
   199        return
   200      if runner and runner not in _SUPPORTED_RUNNERS:
   201        on_error(
   202            'Runner "%s" is not supported. Supported runners are %s.',
   203            runner,
   204            _SUPPORTED_RUNNERS)
   205        return
   206      query = ' '.join(query)
   207  
   208      found = find_pcolls(query, pcoll_by_name(), verbose=verbose)
   209      schemas = set()
   210      main_session = importlib.import_module('__main__')
   211      for _, pcoll in found.items():
   212        if not match_is_named_tuple(pcoll.element_type):
   213          on_error(
   214              'PCollection %s of type %s is not a NamedTuple. See '
   215              'https://beam.apache.org/documentation/programming-guide/#schemas '
   216              'for more details.',
   217              pcoll,
   218              pcoll.element_type)
   219          return
   220        register_coder_for_schema(pcoll.element_type, verbose=verbose)
   221        # Only care about schemas defined by the user in the main module.
   222        if hasattr(main_session, pcoll.element_type.__name__):
   223          schemas.add(pcoll.element_type)
   224  
   225      if runner in ('DirectRunner', None):
   226        collect_data_for_local_run(query, found)
   227        output_name, output, chain = apply_sql(query, output_name, found)
   228        chain.current.schemas = schemas
   229        cache_output(output_name, output)
   230        return output
   231  
   232      output_name, current_node, chain = apply_sql(
   233          query, output_name, found, False)
   234      current_node.schemas = schemas
   235      # TODO(BEAM-10708): Move the options setup and result handling to a
   236      # separate module when more runners are supported.
   237      if runner == 'DataflowRunner':
   238        _ = chain.to_pipeline()
   239        _ = DataflowOptionsForm(
   240            output_name, pcoll_by_name()[output_name],
   241            verbose).display_for_input()
   242        return None
   243      else:
   244        raise ValueError('Unsupported runner %s.', runner)
   245  
   246  
   247  @progress_indicated
   248  def collect_data_for_local_run(query: str, found: Dict[str, beam.PCollection]):
   249    from apache_beam.runners.interactive import interactive_beam as ib
   250    for name, pcoll in found.items():
   251      try:
   252        _ = ib.collect(pcoll)
   253      except (KeyboardInterrupt, SystemExit):
   254        raise
   255      except:  # pylint: disable=bare-except
   256        _LOGGER.error(
   257            'Cannot collect data for PCollection %s. Please make sure the '
   258            'PCollections queried in the sql "%s" are all from a single '
   259            'pipeline using an InteractiveRunner. Make sure there is no '
   260            'ambiguity, for example, same named PCollections from multiple '
   261            'pipelines or notebook re-executions.',
   262            name,
   263            query)
   264        raise
   265  
   266  
   267  @progress_indicated
   268  def apply_sql(
   269      query: str,
   270      output_name: Optional[str],
   271      found: Dict[str, beam.PCollection],
   272      run: bool = True) -> Tuple[str, Union[PValue, SqlNode], SqlChain]:
   273    """Applies a SqlTransform with the given sql and queried PCollections.
   274  
   275    Args:
   276      query: The SQL query executed in the magic.
   277      output_name: (optional) The output variable name in __main__ module.
   278      found: The PCollections with variable names found to be used in the query.
   279      run: Whether to prepare the SQL pipeline for a local run or not.
   280  
   281    Returns:
   282      A tuple of values. First str value is the output variable name in
   283      __main__ module, auto-generated if not provided. Second value: if run,
   284      it's a PValue; otherwise, a SqlNode tracks the SQL without applying it or
   285      executing it. Third value: SqlChain is a chain of SqlNodes that have been
   286      applied.
   287    """
   288    output_name = _generate_output_name(output_name, query, found)
   289    query, sql_source, chain = _build_query_components(
   290        query, found, output_name, run)
   291    if run:
   292      try:
   293        output = sql_source | SqlTransform(query)
   294        # Declare a variable with the output_name and output value in the
   295        # __main__ module so that the user can use the output smoothly.
   296        output_name, output = create_var_in_main(output_name, output)
   297        _LOGGER.info(
   298            "The output PCollection variable is %s with element_type %s",
   299            output_name,
   300            pformat_namedtuple(output.element_type))
   301        return output_name, output, chain
   302      except (KeyboardInterrupt, SystemExit):
   303        raise
   304      except:  # pylint: disable=bare-except
   305        on_error('Error when applying the Beam SQL: %s', traceback.format_exc())
   306        raise
   307    else:
   308      return output_name, chain.current, chain
   309  
   310  
   311  def pcolls_from_streaming_cache(
   312      user_pipeline: beam.Pipeline,
   313      query_pipeline: beam.Pipeline,
   314      name_to_pcoll: Dict[str, beam.PCollection]) -> Dict[str, beam.PCollection]:
   315    """Reads PCollection cache through the TestStream.
   316  
   317    Args:
   318      user_pipeline: The beam.Pipeline object defined by the user in the
   319          notebook.
   320      query_pipeline: The beam.Pipeline object built by the magic to execute the
   321          SQL query.
   322      name_to_pcoll: PCollections with variable names used in the SQL query.
   323  
   324    Returns:
   325      A Dict[str, beam.PCollection], where each PCollection is tagged with
   326      their PCollection variable names, read from the cache.
   327  
   328    When the user_pipeline has unbounded sources, we force all cache reads to go
   329    through the TestStream even if they are bounded sources.
   330    """
   331    def exception_handler(e):
   332      _LOGGER.error(str(e))
   333      return True
   334  
   335    cache_manager = ie.current_env().get_cache_manager(
   336        user_pipeline, create_if_absent=True)
   337    test_stream_service = ie.current_env().get_test_stream_service_controller(
   338        user_pipeline)
   339    if not test_stream_service:
   340      test_stream_service = TestStreamServiceController(
   341          cache_manager, exception_handler=exception_handler)
   342      test_stream_service.start()
   343      ie.current_env().set_test_stream_service_controller(
   344          user_pipeline, test_stream_service)
   345  
   346    tag_to_name = {}
   347    for name, pcoll in name_to_pcoll.items():
   348      key = CacheKey.from_pcoll(name, pcoll).to_str()
   349      tag_to_name[key] = name
   350    output_pcolls = query_pipeline | test_stream.TestStream(
   351        output_tags=set(tag_to_name.keys()),
   352        coder=cache_manager._default_pcoder,
   353        endpoint=test_stream_service.endpoint)
   354    sql_source = {}
   355    for tag, output in output_pcolls.items():
   356      name = tag_to_name[tag]
   357      # Must mark the element_type to avoid introducing pickled Python coder
   358      # to the Java expansion service.
   359      output.element_type = name_to_pcoll[name].element_type
   360      sql_source[name] = output
   361    return sql_source
   362  
   363  
   364  def _generate_output_name(
   365      output_name: Optional[str], query: str,
   366      found: Dict[str, beam.PCollection]) -> str:
   367    """Generates a unique output name if None is provided.
   368  
   369    Otherwise, returns the given output name directly.
   370    The generated output name is sql_output_{uuid} where uuid is an obfuscated
   371    value from the query and PCollections found to be used in the query.
   372    """
   373    if not output_name:
   374      execution_id = obfuscate(query, found)[:12]
   375      output_name = 'sql_output_' + execution_id
   376    return output_name
   377  
   378  
   379  def _build_query_components(
   380      query: str,
   381      found: Dict[str, beam.PCollection],
   382      output_name: str,
   383      run: bool = True
   384  ) -> Tuple[str,
   385             Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline],
   386             SqlChain]:
   387    """Builds necessary components needed to apply the SqlTransform.
   388  
   389    Args:
   390      query: The SQL query to be executed by the magic.
   391      found: The PCollections with variable names found to be used by the query.
   392      output_name: The output variable name in __main__ module.
   393      run: Whether to prepare components for a local run or not.
   394  
   395    Returns:
   396      The processed query to be executed by the magic; a source to apply the
   397      SqlTransform to: a dictionary of tagged PCollections, or a single
   398      PCollection, or the pipeline to execute the query; the chain of applied
   399      beam_sql magics this one belongs to.
   400    """
   401    if found:
   402      user_pipeline = ie.current_env().user_pipeline(
   403          next(iter(found.values())).pipeline)
   404      sql_pipeline = beam.Pipeline(options=user_pipeline._options)
   405      ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
   406      sql_source = {}
   407      if run:
   408        if has_source_to_cache(user_pipeline):
   409          sql_source = pcolls_from_streaming_cache(
   410              user_pipeline, sql_pipeline, found)
   411        else:
   412          cache_manager = ie.current_env().get_cache_manager(
   413              user_pipeline, create_if_absent=True)
   414          for pcoll_name, pcoll in found.items():
   415            cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
   416            sql_source[pcoll_name] = unreify_from_cache(
   417                pipeline=sql_pipeline,
   418                cache_key=cache_key,
   419                cache_manager=cache_manager,
   420                element_type=pcoll.element_type)
   421      else:
   422        sql_source = found
   423      if len(sql_source) == 1:
   424        query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
   425        sql_source = next(iter(sql_source.values()))
   426  
   427      node = SqlNode(
   428          output_name=output_name, source=set(found.keys()), query=query)
   429      chain = ie.current_env().get_sql_chain(
   430          user_pipeline, set_user_pipeline=True).append(node)
   431    else:  # does not query any existing PCollection
   432      sql_source = beam.Pipeline()
   433      ie.current_env().add_user_pipeline(sql_source)
   434  
   435      # The node should be the root node of the chain created below.
   436      node = SqlNode(output_name=output_name, source=sql_source, query=query)
   437      chain = ie.current_env().get_sql_chain(sql_source).append(node)
   438    return query, sql_source, chain
   439  
   440  
   441  @progress_indicated
   442  def cache_output(output_name: str, output: PValue) -> None:
   443    user_pipeline = ie.current_env().user_pipeline(output.pipeline)
   444    if user_pipeline:
   445      cache_manager = ie.current_env().get_cache_manager(
   446          user_pipeline, create_if_absent=True)
   447    else:
   448      _LOGGER.warning(
   449          'Something is wrong with %s. Cannot introspect its data.', output)
   450      return
   451    key = CacheKey.from_pcoll(output_name, output).to_str()
   452    _ = reify_to_cache(pcoll=output, cache_key=key, cache_manager=cache_manager)
   453    try:
   454      output.pipeline.run().wait_until_finish()
   455    except (KeyboardInterrupt, SystemExit):
   456      raise
   457    except:  # pylint: disable=bare-except
   458      _LOGGER.warning(
   459          _NOT_SUPPORTED_MSG, traceback.format_exc(), output.pipeline.runner)
   460      return
   461    ie.current_env().mark_pcollection_computed([output])
   462    visualize_computed_pcoll(
   463        output_name, output, max_n=float('inf'), max_duration_secs=float('inf'))
   464  
   465  
   466  def load_ipython_extension(ipython):
   467    """Marks this module as an IPython extension.
   468  
   469    To load this magic in an IPython environment, execute:
   470    %load_ext apache_beam.runners.interactive.sql.beam_sql_magics.
   471    """
   472    ipython.register_magics(BeamSqlMagics)