github.com/whoyao/protocol@v0.0.0-20230519045905-2d8ace718ca5/utils/graph.go (about)

     1  package utils
     2  
     3  import (
     4  	"container/heap"
     5  	"log"
     6  	"math"
     7  )
     8  
     9  type GraphNodeProps[K comparable] interface {
    10  	ID() K
    11  }
    12  
    13  type GraphEdgeProps interface {
    14  	Length() int64
    15  }
    16  
    17  type Graph[K comparable, N GraphNodeProps[K], E GraphEdgeProps] struct {
    18  	nodesByID map[K]*GraphNode[N]
    19  	nodes     []*GraphNode[N]
    20  	edges     [][]*GraphEdge[N, E]
    21  }
    22  
    23  func NewGraph[K comparable, N GraphNodeProps[K], E GraphEdgeProps]() *Graph[K, N, E] {
    24  	return &Graph[K, N, E]{
    25  		nodesByID: map[K]*GraphNode[N]{},
    26  	}
    27  }
    28  
    29  func (g *Graph[K, N, E]) Size() int {
    30  	return len(g.nodes)
    31  }
    32  
    33  func (g *Graph[K, N, E]) InsertNode(props N) {
    34  	if n, ok := g.nodesByID[props.ID()]; ok {
    35  		n.props = props
    36  		return
    37  	}
    38  
    39  	i := len(g.nodes)
    40  	n := &GraphNode[N]{
    41  		i:     i,
    42  		props: props,
    43  	}
    44  
    45  	g.nodes = append(g.nodes, n)
    46  	g.nodesByID[props.ID()] = n
    47  
    48  	for i := range g.edges {
    49  		g.edges[i] = append(g.edges[i], nil)
    50  	}
    51  	g.edges = append(g.edges, make([]*GraphEdge[N, E], len(g.nodes)))
    52  }
    53  
    54  func (g *Graph[K, N, E]) InsertEdge(src, dst K, props E) {
    55  	s := g.nodesByID[src]
    56  	d := g.nodesByID[dst]
    57  
    58  	g.edges[s.i][d.i] = &GraphEdge[N, E]{props}
    59  }
    60  
    61  func (g *Graph[K, N, E]) DeleteEdge(src, dst K) {
    62  	s := g.nodesByID[src]
    63  	d := g.nodesByID[dst]
    64  
    65  	g.edges[s.i][d.i] = nil
    66  }
    67  
    68  func (g *Graph[K, N, E]) Node(id K) N {
    69  	return g.nodesByID[id].props
    70  }
    71  
    72  func (g *Graph[K, N, E]) Edge(src, dst K) (p E, ok bool) {
    73  	s := g.nodesByID[src]
    74  	d := g.nodesByID[dst]
    75  
    76  	e := g.edges[s.i][d.i]
    77  	if e == nil {
    78  		return
    79  	}
    80  	return e.props, true
    81  }
    82  
    83  func (g *Graph[K, N, E]) OutEdges(src K) map[K]E {
    84  	s := g.nodesByID[src]
    85  
    86  	edges := make(map[K]E, len(g.nodes))
    87  	for i, e := range g.edges[s.i] {
    88  		if e != nil {
    89  			edges[g.nodes[i].props.ID()] = e.props
    90  		}
    91  	}
    92  	return edges
    93  }
    94  
    95  func (g *Graph[K, N, E]) InEdges(dst K) map[K]E {
    96  	d := g.nodesByID[dst]
    97  
    98  	edges := make(map[K]E, len(g.nodes))
    99  	for i, es := range g.edges {
   100  		if es[d.i] != nil {
   101  			edges[g.nodes[i].props.ID()] = es[d.i].props
   102  		}
   103  	}
   104  	return edges
   105  }
   106  
   107  func (g *Graph[K, N, E]) ShortestPath(src, dst K) ([]N, int64) {
   108  	paths := &graphPathMinHeap[N]{}
   109  	visited := map[*GraphNode[N]]*graphPath[N]{}
   110  
   111  	s := g.nodesByID[src]
   112  	d := g.nodesByID[dst]
   113  
   114  	path := &graphPath[N]{node: s}
   115  	heap.Push(paths, path)
   116  	visited[path.node] = path
   117  
   118  	for {
   119  		if paths.Len() == 0 {
   120  			return nil, 0
   121  		}
   122  
   123  		prev := heap.Pop(paths).(*graphPath[N])
   124  		for i, e := range g.edges[prev.node.i] {
   125  			if e == nil {
   126  				continue
   127  			}
   128  
   129  			path := &graphPath[N]{
   130  				prev:   prev,
   131  				node:   g.nodes[i],
   132  				length: prev.length + e.props.Length(),
   133  				num:    prev.num + 1,
   134  			}
   135  
   136  			if p, ok := visited[path.node]; ok && p.Less(path) {
   137  				continue
   138  			}
   139  			visited[path.node] = path
   140  
   141  			if path.node == d {
   142  				return path.Nodes(), path.length
   143  			}
   144  
   145  			heap.Push(paths, path)
   146  		}
   147  	}
   148  }
   149  
   150  func (g *Graph[K, N, E]) TopologicalSort() []N {
   151  	if g.Size() == 0 {
   152  		return nil
   153  	}
   154  
   155  	log.Println(len(g.nodes))
   156  	nodes := make([]N, 0, len(g.nodes))
   157  	acyclic := true
   158  
   159  	temporary := make(map[*GraphNode[N]]struct{}, len(g.nodes))
   160  	permanent := make(map[*GraphNode[N]]struct{}, len(g.nodes))
   161  
   162  	for _, n := range g.nodes {
   163  		if _, ok := permanent[n]; ok {
   164  			continue
   165  		}
   166  
   167  		g.traverseDepthFirst(n, func(n *GraphNode[N], next func()) {
   168  			if _, ok := permanent[n]; ok {
   169  				return
   170  			}
   171  			if _, ok := temporary[n]; ok {
   172  				acyclic = false
   173  				return
   174  			}
   175  
   176  			temporary[n] = struct{}{}
   177  
   178  			next()
   179  
   180  			delete(temporary, n)
   181  			permanent[n] = struct{}{}
   182  			nodes = append(nodes, n.props)
   183  		})
   184  	}
   185  
   186  	if !acyclic {
   187  		return nil
   188  	}
   189  
   190  	for i := 0; i < len(nodes)/2; i++ {
   191  		nodes[i], nodes[len(nodes)-1-i] = nodes[len(nodes)-1-i], nodes[i]
   192  	}
   193  	return nodes
   194  }
   195  
   196  func (g *Graph[K, N, E]) traverseDepthFirst(n *GraphNode[N], fn func(n *GraphNode[N], next func())) {
   197  	fn(n, func() {
   198  		for i, e := range g.edges[n.i] {
   199  			if e != nil {
   200  				g.traverseDepthFirst(g.nodes[i], fn)
   201  			}
   202  		}
   203  	})
   204  }
   205  
   206  type graphPath[T any] struct {
   207  	prev   *graphPath[T]
   208  	node   *GraphNode[T]
   209  	length int64
   210  	num    int
   211  }
   212  
   213  func (p *graphPath[T]) nodes(i int) []T {
   214  	if p.prev == nil {
   215  		return append(make([]T, 0, i), p.node.props)
   216  	} else {
   217  		return append(p.prev.nodes(i+1), p.node.props)
   218  	}
   219  }
   220  
   221  func (p *graphPath[T]) Nodes() []T {
   222  	return p.nodes(1)
   223  }
   224  
   225  func (p *graphPath[T]) Less(o *graphPath[T]) bool {
   226  	return (p.length == o.length && p.num < o.num) || p.length < o.length
   227  }
   228  
   229  type graphPathMinHeap[T any] []*graphPath[T]
   230  
   231  func (h *graphPathMinHeap[T]) Len() int {
   232  	return len(*h)
   233  }
   234  
   235  func (h *graphPathMinHeap[T]) Less(i, j int) bool {
   236  	return (*h)[i].Less((*h)[j])
   237  }
   238  
   239  func (h *graphPathMinHeap[T]) Swap(i, j int) {
   240  	(*h)[i], (*h)[j] = (*h)[j], (*h)[i]
   241  }
   242  
   243  func (h *graphPathMinHeap[T]) Push(x any) {
   244  	*h = append(*h, x.(*graphPath[T]))
   245  }
   246  
   247  func (h *graphPathMinHeap[T]) Pop() any {
   248  	x := (*h)[len(*h)-1]
   249  	(*h)[len(*h)-1] = nil
   250  	*h = (*h)[:len(*h)-1]
   251  	return x
   252  }
   253  
   254  type GraphNode[T any] struct {
   255  	i     int
   256  	props T
   257  }
   258  
   259  type GraphEdge[N, E any] struct {
   260  	props E
   261  }
   262  
   263  const inf = int64(math.MaxInt64/2 - 1)
   264  
   265  func NewFlowGraph(n int64) FlowGraph {
   266  	cap := make([]int64, n*n)
   267  	cost := make([]int64, n*n)
   268  	return FlowGraph{n, cap, cost}
   269  }
   270  
   271  type FlowGraph struct {
   272  	n         int64
   273  	cap, cost []int64
   274  }
   275  
   276  func (g *FlowGraph) AddEdge(s, t, cap, cost int64) {
   277  	g.cap[s*g.n+t] = cap
   278  	g.cap[t*g.n+s] = cap
   279  	g.cost[s*g.n+t] = cost
   280  	g.cost[t*g.n+s] = cost
   281  }
   282  
   283  type MinCostMaxFlow struct {
   284  	found           []bool
   285  	n               int64
   286  	cap, flow, cost []int64
   287  	prev, dist, pi  []int64
   288  }
   289  
   290  func (f *MinCostMaxFlow) search(s, t int64) bool {
   291  	for i := range f.found {
   292  		f.found[i] = false
   293  	}
   294  	for i := range f.dist {
   295  		f.dist[i] = inf
   296  	}
   297  
   298  	f.dist[s] = 0
   299  
   300  	for s != f.n {
   301  		best := f.n
   302  		f.found[s] = true
   303  
   304  		for i := int64(0); i < f.n; i++ {
   305  			if f.found[i] {
   306  				continue
   307  			}
   308  
   309  			if f.flow[i*f.n+s] != 0 {
   310  				val := f.dist[s] + f.pi[s] - f.pi[i] - f.cost[i*f.n+s]
   311  				if f.dist[i] > val {
   312  					f.dist[i] = val
   313  					f.prev[i] = s
   314  				}
   315  			}
   316  
   317  			if f.flow[s*f.n+i] < f.cap[s*f.n+i] {
   318  				val := f.dist[s] + f.pi[s] - f.pi[i] + f.cost[s*f.n+i]
   319  				if f.dist[i] > val {
   320  					f.dist[i] = val
   321  					f.prev[i] = s
   322  				}
   323  			}
   324  
   325  			if f.dist[i] < f.dist[best] {
   326  				best = i
   327  			}
   328  		}
   329  
   330  		s = best
   331  	}
   332  
   333  	for i := int64(0); i < f.n; i++ {
   334  		pi := f.pi[i] + f.dist[i]
   335  		if pi > inf {
   336  			pi = inf
   337  		}
   338  		f.pi[i] = pi
   339  	}
   340  
   341  	return f.found[t]
   342  }
   343  
   344  func (f *MinCostMaxFlow) Flow(s, t int64) int64 {
   345  	return f.flow[s*f.n+t]
   346  }
   347  
   348  func (f *MinCostMaxFlow) ComputeMaxFlow(g FlowGraph, s, t int64) (flow, cost int64) {
   349  	f.cap = g.cap
   350  	f.cost = g.cost
   351  	f.n = g.n
   352  
   353  	f.found = make([]bool, f.n)
   354  	f.flow = make([]int64, f.n*f.n)
   355  	f.dist = make([]int64, f.n+1)
   356  	f.prev = make([]int64, f.n)
   357  	f.pi = make([]int64, f.n)
   358  
   359  	for f.search(s, t) {
   360  		pathFlow := inf
   361  		for u := t; u != s; u = f.prev[u] {
   362  			var pf int64
   363  			if f.flow[u*f.n+f.prev[u]] != 0 {
   364  				pf = f.flow[u*f.n+f.prev[u]]
   365  			} else {
   366  				pf = f.cap[f.prev[u]*f.n+u] - f.flow[f.prev[u]*f.n+u]
   367  			}
   368  			if pf < pathFlow {
   369  				pathFlow = pf
   370  			}
   371  		}
   372  
   373  		for u := t; u != s; u = f.prev[u] {
   374  			if f.flow[u*f.n+f.prev[u]] != 0 {
   375  				f.flow[u*f.n+f.prev[u]] -= pathFlow
   376  				cost -= pathFlow * f.cost[u*f.n+f.prev[u]]
   377  			} else {
   378  				f.flow[f.prev[u]*f.n+u] += pathFlow
   379  				cost += pathFlow * f.cost[f.prev[u]*f.n+u]
   380  			}
   381  		}
   382  		flow += pathFlow
   383  	}
   384  
   385  	return flow, cost
   386  }