github.com/gopherd/gonum@v0.0.4/graph/path/spanning_tree.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package path
     6  
     7  import (
     8  	"container/heap"
     9  	"math"
    10  	"sort"
    11  
    12  	"github.com/gopherd/gonum/graph"
    13  	"github.com/gopherd/gonum/graph/simple"
    14  )
    15  
    16  // WeightedBuilder is a type that can add nodes and weighted edges.
    17  type WeightedBuilder interface {
    18  	AddNode(graph.Node)
    19  	SetWeightedEdge(graph.WeightedEdge)
    20  }
    21  
    22  // Prim generates a minimum spanning tree of g by greedy tree extension, placing
    23  // the result in the destination, dst. If the edge weights of g are distinct
    24  // it will be the unique minimum spanning tree of g. The destination is not cleared
    25  // first. The weight of the minimum spanning tree is returned. If g is not connected,
    26  // a minimum spanning forest will be constructed in dst and the sum of minimum
    27  // spanning tree weights will be returned.
    28  //
    29  // Nodes and Edges from g are used to construct dst, so if the Node and Edge
    30  // types used in g are pointer or reference-like, then the values will be shared
    31  // between the graphs.
    32  //
    33  // If dst has nodes that exist in g, Prim will panic.
    34  func Prim(dst WeightedBuilder, g graph.WeightedUndirected) float64 {
    35  	nodes := graph.NodesOf(g.Nodes())
    36  	if len(nodes) == 0 {
    37  		return 0
    38  	}
    39  
    40  	q := &primQueue{
    41  		indexOf: make(map[int64]int, len(nodes)-1),
    42  		nodes:   make([]simple.WeightedEdge, 0, len(nodes)-1),
    43  	}
    44  	dst.AddNode(nodes[0])
    45  	for _, u := range nodes[1:] {
    46  		dst.AddNode(u)
    47  		heap.Push(q, simple.WeightedEdge{F: u, W: math.Inf(1)})
    48  	}
    49  
    50  	u := nodes[0]
    51  	uid := u.ID()
    52  	for _, v := range graph.NodesOf(g.From(uid)) {
    53  		w, ok := g.Weight(uid, v.ID())
    54  		if !ok {
    55  			panic("prim: unexpected invalid weight")
    56  		}
    57  		q.update(v, u, w)
    58  	}
    59  
    60  	var w float64
    61  	for q.Len() > 0 {
    62  		e := heap.Pop(q).(simple.WeightedEdge)
    63  		if e.To() != nil && g.HasEdgeBetween(e.From().ID(), e.To().ID()) {
    64  			dst.SetWeightedEdge(g.WeightedEdge(e.From().ID(), e.To().ID()))
    65  			w += e.Weight()
    66  		}
    67  
    68  		u = e.From()
    69  		uid := u.ID()
    70  		for _, n := range graph.NodesOf(g.From(uid)) {
    71  			if key, ok := q.key(n); ok {
    72  				w, ok := g.Weight(uid, n.ID())
    73  				if !ok {
    74  					panic("prim: unexpected invalid weight")
    75  				}
    76  				if w < key {
    77  					q.update(n, u, w)
    78  				}
    79  			}
    80  		}
    81  	}
    82  	return w
    83  }
    84  
    85  // primQueue is a Prim's priority queue. The priority queue is a
    86  // queue of edge From nodes keyed on the minimum edge weight to
    87  // a node in the set of nodes already connected to the minimum
    88  // spanning forest.
    89  type primQueue struct {
    90  	indexOf map[int64]int
    91  	nodes   []simple.WeightedEdge
    92  }
    93  
    94  func (q *primQueue) Less(i, j int) bool {
    95  	return q.nodes[i].Weight() < q.nodes[j].Weight()
    96  }
    97  
    98  func (q *primQueue) Swap(i, j int) {
    99  	q.indexOf[q.nodes[i].From().ID()] = j
   100  	q.indexOf[q.nodes[j].From().ID()] = i
   101  	q.nodes[i], q.nodes[j] = q.nodes[j], q.nodes[i]
   102  }
   103  
   104  func (q *primQueue) Len() int {
   105  	return len(q.nodes)
   106  }
   107  
   108  func (q *primQueue) Push(x interface{}) {
   109  	n := x.(simple.WeightedEdge)
   110  	q.indexOf[n.From().ID()] = len(q.nodes)
   111  	q.nodes = append(q.nodes, n)
   112  }
   113  
   114  func (q *primQueue) Pop() interface{} {
   115  	n := q.nodes[len(q.nodes)-1]
   116  	q.nodes = q.nodes[:len(q.nodes)-1]
   117  	delete(q.indexOf, n.From().ID())
   118  	return n
   119  }
   120  
   121  // key returns the key for the node u and whether the node is
   122  // in the queue. If the node is not in the queue, key is returned
   123  // as +Inf.
   124  func (q *primQueue) key(u graph.Node) (key float64, ok bool) {
   125  	i, ok := q.indexOf[u.ID()]
   126  	if !ok {
   127  		return math.Inf(1), false
   128  	}
   129  	return q.nodes[i].Weight(), ok
   130  }
   131  
   132  // update updates u's position in the queue with the new closest
   133  // MST-connected neighbour, v, and the key weight between u and v.
   134  func (q *primQueue) update(u, v graph.Node, key float64) {
   135  	id := u.ID()
   136  	i, ok := q.indexOf[id]
   137  	if !ok {
   138  		return
   139  	}
   140  	q.nodes[i].T = v
   141  	q.nodes[i].W = key
   142  	heap.Fix(q, i)
   143  }
   144  
   145  // UndirectedWeightLister is an undirected graph that returns edge weights and
   146  // the set of edges in the graph.
   147  type UndirectedWeightLister interface {
   148  	graph.WeightedUndirected
   149  	WeightedEdges() graph.WeightedEdges
   150  }
   151  
   152  // Kruskal generates a minimum spanning tree of g by greedy tree coalescence, placing
   153  // the result in the destination, dst. If the edge weights of g are distinct
   154  // it will be the unique minimum spanning tree of g. The destination is not cleared
   155  // first. The weight of the minimum spanning tree is returned. If g is not connected,
   156  // a minimum spanning forest will be constructed in dst and the sum of minimum
   157  // spanning tree weights will be returned.
   158  //
   159  // Nodes and Edges from g are used to construct dst, so if the Node and Edge
   160  // types used in g are pointer or reference-like, then the values will be shared
   161  // between the graphs.
   162  //
   163  // If dst has nodes that exist in g, Kruskal will panic.
   164  func Kruskal(dst WeightedBuilder, g UndirectedWeightLister) float64 {
   165  	edges := graph.WeightedEdgesOf(g.WeightedEdges())
   166  	sort.Sort(byWeight(edges))
   167  
   168  	ds := make(djSet)
   169  	it := g.Nodes()
   170  	for it.Next() {
   171  		n := it.Node()
   172  		dst.AddNode(n)
   173  		ds.add(n.ID())
   174  	}
   175  
   176  	var w float64
   177  	for _, e := range edges {
   178  		if s1, s2 := ds.find(e.From().ID()), ds.find(e.To().ID()); s1 != s2 {
   179  			ds.union(s1, s2)
   180  			dst.SetWeightedEdge(g.WeightedEdge(e.From().ID(), e.To().ID()))
   181  			w += e.Weight()
   182  		}
   183  	}
   184  	return w
   185  }
   186  
   187  type byWeight []graph.WeightedEdge
   188  
   189  func (e byWeight) Len() int           { return len(e) }
   190  func (e byWeight) Less(i, j int) bool { return e[i].Weight() < e[j].Weight() }
   191  func (e byWeight) Swap(i, j int)      { e[i], e[j] = e[j], e[i] }