github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/display/interactive_pipeline_graph.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Helper to render pipeline graph in IPython when running interactively. 19 20 This module is experimental. No backwards-compatibility guarantees. 21 """ 22 23 # pytype: skip-file 24 25 import re 26 27 from apache_beam.runners.interactive.display import pipeline_graph 28 29 30 def nice_str(o): 31 s = repr(o) 32 s = s.replace('"', "'") 33 s = s.replace('\\', '|') 34 s = re.sub(r'[^\x20-\x7F]', ' ', s) 35 assert '"' not in s 36 if len(s) > 35: 37 s = s[:35] + '...' 38 return s 39 40 41 def format_sample(contents, count=1000): 42 contents = list(contents) 43 elems = ', '.join([nice_str(o) for o in contents[:count]]) 44 if len(contents) > count: 45 elems += ', ...' 46 assert '"' not in elems 47 return '{%s}' % elems 48 49 50 class InteractivePipelineGraph(pipeline_graph.PipelineGraph): 51 """Creates the DOT representation of an interactive pipeline. Thread-safe.""" 52 def __init__( 53 self, 54 pipeline, 55 required_transforms=None, 56 referenced_pcollections=None, 57 cached_pcollections=None): 58 """Constructor of PipelineGraph. 59 60 Args: 61 pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered. 62 required_transforms: (list/set of str) ID of top level PTransforms that 63 lead to visible results. 64 referenced_pcollections: (list/set of str) ID of PCollections that are 65 referenced by top level PTransforms executed (i.e. 66 required_transforms) 67 cached_pcollections: (set of str) a set of PCollection IDs of those whose 68 cached results are used in the execution. 69 """ 70 self._required_transforms = required_transforms or set() 71 self._referenced_pcollections = referenced_pcollections or set() 72 self._cached_pcollections = cached_pcollections or set() 73 74 super().__init__( 75 pipeline=pipeline, 76 default_vertex_attrs={ 77 'color': 'gray', 'fontcolor': 'gray' 78 }, 79 default_edge_attrs={'color': 'gray'}) 80 81 transform_updates, pcollection_updates = self._generate_graph_update_dicts() 82 self._update_graph(transform_updates, pcollection_updates) 83 84 def update_pcollection_stats(self, pcollection_stats): 85 """Updates PCollection stats. 86 87 Args: 88 pcollection_stats: (dict of dict) maps PCollection IDs to informations. In 89 particular, we only care about the field 'sample' which should be a 90 the PCollection result in as a list. 91 """ 92 edge_dict = {} 93 for pcoll_id, stats in pcollection_stats.items(): 94 attrs = {} 95 pcoll_list = stats['sample'] 96 if pcoll_list: 97 attrs['label'] = format_sample(pcoll_list, 1) 98 attrs['labeltooltip'] = format_sample(pcoll_list, 10) 99 else: 100 attrs['label'] = '?' 101 edge_dict[pcoll_id] = attrs 102 103 self._update_graph(edge_dict=edge_dict) 104 105 def _generate_graph_update_dicts(self): 106 """Generate updates specific to interactive pipeline. 107 108 Returns: 109 vertex_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes 110 edge_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes 111 """ 112 transform_dict = {} # maps PTransform IDs to properties 113 pcoll_dict = {} # maps PCollection IDs to properties 114 115 for transform_id, transform_proto in self._top_level_transforms(): 116 transform_dict[transform_proto.unique_name] = { 117 'required': transform_id in self._required_transforms 118 } 119 120 for pcoll_id in transform_proto.outputs.values(): 121 pcoll_dict[pcoll_id] = { 122 'cached': pcoll_id in self._cached_pcollections, 123 'referenced': pcoll_id in self._referenced_pcollections 124 } 125 126 def vertex_properties_to_attributes(vertex): 127 """Converts PCollection properties to DOT vertex attributes.""" 128 attrs = {} 129 if 'leaf' in vertex: 130 attrs['style'] = 'invis' 131 elif vertex.get('required'): 132 attrs['color'] = 'blue' 133 attrs['fontcolor'] = 'blue' 134 else: 135 attrs['color'] = 'grey' 136 return attrs 137 138 def edge_properties_to_attributes(edge): 139 """Converts PTransform properties to DOT edge attributes.""" 140 attrs = {} 141 if edge.get('cached'): 142 attrs['color'] = 'red' 143 elif edge.get('referenced'): 144 attrs['color'] = 'black' 145 else: 146 attrs['color'] = 'grey' 147 return attrs 148 149 vertex_dict = {} # maps vertex names to attributes 150 edge_dict = {} # maps edge names to attributes 151 152 for transform_name, transform_properties in transform_dict.items(): 153 vertex_dict[transform_name] = vertex_properties_to_attributes( 154 transform_properties) 155 156 for pcoll_id, pcoll_properties in pcoll_dict.items(): 157 edge_dict[pcoll_id] = edge_properties_to_attributes(pcoll_properties) 158 159 return vertex_dict, edge_dict