github.com/nevalang/neva@v0.23.1-0.20240507185603-7696a9bb8dda/internal/compiler/backend/dot/graphviz.go (about)

     1  package dot
     2  
     3  import (
     4  	"embed"
     5  	"fmt"
     6  	"html"
     7  	"io"
     8  	"strconv"
     9  	"strings"
    10  	"sync"
    11  	"text/template"
    12  
    13  	"github.com/nevalang/neva/internal/runtime/ir"
    14  )
    15  
    16  //go:embed *.tmpl
    17  var tmplFS embed.FS
    18  
    19  type Port struct {
    20  	ir.PortAddr
    21  }
    22  
    23  func trimPortPath(path string) string {
    24  	return strings.TrimSuffix(strings.TrimSuffix(path, "/in"), "/out")
    25  }
    26  
    27  func (p Port) FormatName() string {
    28  	portStr := p.Port
    29  	switch {
    30  	case strings.HasSuffix(p.Path, "/in"):
    31  		portStr += "/in"
    32  	case strings.HasSuffix(p.Path, "/out"):
    33  		portStr += "/out"
    34  	}
    35  	if p.Idx != 0 {
    36  		portStr = fmt.Sprintf("%s/%d", portStr, p.Idx)
    37  	}
    38  	return strconv.Quote(portStr)
    39  }
    40  
    41  func (p Port) FormatLabel() string {
    42  	escapePort := html.EscapeString(p.Port)
    43  	if p.Idx != 0 {
    44  		return html.EscapeString(fmt.Sprintf("%s[%d]", p.Port, p.Idx))
    45  	}
    46  	return escapePort
    47  }
    48  
    49  func (p Port) Format() string {
    50  	path := p.Path
    51  	portStr := p.Port
    52  	switch {
    53  	case strings.HasSuffix(p.Path, "/in"):
    54  		path = path[:len(path)-3] // Trim /in
    55  		portStr += "/in"
    56  	case strings.HasSuffix(p.Path, "/out"):
    57  		path = path[:len(path)-4] // Trim /out
    58  		portStr += "/out"
    59  	}
    60  	if p.Idx != 0 {
    61  		portStr = fmt.Sprint(portStr, "/", p.Idx)
    62  	}
    63  	return fmt.Sprintf("%q:%q", path, portStr)
    64  }
    65  
    66  type Node struct {
    67  	Name  string
    68  	Extra string
    69  	In    map[Port]struct{}
    70  	Out   map[Port]struct{}
    71  }
    72  
    73  func (n Node) Format() string {
    74  	return fmt.Sprintf("%q", n.Name)
    75  }
    76  
    77  func (n Node) FormatLabel() string {
    78  	i := strings.LastIndexByte(n.Name, '/')
    79  	if i == -1 {
    80  		return n.Name
    81  	}
    82  	return n.Name[i+1:]
    83  }
    84  
    85  type Edge struct {
    86  	Send Port
    87  	Recv Port
    88  }
    89  
    90  type Cluster struct {
    91  	Index    int
    92  	Prefix   string
    93  	Nodes    map[string]*Node
    94  	Clusters map[string]*Cluster
    95  }
    96  
    97  func (c *Cluster) getOrCreateClusterNode(b *ClusterBuilder, path string) *Node {
    98  	path = trimPortPath(path)
    99  	return c.getOrCreateClusterNodeRec(b, path, "", path)
   100  }
   101  
   102  func (c *Cluster) getOrCreateClusterNodeRec(b *ClusterBuilder, path, prefix, remaining string) *Node {
   103  	before, after, found := strings.Cut(remaining, "/")
   104  	if !found {
   105  		if c.Nodes == nil {
   106  			c.Nodes = map[string]*Node{}
   107  		}
   108  		n, ok := c.Nodes[before]
   109  		if ok {
   110  			return n
   111  		}
   112  		n = &Node{
   113  			Name: path,
   114  		}
   115  		c.Nodes[before] = n
   116  		return n
   117  	}
   118  	if prefix == "" {
   119  		prefix = before
   120  	} else {
   121  		prefix = prefix + "/" + before
   122  	}
   123  	next := c.Clusters[before]
   124  	if next == nil {
   125  		if c.Clusters == nil {
   126  			c.Clusters = map[string]*Cluster{}
   127  		}
   128  		next = &Cluster{Index: b.nextId, Prefix: prefix}
   129  		c.Clusters[before] = next
   130  		b.nextId++
   131  	}
   132  	return next.getOrCreateClusterNodeRec(b, path, prefix, after)
   133  }
   134  
   135  func (c *Cluster) Label() string {
   136  	i := strings.LastIndexByte(c.Prefix, '/')
   137  	if i == -1 {
   138  		return c.Prefix
   139  	}
   140  	return c.Prefix[i+1:]
   141  }
   142  
   143  type ClusterBuilder struct {
   144  	Main  *Cluster
   145  	Edges []Edge
   146  
   147  	nextId int
   148  	once   sync.Once
   149  	tmpl   *template.Template
   150  	err    error
   151  }
   152  
   153  func (b *ClusterBuilder) initTemplates() {
   154  	b.tmpl, b.err = template.New("").ParseFS(tmplFS, "*.tmpl")
   155  }
   156  
   157  func (b *ClusterBuilder) insertClusterNode(addr ir.PortAddr) {
   158  	if b.Main == nil {
   159  		cluster := &Cluster{}
   160  		b.Main = cluster
   161  		b.nextId++
   162  	}
   163  	switch n := b.Main.getOrCreateClusterNode(b, addr.Path); {
   164  	case strings.HasSuffix(addr.Path, "/in"):
   165  		if n.In == nil {
   166  			n.In = map[Port]struct{}{}
   167  		}
   168  		n.In[Port{addr}] = struct{}{}
   169  	case strings.HasSuffix(addr.Path, "/out"):
   170  		if n.Out == nil {
   171  			n.Out = map[Port]struct{}{}
   172  		}
   173  		n.Out[Port{addr}] = struct{}{}
   174  	}
   175  }
   176  
   177  func (b *ClusterBuilder) InsertEdge(send, recv ir.PortAddr) {
   178  	b.insertClusterNode(send)
   179  	b.insertClusterNode(recv)
   180  	b.Edges = append(b.Edges, Edge{Send: Port{send}, Recv: Port{recv}})
   181  }
   182  
   183  func (b *ClusterBuilder) Build(w io.Writer) error {
   184  	if b.once.Do(b.initTemplates); b.err != nil {
   185  		return b.err
   186  	}
   187  	if err := b.tmpl.ExecuteTemplate(w, "graph.dot.tmpl", b); err != nil {
   188  		return err
   189  	}
   190  	return nil
   191  }