github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/utils.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #  http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Utilities to be used in  Interactive Beam.
    19  """
    20  
    21  import functools
    22  import hashlib
    23  import importlib
    24  import json
    25  import logging
    26  from typing import Any
    27  from typing import Dict
    28  from typing import Tuple
    29  
    30  import pandas as pd
    31  
    32  import apache_beam as beam
    33  from apache_beam.dataframe.convert import to_pcollection
    34  from apache_beam.dataframe.frame_base import DeferredBase
    35  from apache_beam.internal.gcp import auth
    36  from apache_beam.internal.http_client import get_new_http
    37  from apache_beam.io.gcp.internal.clients import storage
    38  from apache_beam.options.pipeline_options import PipelineOptions
    39  from apache_beam.pipeline import Pipeline
    40  from apache_beam.portability.api import beam_runner_api_pb2
    41  from apache_beam.runners.interactive.caching.cacheable import Cacheable
    42  from apache_beam.runners.interactive.caching.cacheable import CacheKey
    43  from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
    44  from apache_beam.testing.test_stream import WindowedValueHolder
    45  from apache_beam.typehints.schemas import named_fields_from_element_type
    46  
    47  _LOGGER = logging.getLogger(__name__)
    48  
    49  # Add line breaks to the IPythonLogHandler's HTML output.
    50  _INTERACTIVE_LOG_STYLE = """
    51    <style>
    52      div.alert {
    53        white-space: pre-line;
    54      }
    55    </style>
    56  """
    57  
    58  
    59  def to_element_list(
    60      reader,  # type: Generator[Union[beam_runner_api_pb2.TestStreamPayload.Event, WindowedValueHolder]] # noqa: F821
    61      coder,  # type: Coder # noqa: F821
    62      include_window_info,  # type: bool
    63      n=None,  # type: int
    64      include_time_events=False, # type: bool
    65  ):
    66    # type: (...) -> List[WindowedValue] # noqa: F821
    67  
    68    """Returns an iterator that properly decodes the elements from the reader.
    69    """
    70  
    71    # Defining a generator like this makes it easier to limit the count of
    72    # elements read. Otherwise, the count limit would need to be duplicated.
    73    def elements():
    74      for e in reader:
    75        if isinstance(e, beam_runner_api_pb2.TestStreamPayload.Event):
    76          if (e.HasField('watermark_event') or
    77              e.HasField('processing_time_event')):
    78            if include_time_events:
    79              yield e
    80          else:
    81            for tv in e.element_event.elements:
    82              decoded = coder.decode(tv.encoded_element)
    83              yield (
    84                  decoded.windowed_value
    85                  if include_window_info else decoded.windowed_value.value)
    86        elif isinstance(e, WindowedValueHolder):
    87          yield (
    88              e.windowed_value if include_window_info else e.windowed_value.value)
    89        else:
    90          yield e
    91  
    92    # Because we can yield multiple elements from a single TestStreamFileRecord,
    93    # we have to limit the count here to ensure that `n` is fulfilled.
    94    count = 0
    95    for e in elements():
    96      if n and count >= n:
    97        break
    98  
    99      yield e
   100  
   101      if not isinstance(e, beam_runner_api_pb2.TestStreamPayload.Event):
   102        count += 1
   103  
   104  
   105  def elements_to_df(elements, include_window_info=False, element_type=None):
   106    # type: (List[WindowedValue], bool, Any) -> DataFrame # noqa: F821
   107  
   108    """Parses the given elements into a Dataframe.
   109  
   110    If the elements are a list of WindowedValues, then it will break out the
   111    elements into their own DataFrame and return it. If include_window_info is
   112    True, then it will concatenate the windowing information onto the elements
   113    DataFrame.
   114    """
   115    try:
   116      columns_names = [
   117          name for name, _ in named_fields_from_element_type(element_type)
   118      ]
   119    except TypeError:
   120      columns_names = None
   121  
   122    rows = []
   123    windowed_info = []
   124    for e in elements:
   125      rows.append(e.value)
   126      if include_window_info:
   127        windowed_info.append([e.timestamp.micros, e.windows, e.pane_info])
   128  
   129    using_dataframes = isinstance(element_type, pd.DataFrame)
   130    using_series = isinstance(element_type, pd.Series)
   131    if using_dataframes or using_series:
   132      rows_df = pd.concat(rows)
   133    else:
   134      rows_df = pd.DataFrame(rows, columns=columns_names)
   135  
   136    if include_window_info and not using_series:
   137      windowed_info_df = pd.DataFrame(
   138          windowed_info, columns=['event_time', 'windows', 'pane_info'])
   139      final_df = pd.concat([rows_df, windowed_info_df], axis=1)
   140    else:
   141      final_df = rows_df
   142  
   143    return final_df
   144  
   145  
   146  def register_ipython_log_handler():
   147    # type: () -> None
   148  
   149    """Adds the IPython handler to a dummy parent logger (named
   150    'apache_beam.runners.interactive') of all interactive modules' loggers so that
   151    if is_in_notebook, logging displays the logs as HTML in frontends.
   152    """
   153  
   154    # apache_beam.runners.interactive is not a module, thus this "root" logger is
   155    # a dummy one created to hold the IPython log handler. When children loggers
   156    # have propagate as True (by default) and logging level as NOTSET (by default,
   157    # so the "root" logger's logging level takes effect), the IPython log handler
   158    # will be triggered at the "root"'s own logging level. And if a child logger
   159    # sets its logging level, it can take control back.
   160    interactive_root_logger = logging.getLogger('apache_beam.runners.interactive')
   161    if any(isinstance(h, IPythonLogHandler)
   162           for h in interactive_root_logger.handlers):
   163      return
   164    interactive_root_logger.setLevel(logging.INFO)
   165    interactive_root_logger.addHandler(IPythonLogHandler())
   166    # Disable the propagation so that logs emitted from interactive modules should
   167    # only be handled by loggers and handlers defined within interactive packages.
   168    interactive_root_logger.propagate = False
   169  
   170  
   171  class IPythonLogHandler(logging.Handler):
   172    """A logging handler to display logs as HTML in IPython backed frontends."""
   173    # TODO(BEAM-7923): Switch to Google hosted CDN once
   174    # https://code.google.com/archive/p/google-ajax-apis/issues/637 is resolved.
   175    log_template = """
   176              <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.4.1/css/bootstrap.min.css" integrity="sha384-Vkoo8x4CGsO3+Hhxv8T/Q5PaXtkKtu6ug5TOeNV6gBiFeWPGFN9MuhOf23Q9Ifjh" crossorigin="anonymous">
   177              <div class="alert alert-{level}">{msg}</div>"""
   178  
   179    logging_to_alert_level_map = {
   180        logging.CRITICAL: 'danger',
   181        logging.ERROR: 'danger',
   182        logging.WARNING: 'warning',
   183        logging.INFO: 'info',
   184        logging.DEBUG: 'dark',
   185        logging.NOTSET: 'light'
   186    }
   187  
   188    def emit(self, record):
   189      try:
   190        from html import escape
   191        from IPython.display import HTML
   192        from IPython.display import display
   193        display(HTML(_INTERACTIVE_LOG_STYLE))
   194        display(
   195            HTML(
   196                self.log_template.format(
   197                    level=self.logging_to_alert_level_map[record.levelno],
   198                    msg=escape(record.msg % record.args))))
   199      except ImportError:
   200        pass  # NOOP when dependencies are not available.
   201  
   202  
   203  def obfuscate(*inputs):
   204    # type: (*Any) -> str
   205  
   206    """Obfuscates any inputs into a hexadecimal string."""
   207    str_inputs = [str(input) for input in inputs]
   208    merged_inputs = '_'.join(str_inputs)
   209    return hashlib.md5(merged_inputs.encode('utf-8')).hexdigest()
   210  
   211  
   212  class ProgressIndicator(object):
   213    """An indicator visualizing code execution in progress."""
   214    # TODO(BEAM-7923): Switch to Google hosted CDN once
   215    # https://code.google.com/archive/p/google-ajax-apis/issues/637 is resolved.
   216    spinner_template = """
   217              <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.4.1/css/bootstrap.min.css" integrity="sha384-Vkoo8x4CGsO3+Hhxv8T/Q5PaXtkKtu6ug5TOeNV6gBiFeWPGFN9MuhOf23Q9Ifjh" crossorigin="anonymous">
   218              <div id="{id}">
   219                <div class="spinner-border text-info" role="status"></div>
   220                <span class="text-info">{text}</span>
   221              </div>
   222              """
   223    spinner_removal_template = """
   224              $("#{id}").remove();"""
   225  
   226    def __init__(self, enter_text, exit_text):
   227      # type: (str, str) -> None
   228  
   229      self._id = 'progress_indicator_{}'.format(obfuscate(id(self)))
   230      self._enter_text = enter_text
   231      self._exit_text = exit_text
   232  
   233    def __enter__(self):
   234      try:
   235        from IPython.display import HTML
   236        from IPython.display import display
   237        from apache_beam.runners.interactive import interactive_environment as ie
   238        if ie.current_env().is_in_notebook:
   239          display(
   240              HTML(
   241                  self.spinner_template.format(
   242                      id=self._id, text=self._enter_text)))
   243        else:
   244          display(self._enter_text)
   245      except ImportError as e:
   246        _LOGGER.error(
   247            'Please use interactive Beam features in an IPython'
   248            'or notebook environment: %s' % e)
   249  
   250    def __exit__(self, exc_type, exc_value, traceback):
   251      try:
   252        from IPython.display import Javascript
   253        from IPython.display import display
   254        from IPython.display import display_javascript
   255        from apache_beam.runners.interactive import interactive_environment as ie
   256        if ie.current_env().is_in_notebook:
   257          script = self.spinner_removal_template.format(id=self._id)
   258          display_javascript(
   259              Javascript(
   260                  ie._JQUERY_WITH_DATATABLE_TEMPLATE.format(
   261                      customized_script=script)))
   262        else:
   263          display(self._exit_text)
   264      except ImportError as e:
   265        _LOGGER.error(
   266            'Please use interactive Beam features in an IPython'
   267            'or notebook environment: %s' % e)
   268  
   269  
   270  def progress_indicated(func):
   271    # type: (Callable[..., Any]) -> Callable[..., Any] # noqa: F821
   272  
   273    """A decorator using a unique progress indicator as a context manager to
   274    execute the given function within."""
   275    @functools.wraps(func)
   276    def run_within_progress_indicator(*args, **kwargs):
   277      with ProgressIndicator(f'Processing... {func.__name__}', 'Done.'):
   278        return func(*args, **kwargs)
   279  
   280    return run_within_progress_indicator
   281  
   282  
   283  def as_json(func):
   284    # type: (Callable[..., Any]) -> Callable[..., str] # noqa: F821
   285  
   286    """A decorator convert python objects returned by callables to json
   287    string.
   288  
   289    The decorated function should always return an object parsable by json.dumps.
   290    If the object is not parsable, the str() of original object is returned
   291    instead.
   292    """
   293    def return_as_json(*args, **kwargs):
   294      try:
   295        return_value = func(*args, **kwargs)
   296        return json.dumps(return_value)
   297      except TypeError:
   298        return str(return_value)
   299  
   300    return return_as_json
   301  
   302  
   303  def deferred_df_to_pcollection(df):
   304    assert isinstance(df, DeferredBase), '{} is not a DeferredBase'.format(df)
   305  
   306    # The proxy is used to output a DataFrame with the correct columns.
   307    #
   308    # TODO(https://github.com/apache/beam/issues/20577): Once type hints are
   309    # implemented for pandas, use those instead of the proxy.
   310    cache = ExpressionCache()
   311    cache.replace_with_cached(df._expr)
   312  
   313    proxy = df._expr.proxy()
   314    return to_pcollection(df, yield_elements='pandas', label=str(df._expr)), proxy
   315  
   316  
   317  def pcoll_by_name() -> Dict[str, beam.PCollection]:
   318    """Finds all PCollections by their variable names defined in the notebook."""
   319    from apache_beam.runners.interactive import interactive_environment as ie
   320  
   321    inspectables = ie.current_env().inspector_with_synthetic.inspectables
   322    pcolls = {}
   323    for _, inspectable in inspectables.items():
   324      metadata = inspectable['metadata']
   325      if metadata['type'] == 'pcollection':
   326        pcolls[metadata['name']] = inspectable['value']
   327    return pcolls
   328  
   329  
   330  def find_pcoll_name(pcoll: beam.PCollection) -> str:
   331    """Finds the variable name of a PCollection defined by the user.
   332  
   333    Returns None if not assigned to any variable.
   334    """
   335    from apache_beam.runners.interactive import interactive_environment as ie
   336  
   337    inspectables = ie.current_env().inspector.inspectables
   338    for _, inspectable in inspectables.items():
   339      if inspectable['value'] is pcoll:
   340        return inspectable['metadata']['name']
   341    return None
   342  
   343  
   344  def cacheables() -> Dict[CacheKey, Cacheable]:
   345    """Finds all Cacheables with their CacheKeys."""
   346    from apache_beam.runners.interactive import interactive_environment as ie
   347  
   348    inspectables = ie.current_env().inspector_with_synthetic.inspectables
   349    cacheables = {}
   350    for _, inspectable in inspectables.items():
   351      metadata = inspectable['metadata']
   352      if metadata['type'] == 'pcollection':
   353        cacheable = Cacheable.from_pcoll(metadata['name'], inspectable['value'])
   354        cacheables[cacheable.to_key()] = cacheable
   355    return cacheables
   356  
   357  
   358  def watch_sources(pipeline):
   359    """Watches the unbounded sources in the pipeline.
   360  
   361    Sources can output to a PCollection without a user variable reference. In
   362    this case the source is not cached. We still want to cache the data so we
   363    synthetically create a variable to the intermediate PCollection.
   364    """
   365    from apache_beam.pipeline import PipelineVisitor
   366    from apache_beam.runners.interactive import interactive_environment as ie
   367  
   368    retrieved_user_pipeline = ie.current_env().user_pipeline(pipeline)
   369    pcoll_to_name = {v: k for k, v in pcoll_by_name().items()}
   370  
   371    class CacheableUnboundedPCollectionVisitor(PipelineVisitor):
   372      def __init__(self):
   373        self.unbounded_pcolls = set()
   374  
   375      def enter_composite_transform(self, transform_node):
   376        self.visit_transform(transform_node)
   377  
   378      def visit_transform(self, transform_node):
   379        if isinstance(transform_node.transform,
   380                      tuple(ie.current_env().options.recordable_sources)):
   381          for pcoll in transform_node.outputs.values():
   382            # Only generate a synthetic var when it's not already watched. For
   383            # example, the user could have assigned the unbounded source output
   384            # to a variable, watching it again with a different variable name
   385            # creates ambiguity.
   386            if pcoll not in pcoll_to_name:
   387              ie.current_env().watch({'synthetic_var_' + str(id(pcoll)): pcoll})
   388  
   389    retrieved_user_pipeline.visit(CacheableUnboundedPCollectionVisitor())
   390  
   391  
   392  def has_unbounded_sources(pipeline):
   393    """Checks if a given pipeline has recordable sources."""
   394    return len(unbounded_sources(pipeline)) > 0
   395  
   396  
   397  def unbounded_sources(pipeline):
   398    """Returns a pipeline's recordable sources."""
   399    from apache_beam.pipeline import PipelineVisitor
   400    from apache_beam.runners.interactive import interactive_environment as ie
   401  
   402    class CheckUnboundednessVisitor(PipelineVisitor):
   403      """Visitor checks if there are any unbounded read sources in the Pipeline.
   404  
   405      Visitor visits all nodes and checks if it is an instance of recordable
   406      sources.
   407      """
   408      def __init__(self):
   409        self.unbounded_sources = []
   410  
   411      def enter_composite_transform(self, transform_node):
   412        self.visit_transform(transform_node)
   413  
   414      def visit_transform(self, transform_node):
   415        if isinstance(transform_node.transform,
   416                      tuple(ie.current_env().options.recordable_sources)):
   417          self.unbounded_sources.append(transform_node)
   418  
   419    v = CheckUnboundednessVisitor()
   420    pipeline.visit(v)
   421    return v.unbounded_sources
   422  
   423  
   424  def create_var_in_main(name: str,
   425                         value: Any,
   426                         watch: bool = True) -> Tuple[str, Any]:
   427    """Declares a variable in the main module.
   428  
   429    Args:
   430      name: the variable name in the main module.
   431      value: the value of the variable.
   432      watch: whether to watch it in the interactive environment.
   433    Returns:
   434      A 2-entry tuple of the variable name and value.
   435    """
   436    setattr(importlib.import_module('__main__'), name, value)
   437    if watch:
   438      from apache_beam.runners.interactive import interactive_environment as ie
   439      ie.current_env().watch({name: value})
   440    return name, value
   441  
   442  
   443  def assert_bucket_exists(bucket_name):
   444    # type: (str) -> None
   445  
   446    """Asserts whether the specified GCS bucket with the name
   447    bucket_name exists.
   448  
   449      Logs an error and raises a ValueError if the bucket does not exist.
   450  
   451      Logs a warning if the bucket cannot be verified to exist.
   452    """
   453    try:
   454      from apitools.base.py.exceptions import HttpError
   455      storage_client = storage.StorageV1(
   456          credentials=auth.get_service_credentials(PipelineOptions()),
   457          get_credentials=False,
   458          http=get_new_http(),
   459          response_encoding='utf8')
   460      request = storage.StorageBucketsGetRequest(bucket=bucket_name)
   461      storage_client.buckets.Get(request)
   462    except HttpError as e:
   463      if e.status_code == 404:
   464        _LOGGER.error('%s bucket does not exist!', bucket_name)
   465        raise ValueError('Invalid GCS bucket provided!')
   466      else:
   467        _LOGGER.warning(
   468            'HttpError - unable to verify whether bucket %s exists', bucket_name)
   469    except ImportError:
   470      _LOGGER.warning(
   471          'ImportError - unable to verify whether bucket %s exists', bucket_name)
   472  
   473  
   474  def detect_pipeline_runner(pipeline):
   475    if isinstance(pipeline, Pipeline):
   476      from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
   477      if isinstance(pipeline.runner, InteractiveRunner):
   478        pipeline_runner = pipeline.runner._underlying_runner
   479      else:
   480        pipeline_runner = pipeline.runner
   481    else:
   482      pipeline_runner = None
   483    return pipeline_runner