gorgonia.org/gorgonia@v0.9.17/encoding/dot/topgraph.go (about)

     1  package dot
     2  
     3  import (
     4  	"sort"
     5  
     6  	"gonum.org/v1/gonum/graph"
     7  	"gonum.org/v1/gonum/graph/encoding"
     8  	gonumDot "gonum.org/v1/gonum/graph/encoding/dot"
     9  	"gonum.org/v1/gonum/graph/simple"
    10  	"gorgonia.org/gorgonia"
    11  	internalEncoding "gorgonia.org/gorgonia/internal/encoding"
    12  )
    13  
    14  func subGraphs() map[internalEncoding.Group]subgrapher {
    15  	return map[internalEncoding.Group]subgrapher{
    16  		internalEncoding.ConstantCluster: constantSubGraph{
    17  			DirectedBuilder: simple.NewDirectedGraph(),
    18  			name:            "Constants",
    19  		},
    20  		internalEncoding.InputCluster: inputSubGraph{
    21  			DirectedBuilder: simple.NewDirectedGraph(),
    22  			name:            "Inputs",
    23  		},
    24  		internalEncoding.ExprGraphCluster: exprSubGraph{
    25  			DirectedBuilder: simple.NewDirectedGraph(),
    26  			name:            "ExprGraph",
    27  			subs:            make(map[internalEncoding.Group]operatorSubGraph),
    28  		},
    29  		internalEncoding.UndefinedCluster: exprSubGraph{
    30  			DirectedBuilder: simple.NewDirectedGraph(),
    31  			name:            "Undefined",
    32  		},
    33  	}
    34  
    35  }
    36  
    37  type attributer []encoding.Attribute
    38  
    39  func (a attributer) Attributes() []encoding.Attribute { return a }
    40  
    41  func sortedKeys(m map[internalEncoding.Group]subgrapher) (retVal internalEncoding.Groups) {
    42  	for k := range m {
    43  		retVal = append(retVal, k)
    44  	}
    45  	sort.Sort(retVal)
    46  	return retVal
    47  }
    48  
    49  func generateDotGraph(g *gorgonia.ExprGraph) (graph.Graph, error) {
    50  	dg := simple.NewDirectedGraph()
    51  	copyGraph(dg, g)
    52  	nodes := dg.Nodes()
    53  	subgraphs := subGraphs()
    54  
    55  	for nodes.Next() {
    56  		n := nodes.Node()
    57  		if _, ok := n.(internalEncoding.Grouper); ok {
    58  			groups := n.(internalEncoding.Grouper).Groups()
    59  			for _, group := range groups {
    60  				if subgrapher, ok := subgraphs[group]; ok {
    61  					subgrapher.(graph.DirectedBuilder).AddNode(n)
    62  				} else {
    63  					// check if we are in the ExprGraphCluster
    64  					var subgraph operatorSubGraph
    65  					subgraph = operatorSubGraph{
    66  						DirectedBuilder: simple.NewDirectedGraph(),
    67  						id:              group.ID,
    68  						name:            group.Name,
    69  					}
    70  					if groups.Have(internalEncoding.ExprGraphCluster) {
    71  						exprSubg := subgraphs[internalEncoding.ExprGraphCluster].(exprSubGraph)
    72  						var ok bool
    73  						if _, ok = exprSubg.subs[group]; ok {
    74  							subgraph = exprSubg.subs[group]
    75  						} else {
    76  							exprSubg.subs[group] = subgraph
    77  						}
    78  						subgraph.AddNode(n)
    79  						continue
    80  					}
    81  					subgraph.AddNode(n)
    82  					subgraphs[group] = subgraph
    83  				}
    84  			}
    85  		}
    86  	}
    87  	subs := make([]gonumDot.Graph, 0, len(subgraphs))
    88  	keys := sortedKeys(subgraphs)
    89  	for _, k := range keys {
    90  		subs = append(subs, subgraphs[k])
    91  	}
    92  	return dotGraph{
    93  		Directed: dg,
    94  		subs:     subs,
    95  	}, nil
    96  }