github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/graph/path/a_star.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package path
     6  
     7  import (
     8  	"container/heap"
     9  
    10  	"github.com/jingcheng-WU/gonum/graph"
    11  	"github.com/jingcheng-WU/gonum/graph/internal/set"
    12  	"github.com/jingcheng-WU/gonum/graph/traverse"
    13  )
    14  
    15  // AStar finds the A*-shortest path from s to t in g using the heuristic h. The path and
    16  // its cost are returned in a Shortest along with paths and costs to all nodes explored
    17  // during the search. The number of expanded nodes is also returned. This value may help
    18  // with heuristic tuning.
    19  //
    20  // The path will be the shortest path if the heuristic is admissible. A heuristic is
    21  // admissible if for any node, n, in the graph, the heuristic estimate of the cost of
    22  // the path from n to t is less than or equal to the true cost of that path.
    23  //
    24  // If h is nil, AStar will use the g.HeuristicCost method if g implements HeuristicCoster,
    25  // falling back to NullHeuristic otherwise. If the graph does not implement Weighted,
    26  // UniformCost is used. AStar will panic if g has an A*-reachable negative edge weight.
    27  func AStar(s, t graph.Node, g traverse.Graph, h Heuristic) (path Shortest, expanded int) {
    28  	if g, ok := g.(graph.Graph); ok {
    29  		if g.Node(s.ID()) == nil || g.Node(t.ID()) == nil {
    30  			return Shortest{from: s}, 0
    31  		}
    32  	}
    33  	var weight Weighting
    34  	if wg, ok := g.(Weighted); ok {
    35  		weight = wg.Weight
    36  	} else {
    37  		weight = UniformCost(g)
    38  	}
    39  	if h == nil {
    40  		if g, ok := g.(HeuristicCoster); ok {
    41  			h = g.HeuristicCost
    42  		} else {
    43  			h = NullHeuristic
    44  		}
    45  	}
    46  
    47  	path = newShortestFrom(s, []graph.Node{s, t})
    48  	tid := t.ID()
    49  
    50  	visited := make(set.Int64s)
    51  	open := &aStarQueue{indexOf: make(map[int64]int)}
    52  	heap.Push(open, aStarNode{node: s, gscore: 0, fscore: h(s, t)})
    53  
    54  	for open.Len() != 0 {
    55  		u := heap.Pop(open).(aStarNode)
    56  		uid := u.node.ID()
    57  		i := path.indexOf[uid]
    58  		expanded++
    59  
    60  		if uid == tid {
    61  			break
    62  		}
    63  
    64  		visited.Add(uid)
    65  		to := g.From(u.node.ID())
    66  		for to.Next() {
    67  			v := to.Node()
    68  			vid := v.ID()
    69  			if visited.Has(vid) {
    70  				continue
    71  			}
    72  			j, ok := path.indexOf[vid]
    73  			if !ok {
    74  				j = path.add(v)
    75  			}
    76  
    77  			w, ok := weight(u.node.ID(), vid)
    78  			if !ok {
    79  				panic("path: A* unexpected invalid weight")
    80  			}
    81  			if w < 0 {
    82  				panic("path: A* negative edge weight")
    83  			}
    84  			g := u.gscore + w
    85  			if n, ok := open.node(vid); !ok {
    86  				path.set(j, g, i)
    87  				heap.Push(open, aStarNode{node: v, gscore: g, fscore: g + h(v, t)})
    88  			} else if g < n.gscore {
    89  				path.set(j, g, i)
    90  				open.update(vid, g, g+h(v, t))
    91  			}
    92  		}
    93  	}
    94  
    95  	return path, expanded
    96  }
    97  
    98  // NullHeuristic is an admissible, consistent heuristic that will not speed up computation.
    99  func NullHeuristic(_, _ graph.Node) float64 {
   100  	return 0
   101  }
   102  
   103  // aStarNode adds A* accounting to a graph.Node.
   104  type aStarNode struct {
   105  	node   graph.Node
   106  	gscore float64
   107  	fscore float64
   108  }
   109  
   110  // aStarQueue is an A* priority queue.
   111  type aStarQueue struct {
   112  	indexOf map[int64]int
   113  	nodes   []aStarNode
   114  }
   115  
   116  func (q *aStarQueue) Less(i, j int) bool {
   117  	return q.nodes[i].fscore < q.nodes[j].fscore
   118  }
   119  
   120  func (q *aStarQueue) Swap(i, j int) {
   121  	q.indexOf[q.nodes[i].node.ID()] = j
   122  	q.indexOf[q.nodes[j].node.ID()] = i
   123  	q.nodes[i], q.nodes[j] = q.nodes[j], q.nodes[i]
   124  }
   125  
   126  func (q *aStarQueue) Len() int {
   127  	return len(q.nodes)
   128  }
   129  
   130  func (q *aStarQueue) Push(x interface{}) {
   131  	n := x.(aStarNode)
   132  	q.indexOf[n.node.ID()] = len(q.nodes)
   133  	q.nodes = append(q.nodes, n)
   134  }
   135  
   136  func (q *aStarQueue) Pop() interface{} {
   137  	n := q.nodes[len(q.nodes)-1]
   138  	q.nodes = q.nodes[:len(q.nodes)-1]
   139  	delete(q.indexOf, n.node.ID())
   140  	return n
   141  }
   142  
   143  func (q *aStarQueue) update(id int64, g, f float64) {
   144  	i, ok := q.indexOf[id]
   145  	if !ok {
   146  		return
   147  	}
   148  	q.nodes[i].gscore = g
   149  	q.nodes[i].fscore = f
   150  	heap.Fix(q, i)
   151  }
   152  
   153  func (q *aStarQueue) node(id int64) (aStarNode, bool) {
   154  	loc, ok := q.indexOf[id]
   155  	if ok {
   156  		return q.nodes[loc], true
   157  	}
   158  	return aStarNode{}, false
   159  }