github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/display/pipeline_graph.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """For generating Beam pipeline graph in DOT representation.
    19  
    20  This module is experimental. No backwards-compatibility guarantees.
    21  """
    22  
    23  # pytype: skip-file
    24  
    25  import collections
    26  import logging
    27  import threading
    28  from typing import DefaultDict
    29  from typing import Dict
    30  from typing import Iterator
    31  from typing import List
    32  from typing import Tuple
    33  from typing import Union
    34  
    35  import pydot
    36  
    37  import apache_beam as beam
    38  from apache_beam.portability.api import beam_runner_api_pb2
    39  from apache_beam.runners.interactive import interactive_environment as ie
    40  from apache_beam.runners.interactive import pipeline_instrument as inst
    41  from apache_beam.runners.interactive.display import pipeline_graph_renderer
    42  
    43  # pylint does not understand context
    44  # pylint:disable=dangerous-default-value
    45  
    46  
    47  class PipelineGraph(object):
    48    """Creates a DOT representing the pipeline. Thread-safe. Runner agnostic."""
    49    def __init__(
    50        self,
    51        pipeline,  # type: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline]
    52        default_vertex_attrs={'shape': 'box'},
    53        default_edge_attrs=None,
    54        render_option=None):
    55      """Constructor of PipelineGraph.
    56  
    57      Examples:
    58        graph = pipeline_graph.PipelineGraph(pipeline_proto)
    59        graph.get_dot()
    60  
    61        or
    62  
    63        graph = pipeline_graph.PipelineGraph(pipeline)
    64        graph.get_dot()
    65  
    66      Args:
    67        pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
    68        default_vertex_attrs: (Dict[str, str]) a dict of default vertex attributes
    69        default_edge_attrs: (Dict[str, str]) a dict of default edge attributes
    70        render_option: (str) this parameter decides how the pipeline graph is
    71            rendered. See display.pipeline_graph_renderer for available options.
    72      """
    73      self._lock = threading.Lock()
    74      self._graph = None  # type: pydot.Dot
    75      self._pipeline_instrument = None
    76      if isinstance(pipeline, beam.Pipeline):
    77        self._pipeline_instrument = inst.PipelineInstrument(
    78            pipeline, pipeline._options)
    79        # The pre-process links user pipeline to runner pipeline through analysis
    80        # but without mutating runner pipeline.
    81        self._pipeline_instrument.preprocess()
    82  
    83      if isinstance(pipeline, beam_runner_api_pb2.Pipeline):
    84        self._pipeline_proto = pipeline
    85      elif isinstance(pipeline, beam.Pipeline):
    86        self._pipeline_proto = pipeline.to_runner_api()
    87      else:
    88        raise TypeError(
    89            'pipeline should either be a %s or %s, while %s is given' %
    90            (beam_runner_api_pb2.Pipeline, beam.Pipeline, type(pipeline)))
    91  
    92      # A dict from PCollection ID to a list of its consuming Transform IDs
    93      self._consumers = collections.defaultdict(
    94          list)  # type: DefaultDict[str, List[str]]
    95      # A dict from PCollection ID to its producing Transform ID
    96      self._producers = {}  # type: Dict[str, str]
    97  
    98      for transform_id, transform_proto in self._top_level_transforms():
    99        for pcoll_id in transform_proto.inputs.values():
   100          self._consumers[pcoll_id].append(transform_id)
   101        for pcoll_id in transform_proto.outputs.values():
   102          self._producers[pcoll_id] = transform_id
   103  
   104      default_vertex_attrs = default_vertex_attrs or {'shape': 'box'}
   105      if 'color' not in default_vertex_attrs:
   106        default_vertex_attrs['color'] = 'blue'
   107      if 'fontcolor' not in default_vertex_attrs:
   108        default_vertex_attrs['fontcolor'] = 'blue'
   109  
   110      vertex_dict, edge_dict = self._generate_graph_dicts()
   111      self._construct_graph(
   112          vertex_dict, edge_dict, default_vertex_attrs, default_edge_attrs)
   113  
   114      self._renderer = pipeline_graph_renderer.get_renderer(render_option)
   115  
   116    def get_dot(self):
   117      # type: () -> str
   118      return self._get_graph().to_string()
   119  
   120    def display_graph(self):
   121      """Displays the graph generated."""
   122      rendered_graph = self._renderer.render_pipeline_graph(self)
   123      if ie.current_env().is_in_notebook:
   124        try:
   125          from IPython import display
   126          display.display(display.HTML(rendered_graph))
   127        except ImportError:  # Unlikely to happen when is_in_notebook.
   128          logging.warning(
   129              'Failed to import IPython display module when current '
   130              'environment is in a notebook. Cannot display the '
   131              'pipeline graph.')
   132  
   133    def _top_level_transforms(self):
   134      # type: () -> Iterator[Tuple[str, beam_runner_api_pb2.PTransform]]
   135  
   136      """Yields all top level PTransforms (subtransforms of the root PTransform).
   137  
   138      Yields: (str, PTransform proto) ID, proto pair of top level PTransforms.
   139      """
   140      transforms = self._pipeline_proto.components.transforms
   141      for root_transform_id in self._pipeline_proto.root_transform_ids:
   142        root_transform_proto = transforms[root_transform_id]
   143        for top_level_transform_id in root_transform_proto.subtransforms:
   144          top_level_transform_proto = transforms[top_level_transform_id]
   145          yield top_level_transform_id, top_level_transform_proto
   146  
   147    def _decorate(self, value):
   148      """Decorates label-ish values used for rendering in dot language.
   149  
   150      Escapes special characters in the given str value for dot language. All
   151      PTransform unique names are escaped implicitly in this module when building
   152      dot representation. Otherwise, special characters will break the graph
   153      rendered or cause runtime errors.
   154      """
   155      # Replace py str literal `\\` which is `\` in dot with py str literal
   156      # `\\\\` which is `\\` in dot so that dot `\\` can be rendered as `\`. Then
   157      # replace `"` with `\\"` so that the dot generated will be `\"` and be
   158      # rendered as `"`.
   159      return '"{}"'.format(value.replace('\\', '\\\\').replace('"', '\\"'))
   160  
   161    def _generate_graph_dicts(self):
   162      """From pipeline_proto and other info, generate the graph.
   163  
   164      Returns:
   165        vertex_dict: (Dict[str, Dict[str, str]]) vertex mapped to attributes.
   166        edge_dict: (Dict[(str, str), Dict[str, str]]) vertex pair mapped to the
   167            edge's attribute.
   168      """
   169      transforms = self._pipeline_proto.components.transforms
   170  
   171      # A dict from vertex name (i.e. PCollection ID) to its attributes.
   172      vertex_dict = collections.defaultdict(dict)
   173      # A dict from vertex name pairs defining the edge (i.e. a pair of PTransform
   174      # IDs defining the PCollection) to its attributes.
   175      edge_dict = collections.defaultdict(dict)
   176  
   177      self._edge_to_vertex_pairs = collections.defaultdict(list)
   178  
   179      for _, transform in self._top_level_transforms():
   180        vertex_dict[self._decorate(transform.unique_name)] = {}
   181  
   182        for pcoll_id in transform.outputs.values():
   183          pcoll_node = None
   184          if self._pipeline_instrument:
   185            cacheable = self._pipeline_instrument.cacheables.get(pcoll_id)
   186            pcoll_node = cacheable.var if cacheable else None
   187          # If no PipelineInstrument is available or the PCollection is not
   188          # watched.
   189          if not pcoll_node:
   190            pcoll_node = 'pcoll%s' % (hash(pcoll_id) % 10000)
   191            vertex_dict[pcoll_node] = {
   192                'shape': 'circle',
   193                'label': '',  # The pcoll node has no name.
   194            }
   195          # There is PipelineInstrument and the PCollection is watched with an
   196          # assigned variable.
   197          else:
   198            vertex_dict[pcoll_node] = {'shape': 'circle'}
   199          if pcoll_id not in self._consumers:
   200            self._edge_to_vertex_pairs[pcoll_id].append(
   201                (self._decorate(transform.unique_name), pcoll_node))
   202            edge_dict[(self._decorate(transform.unique_name), pcoll_node)] = {}
   203          else:
   204            for consumer in self._consumers[pcoll_id]:
   205              producer_name = self._decorate(transform.unique_name)
   206              consumer_name = self._decorate(transforms[consumer].unique_name)
   207              self._edge_to_vertex_pairs[pcoll_id].append(
   208                  (producer_name, pcoll_node))
   209              edge_dict[(producer_name, pcoll_node)] = {}
   210              self._edge_to_vertex_pairs[pcoll_id].append(
   211                  (pcoll_node, consumer_name))
   212              edge_dict[(pcoll_node, consumer_name)] = {}
   213  
   214      return vertex_dict, edge_dict
   215  
   216    def _get_graph(self):
   217      """Returns pydot.Dot object for the pipeline graph.
   218  
   219      The purpose of this method is to avoid accessing the graph while it is
   220      updated. No one except for this method should be accessing _graph directly.
   221  
   222      Returns:
   223        (pydot.Dot)
   224      """
   225      with self._lock:
   226        return self._graph
   227  
   228    def _construct_graph(
   229        self, vertex_dict, edge_dict, default_vertex_attrs, default_edge_attrs):
   230      """Constructs the pydot.Dot object for the pipeline graph.
   231  
   232      Args:
   233        vertex_dict: (Dict[str, Dict[str, str]]) maps vertex names to attributes
   234        edge_dict: (Dict[(str, str), Dict[str, str]]) maps vertex name pairs to
   235            attributes
   236        default_vertex_attrs: (Dict[str, str]) a dict of attributes
   237        default_edge_attrs: (Dict[str, str]) a dict of attributes
   238      """
   239      with self._lock:
   240        self._graph = pydot.Dot()
   241  
   242        if default_vertex_attrs:
   243          self._graph.set_node_defaults(**default_vertex_attrs)
   244        if default_edge_attrs:
   245          self._graph.set_edge_defaults(**default_edge_attrs)
   246  
   247        self._vertex_refs = {}  # Maps vertex name to pydot.Node
   248        self._edge_refs = {}  # Maps vertex name pairs to pydot.Edge
   249  
   250        for vertex, vertex_attrs in vertex_dict.items():
   251          vertex_ref = pydot.Node(vertex, **vertex_attrs)
   252          self._vertex_refs[vertex] = vertex_ref
   253          self._graph.add_node(vertex_ref)
   254  
   255        for edge, edge_attrs in edge_dict.items():
   256          vertex_src = self._vertex_refs[edge[0]]
   257          vertex_dst = self._vertex_refs[edge[1]]
   258  
   259          edge_ref = pydot.Edge(vertex_src, vertex_dst, **edge_attrs)
   260          self._edge_refs[edge] = edge_ref
   261          self._graph.add_edge(edge_ref)
   262  
   263    def _update_graph(self, vertex_dict=None, edge_dict=None):
   264      """Updates the pydot.Dot object with the given attribute update
   265  
   266      Args:
   267        vertex_dict: (Dict[str, Dict[str, str]]) maps vertex names to attributes
   268        edge_dict: This should be
   269            Either (Dict[str, Dict[str, str]]) which maps edge names to attributes
   270            Or (Dict[(str, str), Dict[str, str]]) which maps vertex pairs to edge
   271            attributes
   272      """
   273      def set_attrs(ref, attrs):
   274        for attr_name, attr_val in attrs.items():
   275          ref.set(attr_name, attr_val)
   276  
   277      with self._lock:
   278        if vertex_dict:
   279          for vertex, vertex_attrs in vertex_dict.items():
   280            set_attrs(self._vertex_refs[vertex], vertex_attrs)
   281        if edge_dict:
   282          for edge, edge_attrs in edge_dict.items():
   283            if isinstance(edge, tuple):
   284              set_attrs(self._edge_refs[edge], edge_attrs)
   285            else:
   286              for vertex_pair in self._edge_to_vertex_pairs[edge]:
   287                set_attrs(self._edge_refs[vertex_pair], edge_attrs)