github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/render.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A portable "runner" that renders a beam graph.
    19  
    20  This runner can either render the graph to a (set of) output path(s), as
    21  designated by (possibly repeated) --render_output, or serve the pipeline as
    22  an interactive graph, if --render_port is set.
    23  
    24  In Python, this runner can be passed directly at pipeline construction, e.g.::
    25  
    26     with beam.Pipeline(runner=beam.runners.render.RenderRunner(), options=...)
    27  
    28  For other languages, start this service a by running::
    29  
    30    python -m apache_beam.runners.render --job_port=PORT ...
    31  
    32  and then run your pipline with the PortableRunner setting the job endpoint
    33  to `localhost:PORT`.
    34  
    35  If any `--render_output=path.ext` flags are passed, each submitted job will
    36  get written to the given output (overwriting any previously existing file).
    37  
    38  If `--render_port` is set to a non-negative value, a local http server will
    39  be started which allows for interactive exploration of the pipeline graph.
    40  
    41  As an alternative to starting a job server, a single pipeline can be rendered
    42  by passing a pipeline proto file to `--pipeline_proto`.  For example
    43  
    44    python -m apache_beam.runners.render  \\
    45        --pipeline_proto gs://<staging_location>/pipeline.pb  \\
    46        --render_output=/tmp/pipeline.svg
    47  
    48  Requires the graphviz dot executable to be available in the path.
    49  """
    50  
    51  import argparse
    52  import base64
    53  import collections
    54  import http.server
    55  import json
    56  import logging
    57  import os
    58  import re
    59  import subprocess
    60  import sys
    61  import tempfile
    62  import threading
    63  import time
    64  import urllib.parse
    65  
    66  from google.protobuf import json_format
    67  from google.protobuf import text_format  # type: ignore
    68  
    69  from apache_beam.options import pipeline_options
    70  from apache_beam.portability.api import beam_runner_api_pb2
    71  from apache_beam.runners import runner
    72  from apache_beam.runners.portability import local_job_service
    73  from apache_beam.runners.portability import local_job_service_main
    74  from apache_beam.runners.portability.fn_api_runner import translations
    75  
    76  try:
    77    from apache_beam.io.gcp import gcsio
    78  except ImportError:
    79    gcsio = None  # type: ignore
    80  
    81  # From the Beam site, circa November 2022.
    82  DEFAULT_EDGE_STYLE = 'color="#ff570b"'
    83  DEFAULT_TRANSFORM_STYLE = (
    84      'shape=rect style="rounded, filled" color="#ff570b" fillcolor="#fff6dd"')
    85  DEFAULT_HIGHLIGHT_STYLE = (
    86      'shape=rect style="rounded, filled" color="#ff570b" fillcolor="#ffdb97"')
    87  
    88  
    89  class RenderOptions(pipeline_options.PipelineOptions):
    90    """Rendering options."""
    91    @classmethod
    92    def _add_argparse_args(cls, parser):
    93      parser.add_argument(
    94          '--render_port',
    95          type=int,
    96          default=-1,
    97          help='The port at which to serve the graph. '
    98          'If 0, an unused port will be chosen. '
    99          'If -1, the server will not be started.')
   100      parser.add_argument(
   101          '--render_output',
   102          action='append',
   103          help='A path or paths to which to write rendered output. '
   104          'The output type will be deduced from the file extension.')
   105      parser.add_argument(
   106          '--render_leaf_composite_nodes',
   107          action='append',
   108          help='A set of regular expressions for transform names that should '
   109          'not be expanded.  For example, one could pass "\bRead.*" to indicate '
   110          'the inner structure of read nodes should not be expanded. '
   111          'If not given, defaults to the top-level nodes if interactively '
   112          'serving the graph and expanding all nodes otherwise.')
   113      parser.add_argument(
   114          '--render_edge_attributes',
   115          default='',
   116          help='Graphviz attributes to add to all edges.')
   117      parser.add_argument(
   118          '--render_node_attributes',
   119          default='',
   120          help='Graphviz attributes to add to all nodes.')
   121      parser.add_argument(
   122          '--render_highlight_attributes',
   123          default='',
   124          help='Graphviz attributes to add to all highlighted nodes.')
   125      parser.add_argument(
   126          '--log_proto',
   127          default=False,
   128          action='store_true',
   129          help='Set to also log input pipeline proto to stdout.')
   130      return parser
   131  
   132  
   133  class PipelineRenderer:
   134    def __init__(self, pipeline, options):
   135      self.pipeline = pipeline
   136      self.options = options
   137  
   138      # Drill down into any uninteresting, top-level transforms that contain
   139      # the whole pipeline (often added by the SDK).
   140      roots = self.pipeline.root_transform_ids
   141      while len(roots) == 1:
   142        root = self.pipeline.components.transforms[roots[0]]
   143        if not root.subtransforms:
   144          break
   145        roots = root.subtransforms
   146      self.roots = roots
   147  
   148      # Figure out at what point to stop rendering composite internals.
   149      if options.render_leaf_composite_nodes:
   150        is_leaf = lambda name: any(
   151            re.match(pattern, name)
   152            for patterns in options.render_leaf_composite_nodes
   153            for pattern in patterns.split(','))
   154        self.leaf_composites = set()
   155  
   156        def mark_leaves(transform_ids):
   157          for transform_id in transform_ids:
   158            if is_leaf(transform_id):
   159              self.leaf_composites.add(transform_id)
   160            else:
   161              mark_leaves(
   162                  self.pipeline.components.transforms[transform_id].subtransforms)
   163  
   164        mark_leaves(self.roots)
   165  
   166      elif options.render_port >= 0:
   167        # Start interactive with no unfolding.
   168        self.leaf_composites = set(self.roots)
   169      else:
   170        # For non-interactive, expand fully.
   171        self.leaf_composites = set()
   172  
   173      # Useful for attempting graph layout consistency.
   174      self.latest_positions = {}
   175      self.highlighted = []
   176  
   177    def update(self, toggle=None):
   178      if toggle:
   179        transform_id = toggle[0]
   180        self.highlighted = [transform_id]
   181        if transform_id in self.leaf_composites:
   182          transform = self.pipeline.components.transforms[transform_id]
   183          if transform.subtransforms:
   184            self.leaf_composites.remove(transform_id)
   185            for subtransform in transform.subtransforms:
   186              self.leaf_composites.add(subtransform)
   187              if transform_id in self.latest_positions:
   188                self.latest_positions[subtransform] = self.latest_positions[
   189                    transform_id]
   190        else:
   191          self.leaf_composites.add(transform_id)
   192  
   193    def style(self, transform_id):
   194      base = ' '.join(
   195          [DEFAULT_TRANSFORM_STYLE, self.options.render_node_attributes])
   196      if transform_id in self.highlighted:
   197        return ' '.join([
   198            base,
   199            DEFAULT_HIGHLIGHT_STYLE,
   200            self.options.render_highlight_attributes
   201        ])
   202      else:
   203        return base
   204  
   205    def to_dot(self):
   206      return '\n'.join(self.to_dot_iter())
   207  
   208    def to_dot_iter(self):
   209      yield 'digraph G {'
   210      # Defer drawing any edges until the end lest we declare nodes too early.
   211      edges_out = []
   212      for transform_id in self.roots:
   213        yield from self.transform_to_dot(
   214            transform_id, self.pcoll_leaf_consumers(), edges_out)
   215      yield from edges_out
   216      yield '}'
   217  
   218    def transform_to_dot(self, transform_id, pcoll_leaf_consumers, edges_out):
   219      transform = self.pipeline.components.transforms[transform_id]
   220      if self.is_leaf(transform_id):
   221        yield self.transform_node(transform_id)
   222        transform_inputs = set(transform.inputs.values())
   223        for name, output in transform.outputs.items():
   224          # For outputs that are also inputs, it's ambiguous whether they are
   225          # consumed as the outputs of this transform, or of the upstream
   226          # transform. Render the latter.
   227          if output in transform_inputs:
   228            continue
   229          output_label = name if len(transform.outputs) > 1 else ''
   230          for consumer, is_side_input in pcoll_leaf_consumers[output]:
   231            # Can't yield this here as the consumer might not be in this cluster.
   232            edge_style = 'dashed' if is_side_input else 'solid'
   233            edge_attributes = ' '.join([
   234                f'label="{output_label}" style={edge_style}',
   235                DEFAULT_EDGE_STYLE,
   236                self.options.render_edge_attributes
   237            ])
   238            edges_out.append(
   239                f'"{transform_id}" -> "{consumer}" [{edge_attributes}]')
   240      else:
   241        yield f'subgraph "cluster_{transform_id}" {{'
   242        yield self.transform_attributes(transform_id)
   243        for subtransform in transform.subtransforms:
   244          yield from self.transform_to_dot(
   245              subtransform, pcoll_leaf_consumers, edges_out)
   246        yield '}'
   247  
   248    def transform_node(self, transform_id):
   249      return f'"{transform_id}" [{self.transform_attributes(transform_id)}]'
   250  
   251    def transform_attributes(self, transform_id):
   252      transform = self.pipeline.components.transforms[transform_id]
   253      local_name = transform.unique_name.split('/')[-1]
   254      if transform_id in self.latest_positions:
   255        pos_str = f'pos="{self.latest_positions[transform_id]}"'
   256      else:
   257        pos_str = ''
   258      return (
   259          f'label="{local_name}" {self.style(transform_id)} '
   260          f'URL="javascript:click(\'{transform_id}\')" {pos_str}')
   261  
   262    def pcoll_leaf_consumers_iter(self, transform_id):
   263      transform = self.pipeline.components.transforms[transform_id]
   264      transform_inputs = set(transform.inputs.values())
   265      side_inputs = set(translations.side_inputs(transform).values())
   266      if self.is_leaf(transform_id):
   267        for pcoll in transform.inputs.values():
   268          yield pcoll, (transform_id, pcoll in side_inputs)
   269      for subtransform in transform.subtransforms:
   270        for pcoll, (consumer,
   271                    annotation) in self.pcoll_leaf_consumers_iter(subtransform):
   272          if self.is_leaf(transform_id):
   273            if pcoll not in transform_inputs:
   274              yield pcoll, (transform_id, annotation)
   275          else:
   276            yield pcoll, (consumer, annotation)
   277  
   278    def pcoll_leaf_consumers(self):
   279      result = collections.defaultdict(list)
   280      for transform_id in self.roots:
   281        for pcoll, consumer_info in self.pcoll_leaf_consumers_iter(transform_id):
   282          result[pcoll].append(consumer_info)
   283      return result
   284  
   285    def is_leaf(self, transform_id):
   286      return (
   287          transform_id in self.leaf_composites or
   288          not self.pipeline.components.transforms[transform_id].subtransforms)
   289  
   290    def info(self):
   291      if len(self.highlighted) != 1:
   292        return ''
   293      transform_id = self.highlighted[0]
   294      return f'<pre>{self.pipeline.components.transforms[transform_id]}</pre>'
   295  
   296    def layout_dot(self):
   297      layout = subprocess.run(['dot', '-Tdot'],
   298                              input=self.to_dot().encode('utf-8'),
   299                              capture_output=True,
   300                              check=True).stdout
   301  
   302      # Try to capture the positions for layout consistency.
   303      json_out = json.loads(
   304          subprocess.run(['dot', '-n2', '-Kneato', '-Tjson'],
   305                         input=layout,
   306                         capture_output=True,
   307                         check=True).stdout)
   308      for box in json_out['objects']:
   309        name = box.get('name', None)
   310        if name in self.pipeline.components.transforms:
   311          if 'pos' in box:
   312            self.latest_positions[name] = box['pos']
   313          elif 'bb' in box:
   314            x0, y0, x1, y1 = [float(r) for r in box['bb'].split(',')]
   315            self.latest_positions[name] = f'{(x0+x1)/2},{(y0+y1)/2}'
   316  
   317      return layout
   318  
   319    def page_callback_data(self, layout):
   320      svg = subprocess.run(['dot', '-Kneato', '-n2', '-Tsvg'],
   321                           input=layout,
   322                           capture_output=True,
   323                           check=True).stdout
   324      cmapx = subprocess.run(['dot', '-Kneato', '-n2', '-Tcmapx'],
   325                             input=layout,
   326                             capture_output=True,
   327                             check=True).stdout
   328  
   329      return {
   330          'src': 'data:image/svg+xml;base64,' +
   331          base64.b64encode(svg).decode('utf-8'),
   332          'cmapx': cmapx.decode('utf-8'),
   333          'info': self.info(),
   334      }
   335  
   336    def render_data(self):
   337      logging.info("Re-rendering pipeline...")
   338      layout = self.layout_dot()
   339      if self.options.render_output:
   340        for path in self.options.render_output:
   341          format = os.path.splitext(path)[-1][1:]
   342          result = subprocess.run(
   343              ['dot', '-Kneato', '-n2', '-T' + format, '-o', path],
   344              input=layout,
   345              check=False)
   346          if result.returncode:
   347            logging.error(
   348                "Failed render pipeline as %r: exit %s", path, result.returncode)
   349          else:
   350            logging.info("Rendered pipeline as %r", path)
   351      return self.page_callback_data(layout)
   352  
   353    def render_json(self):
   354      return json.dumps(self.render_data())
   355  
   356    def page(self):
   357      data = self.render_data()
   358      src = data['src']
   359      cmapx = data['cmapx']
   360      return """
   361          <html>
   362            <head>
   363            <script>
   364              function click(transform_id) {
   365                var xhttp = new XMLHttpRequest();
   366                xhttp.onreadystatechange = function() {
   367                  render_data = JSON.parse(this.responseText);
   368                  document.getElementById('image_map_holder').innerHTML =
   369                      render_data.cmapx;
   370                  document.getElementById('image_tag').src = render_data.src
   371                  document.getElementById('info').innerHTML = render_data.info
   372                };
   373                xhttp.open("GET", "render?toggle=" + transform_id, true);
   374                xhttp.send();
   375              }
   376  
   377            </script>
   378            </head>
   379            """ + f"""
   380            <body>
   381              Click on a composite transform to expand.
   382              <br>
   383              <img id='image_tag' src='{src}' usemap='#G'>
   384              <hr>
   385              <div id='info'></div>
   386              <div id='image_map_holder'>
   387              {cmapx}
   388              </div>
   389            </body>
   390          </html>
   391      """
   392  
   393  
   394  class RenderRunner(runner.PipelineRunner):
   395    # TODO(robertwb): Consider making this a runner wrapper, where live status
   396    # (such as counters, stage completion status, or possibly even PCollection
   397    # samples) queryable and/or displayed.  This could evolve into a full Beam
   398    # UI.
   399    def run_pipeline(self, pipeline_object, options, pipeline_proto=None):
   400      if not pipeline_proto:
   401        pipeline_proto = pipeline_object.to_runner_api()
   402      render_options = options.view_as(RenderOptions)
   403      if render_options.log_proto:
   404        logging.info(pipeline_proto)
   405      renderer = PipelineRenderer(pipeline_proto, render_options)
   406      renderer.page()
   407  
   408      if render_options.render_port >= 0:
   409        # TODO: If this gets more complex, we could consider taking on a
   410        # framework like Flask as a dependency.
   411        class RequestHandler(http.server.BaseHTTPRequestHandler):
   412          def do_GET(self):
   413            parts = urllib.parse.urlparse(self.path)
   414            args = urllib.parse.parse_qs(parts.query)
   415            renderer.update(**args)
   416  
   417            if parts.path == '/':
   418              response = renderer.page()
   419            elif parts.path == '/render':
   420              response = renderer.render_json()
   421            else:
   422              self.send_response(400)
   423              return
   424  
   425            self.send_response(200)
   426            self.send_header("Content-type", "text/html")
   427            self.end_headers()
   428            self.wfile.write(response.encode('utf-8'))
   429  
   430        server = http.server.HTTPServer(('localhost', render_options.render_port),
   431                                        RequestHandler)
   432        server_thread = threading.Thread(target=server.serve_forever, daemon=True)
   433        server_thread.start()
   434        print('Serving at http://%s:%s' % server.server_address)
   435        return RenderPipelineResult(server)
   436  
   437      else:
   438        return RenderPipelineResult(None)
   439  
   440  
   441  class RenderPipelineResult(runner.PipelineResult):
   442    def __init__(self, server):
   443      super().__init__(runner.PipelineState.RUNNING)
   444      self.server = server
   445  
   446    def wait_until_finish(self, duration=None):
   447      if self.server:
   448        time.sleep(duration or 1e8)
   449        self.server.shutdown()
   450      self._state = runner.PipelineState.DONE
   451  
   452    def monitoring_infos(self):
   453      return []
   454  
   455  
   456  def run(argv):
   457    if argv[0] == __file__:
   458      argv = argv[1:]
   459    parser = argparse.ArgumentParser(
   460        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
   461    parser.add_argument(
   462        '--job_port',
   463        type=int,
   464        default=0,
   465        help='port on which to serve the job api')
   466    parser.add_argument(
   467        '--pipeline_proto', help='file containing the beam pipeline definition')
   468    RenderOptions._add_argparse_args(parser)
   469    options = parser.parse_args(argv)
   470  
   471    if options.pipeline_proto:
   472      if not options.render_output and options.render_port < 0:
   473        options.render_port = 0
   474  
   475      render_one(options)
   476  
   477      if options.render_output:
   478        return
   479  
   480    run_server(options)
   481  
   482  
   483  def render_one(options):
   484    if options.pipeline_proto == '-':
   485      content = sys.stdin.buffer.read()
   486      if content[0] == b'{':
   487        ext = '.json'
   488      else:
   489        try:
   490          content.decode('utf-8')
   491          ext = '.textproto'
   492        except UnicodeDecodeError:
   493          ext = '.pb'
   494    else:
   495      if options.pipeline_proto.startswith('gs://'):
   496        if gcsio is None:
   497          raise ImportError('GCS not available; please install apache_beam[gcp]')
   498        open_fn = gcsio.GcsIO().open
   499      else:
   500        open_fn = open
   501  
   502      with open_fn(options.pipeline_proto, 'rb') as fin:
   503        content = fin.read()
   504      ext = os.path.splitext(options.pipeline_proto)[-1]
   505  
   506    if ext == '.textproto':
   507      pipeline_proto = text_format.Parse(content, beam_runner_api_pb2.Pipeline())
   508    elif ext == '.json':
   509      pipeline_proto = json_format.Parse(content, beam_runner_api_pb2.Pipeline())
   510    else:
   511      pipeline_proto = beam_runner_api_pb2.Pipeline()
   512      pipeline_proto.ParseFromString(content)
   513  
   514    RenderRunner().run_pipeline(
   515        None, pipeline_options.PipelineOptions(**vars(options)), pipeline_proto)
   516  
   517  
   518  def run_server(options):
   519    class RenderBeamJob(local_job_service.BeamJob):
   520      def _invoke_runner(self):
   521        return RenderRunner().run_pipeline(
   522            None,
   523            pipeline_options.PipelineOptions(**vars(options)),
   524            self._pipeline_proto)
   525  
   526    with tempfile.TemporaryDirectory() as staging_dir:
   527      job_servicer = local_job_service.LocalJobServicer(
   528          staging_dir, beam_job_type=RenderBeamJob)
   529      port = job_servicer.start_grpc_server(options.job_port)
   530      try:
   531        local_job_service_main.serve(
   532            "Listening for beam jobs on port %d." % port, job_servicer)
   533      finally:
   534        job_servicer.stop()
   535  
   536  
   537  if __name__ == '__main__':
   538    logging.basicConfig()
   539    logging.getLogger().setLevel(logging.INFO)
   540    run(sys.argv)