github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/query/shortest.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package query
    18  
    19  import (
    20  	"container/heap"
    21  	"context"
    22  	"math"
    23  	"sync"
    24  
    25  	"github.com/dgraph-io/dgraph/algo"
    26  	"github.com/dgraph-io/dgraph/protos/pb"
    27  	"github.com/dgraph-io/dgraph/types"
    28  	"github.com/dgraph-io/dgraph/types/facets"
    29  	"github.com/dgraph-io/dgraph/x"
    30  	"github.com/pkg/errors"
    31  )
    32  
    33  type pathInfo struct {
    34  	uid   uint64
    35  	attr  string
    36  	facet *pb.Facets
    37  }
    38  
    39  type route struct {
    40  	route       []pathInfo
    41  	totalWeight float64
    42  }
    43  
    44  type queueItem struct {
    45  	uid   uint64  // uid of the node.
    46  	cost  float64 // cost of taking the path till this uid.
    47  	hop   int     // number of hops taken to reach this node.
    48  	index int
    49  	path  route // used in k shortest path.
    50  }
    51  
    52  var pathPool = sync.Pool{
    53  	New: func() interface{} {
    54  		return []pathInfo{}
    55  	},
    56  }
    57  
    58  var errStop = errors.Errorf("STOP")
    59  var errFacet = errors.Errorf("Skip the edge")
    60  
    61  type priorityQueue []*queueItem
    62  
    63  func (h priorityQueue) Len() int           { return len(h) }
    64  func (h priorityQueue) Less(i, j int) bool { return h[i].cost < h[j].cost }
    65  func (h priorityQueue) Swap(i, j int) {
    66  	h[i], h[j] = h[j], h[i]
    67  	h[i].index = i
    68  	h[j].index = j
    69  }
    70  func (h *priorityQueue) Push(x interface{}) {
    71  	n := len(*h)
    72  	item := x.(*queueItem)
    73  	item.index = n
    74  	*h = append(*h, item)
    75  }
    76  
    77  func (h *priorityQueue) Pop() interface{} {
    78  	old := *h
    79  	n := len(old)
    80  	x := old[n-1]
    81  	*h = old[0 : n-1]
    82  	x.index = -1
    83  	return x
    84  }
    85  
    86  type mapItem struct {
    87  	attr  string
    88  	cost  float64
    89  	facet *pb.Facets
    90  }
    91  
    92  // We manintain a map from UID to nodeInfo for Djikstras.
    93  type nodeInfo struct {
    94  	mapItem
    95  	parent uint64
    96  	// Pointer to the item in heap. Used to update priority
    97  	node *queueItem
    98  }
    99  
   100  func (sg *SubGraph) getCost(matrix, list int) (cost float64,
   101  	fcs *pb.Facets, rerr error) {
   102  
   103  	cost = 1.0
   104  	if len(sg.facetsMatrix) <= matrix {
   105  		return cost, fcs, rerr
   106  	}
   107  	fcsList := sg.facetsMatrix[matrix].FacetsList
   108  	if len(fcsList) <= list {
   109  		rerr = errFacet
   110  		return cost, fcs, rerr
   111  	}
   112  	fcs = fcsList[list]
   113  	if len(fcs.Facets) == 0 {
   114  		rerr = errFacet
   115  		return cost, fcs, rerr
   116  	}
   117  	if len(fcs.Facets) > 1 {
   118  		rerr = errors.Errorf("Expected 1 but got %d facets", len(fcs.Facets))
   119  		return cost, fcs, rerr
   120  	}
   121  	tv, err := facets.ValFor(fcs.Facets[0])
   122  	if err != nil {
   123  		return 0.0, nil, err
   124  	}
   125  	if tv.Tid == types.IntID {
   126  		cost = float64(tv.Value.(int64))
   127  	} else if tv.Tid == types.FloatID {
   128  		cost = float64(tv.Value.(float64))
   129  	} else {
   130  		rerr = errFacet
   131  	}
   132  	return cost, fcs, rerr
   133  }
   134  
   135  func (sg *SubGraph) expandOut(ctx context.Context,
   136  	adjacencyMap map[uint64]map[uint64]mapItem, next chan bool, rch chan error) {
   137  
   138  	var numEdges uint64
   139  	var exec []*SubGraph
   140  	var err error
   141  	in := []uint64{sg.Params.From}
   142  	sg.SrcUIDs = &pb.List{Uids: in}
   143  	sg.uidMatrix = []*pb.List{{Uids: in}}
   144  	sg.DestUIDs = sg.SrcUIDs
   145  
   146  	for _, child := range sg.Children {
   147  		child.SrcUIDs = sg.DestUIDs
   148  		exec = append(exec, child)
   149  	}
   150  	dummy := &SubGraph{}
   151  	for {
   152  		isNext := <-next
   153  		if !isNext {
   154  			return
   155  		}
   156  		rrch := make(chan error, len(exec))
   157  		for _, subgraph := range exec {
   158  			go ProcessGraph(ctx, subgraph, dummy, rrch)
   159  		}
   160  
   161  		for range exec {
   162  			select {
   163  			case err = <-rrch:
   164  				if err != nil {
   165  					rch <- err
   166  					return
   167  				}
   168  			case <-ctx.Done():
   169  				rch <- ctx.Err()
   170  				return
   171  			}
   172  		}
   173  
   174  		for _, subgraph := range exec {
   175  			select {
   176  			case <-ctx.Done():
   177  				rch <- ctx.Err()
   178  				return
   179  			default:
   180  				if subgraph.UnknownAttr {
   181  					continue
   182  				}
   183  
   184  				// Send the destuids in res chan.
   185  				for mIdx, fromUID := range subgraph.SrcUIDs.Uids {
   186  					// This can happen when trying to go traverse a predicate of type password
   187  					// for example.
   188  					if mIdx >= len(subgraph.uidMatrix) {
   189  						continue
   190  					}
   191  
   192  					for lIdx, toUID := range subgraph.uidMatrix[mIdx].Uids {
   193  						if adjacencyMap[fromUID] == nil {
   194  							adjacencyMap[fromUID] = make(map[uint64]mapItem)
   195  						}
   196  						// The default cost we'd use is 1.
   197  						cost, facet, err := subgraph.getCost(mIdx, lIdx)
   198  						if err == errFacet {
   199  							// Ignore the edge and continue.
   200  							continue
   201  						} else if err != nil {
   202  							rch <- err
   203  							return
   204  						}
   205  						adjacencyMap[fromUID][toUID] = mapItem{
   206  							cost:  cost,
   207  							facet: facet,
   208  							attr:  subgraph.Attr,
   209  						}
   210  						numEdges++
   211  					}
   212  				}
   213  			}
   214  		}
   215  
   216  		if numEdges > x.Config.QueryEdgeLimit {
   217  			// If we've seen too many edges, stop the query.
   218  			rch <- errors.Errorf("Exceeded query edge limit = %v. Found %v edges.",
   219  				x.Config.QueryEdgeLimit, numEdges)
   220  			return
   221  		}
   222  
   223  		// modify the exec and attach child nodes.
   224  		var out []*SubGraph
   225  		for _, subgraph := range exec {
   226  			if len(subgraph.DestUIDs.Uids) == 0 {
   227  				continue
   228  			}
   229  			select {
   230  			case <-ctx.Done():
   231  				rch <- ctx.Err()
   232  				return
   233  			default:
   234  				for _, child := range sg.Children {
   235  					temp := new(SubGraph)
   236  					temp.copyFiltersRecurse(child)
   237  
   238  					temp.SrcUIDs = subgraph.DestUIDs
   239  					// Remove those nodes which we have already traversed. As this cannot be
   240  					// in the path again.
   241  					algo.ApplyFilter(temp.SrcUIDs, func(uid uint64, i int) bool {
   242  						_, ok := adjacencyMap[uid]
   243  						return !ok
   244  					})
   245  					subgraph.Children = append(subgraph.Children, temp)
   246  					out = append(out, temp)
   247  				}
   248  			}
   249  		}
   250  
   251  		if len(out) == 0 {
   252  			rch <- errStop
   253  			return
   254  		}
   255  		rch <- nil
   256  		exec = out
   257  	}
   258  }
   259  
   260  func (sg *SubGraph) copyFiltersRecurse(otherSubgraph *SubGraph) {
   261  	*sg = *otherSubgraph
   262  	sg.Children = []*SubGraph{}
   263  	sg.Filters = []*SubGraph{}
   264  	for _, fc := range otherSubgraph.Filters {
   265  		tempChild := new(SubGraph)
   266  		tempChild.copyFiltersRecurse(fc)
   267  		sg.Filters = append(sg.Filters, tempChild)
   268  	}
   269  }
   270  
   271  func runKShortestPaths(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) {
   272  	var err error
   273  	if sg.Params.Alias != "shortest" {
   274  		return nil, errors.Errorf("Invalid shortest path query")
   275  	}
   276  
   277  	numPaths := sg.Params.numPaths
   278  	var kroutes []route
   279  	pq := make(priorityQueue, 0)
   280  	heap.Init(&pq)
   281  
   282  	// Initialize and push the source node.
   283  	srcNode := &queueItem{
   284  		uid:  sg.Params.From,
   285  		cost: 0,
   286  		hop:  0,
   287  		path: route{route: []pathInfo{{uid: sg.Params.From}}},
   288  	}
   289  	heap.Push(&pq, srcNode)
   290  
   291  	numHops := -1
   292  	maxHops := int(sg.Params.ExploreDepth)
   293  	if maxHops == 0 {
   294  		maxHops = int(math.MaxInt32)
   295  	}
   296  	minWeight := sg.Params.MinWeight
   297  	maxWeight := sg.Params.MaxWeight
   298  	next := make(chan bool, 2)
   299  	expandErr := make(chan error, 2)
   300  	adjacencyMap := make(map[uint64]map[uint64]mapItem)
   301  	go sg.expandOut(ctx, adjacencyMap, next, expandErr)
   302  
   303  	// In k shortest path we can't have this. We store the path till a node in every
   304  	// node.
   305  	// map to store the min cost and parent of nodes.
   306  	var stopExpansion bool
   307  	for pq.Len() > 0 {
   308  		item := heap.Pop(&pq).(*queueItem)
   309  		if item.uid == sg.Params.To {
   310  			// Ignore paths that do not meet the minimum weight requirement.
   311  			if item.cost < minWeight {
   312  				continue
   313  			}
   314  
   315  			// Add path to list.
   316  			newRoute := item.path
   317  			newRoute.totalWeight = item.cost
   318  			kroutes = append(kroutes, newRoute)
   319  			if len(kroutes) == numPaths {
   320  				// We found the required number of paths.
   321  				break
   322  			}
   323  		}
   324  		if item.hop > numHops && numHops < maxHops {
   325  			// Explore the next level by calling processGraph and add them
   326  			// to the queue.
   327  			if !stopExpansion {
   328  				next <- true
   329  				select {
   330  				case err = <-expandErr:
   331  					if err != nil {
   332  						if err == errStop {
   333  							stopExpansion = true
   334  						} else {
   335  							return nil, err
   336  						}
   337  					}
   338  				case <-ctx.Done():
   339  					return nil, ctx.Err()
   340  				}
   341  				numHops++
   342  			}
   343  		}
   344  		select {
   345  		case <-ctx.Done():
   346  			return nil, ctx.Err()
   347  		default:
   348  			if stopExpansion {
   349  				continue
   350  			}
   351  		}
   352  		neighbours := adjacencyMap[item.uid]
   353  		for toUid, info := range neighbours {
   354  			cost := info.cost
   355  			// Skip neighbour if the cost is greater than the maximum weight allowed.
   356  			if item.cost+cost > maxWeight {
   357  				continue
   358  			}
   359  
   360  			curPath := pathPool.Get().([]pathInfo)
   361  			if cap(curPath) < len(item.path.route)+1 {
   362  				// We can't use it due to insufficient capacity. Put it back.
   363  				pathPool.Put(curPath)
   364  				curPath = make([]pathInfo, len(item.path.route)+1)
   365  			} else {
   366  				// Use the curPath from pathPool. Set length appropriately.
   367  				curPath = curPath[:len(item.path.route)+1]
   368  			}
   369  			n := copy(curPath, item.path.route)
   370  			curPath[n] = pathInfo{
   371  				uid:   toUid,
   372  				attr:  info.attr,
   373  				facet: info.facet,
   374  			}
   375  			node := &queueItem{
   376  				uid:  toUid,
   377  				cost: item.cost + cost,
   378  				hop:  item.hop + 1,
   379  				path: route{route: curPath},
   380  			}
   381  			heap.Push(&pq, node)
   382  		}
   383  		// Return the popped nodes path to pool.
   384  		pathPool.Put(item.path.route)
   385  	}
   386  
   387  	next <- false
   388  
   389  	if len(kroutes) == 0 {
   390  		sg.DestUIDs = &pb.List{}
   391  		return nil, nil
   392  	}
   393  	var res []uint64
   394  	for _, it := range kroutes[0].route {
   395  		res = append(res, it.uid)
   396  	}
   397  	sg.DestUIDs.Uids = res
   398  	shortestSg := createkroutesubgraph(ctx, kroutes)
   399  	return shortestSg, nil
   400  }
   401  
   402  // Djikstras algorithm pseudocode for reference.
   403  //
   404  //
   405  // 1  function Dijkstra(Graph, source):
   406  // 2      dist[source] ← 0                                    // Initialization
   407  // 3
   408  // 4      create vertex set Q
   409  // 5
   410  // 6      for each vertex v in Graph:
   411  // 7          if v ≠ source
   412  // 8              dist[v] ← INFINITY                          // Unknown distance from source to v
   413  // 9              prev[v] ← UNDEFINED                         // Predecessor of v
   414  // 10
   415  // 11         Q.add_with_priority(v, dist[v])
   416  // 12
   417  // 13
   418  // 14     while Q is not empty:                              // The main loop
   419  // 15         u ← Q.extract_min()                            // Remove and return best vertex
   420  // 16         for each neighbor v of u:                       // only v that is still in Q
   421  // 17             alt = dist[u] + length(u, v)
   422  // 18             if alt < dist[v]
   423  // 19                 dist[v] ← alt
   424  // 20                 prev[v] ← u
   425  // 21                 Q.decrease_priority(v, alt)
   426  // 22
   427  // 23     return dist[], prev[]
   428  func shortestPath(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) {
   429  	var err error
   430  	if sg.Params.Alias != "shortest" {
   431  		return nil, errors.Errorf("Invalid shortest path query")
   432  	}
   433  	if sg.Params.From == 0 || sg.Params.To == 0 {
   434  		return nil, nil
   435  	}
   436  	numPaths := sg.Params.numPaths
   437  	if numPaths == 0 {
   438  		// Return 1 path by default.
   439  		numPaths = 1
   440  	}
   441  
   442  	if numPaths > 1 {
   443  		return runKShortestPaths(ctx, sg)
   444  	}
   445  	pq := make(priorityQueue, 0)
   446  	heap.Init(&pq)
   447  
   448  	// Initialize and push the source node.
   449  	srcNode := &queueItem{
   450  		uid:  sg.Params.From,
   451  		cost: 0,
   452  		hop:  0,
   453  	}
   454  	heap.Push(&pq, srcNode)
   455  
   456  	numHops := -1
   457  	maxHops := int(sg.Params.ExploreDepth)
   458  	if maxHops == 0 {
   459  		maxHops = int(math.MaxInt32)
   460  	}
   461  	next := make(chan bool, 2)
   462  	expandErr := make(chan error, 2)
   463  	adjacencyMap := make(map[uint64]map[uint64]mapItem)
   464  	go sg.expandOut(ctx, adjacencyMap, next, expandErr)
   465  
   466  	// map to store the min cost and parent of nodes.
   467  	dist := make(map[uint64]nodeInfo)
   468  	dist[srcNode.uid] = nodeInfo{
   469  		parent: 0,
   470  		node:   srcNode,
   471  		mapItem: mapItem{
   472  			cost: 0,
   473  		},
   474  	}
   475  
   476  	var stopExpansion bool
   477  	var totalWeight float64
   478  	for pq.Len() > 0 {
   479  		item := heap.Pop(&pq).(*queueItem)
   480  		if item.uid == sg.Params.To {
   481  			totalWeight = item.cost
   482  			break
   483  		}
   484  		if item.hop > numHops && numHops < maxHops {
   485  			// Explore the next level by calling processGraph and add them
   486  			// to the queue.
   487  			if !stopExpansion {
   488  				next <- true
   489  			}
   490  			select {
   491  			case err = <-expandErr:
   492  				if err != nil {
   493  					if err == errStop {
   494  						stopExpansion = true
   495  					} else {
   496  						return nil, err
   497  					}
   498  				}
   499  			case <-ctx.Done():
   500  				return nil, ctx.Err()
   501  			}
   502  			numHops++
   503  		}
   504  		select {
   505  		case <-ctx.Done():
   506  			return nil, ctx.Err()
   507  		default:
   508  			if !stopExpansion {
   509  				neighbours := adjacencyMap[item.uid]
   510  				for toUid, info := range neighbours {
   511  					cost := info.cost
   512  					d, ok := dist[toUid]
   513  					if ok && d.cost <= item.cost+cost {
   514  						continue
   515  					}
   516  					if !ok {
   517  						// This is the first time we're seeing this node. So
   518  						// create a new node and add it to the heap and map.
   519  						node := &queueItem{
   520  							uid:  toUid,
   521  							cost: item.cost + cost,
   522  							hop:  item.hop + 1,
   523  						}
   524  						heap.Push(&pq, node)
   525  						dist[toUid] = nodeInfo{
   526  							parent: item.uid,
   527  							node:   node,
   528  							mapItem: mapItem{
   529  								cost:  item.cost + cost,
   530  								attr:  info.attr,
   531  								facet: info.facet,
   532  							},
   533  						}
   534  					} else {
   535  						// We've already seen this node. So, just update the cost
   536  						// and fix the priority in the heap and map.
   537  						node := dist[toUid].node
   538  						node.cost = item.cost + cost
   539  						node.hop = item.hop + 1
   540  						heap.Fix(&pq, node.index)
   541  						// Update the map with new values.
   542  						dist[toUid] = nodeInfo{
   543  							parent: item.uid,
   544  							node:   node,
   545  							mapItem: mapItem{
   546  								cost:  item.cost + cost,
   547  								attr:  info.attr,
   548  								facet: info.facet,
   549  							},
   550  						}
   551  					}
   552  				}
   553  			}
   554  		}
   555  	}
   556  
   557  	next <- false
   558  	// Go through the distance map to find the path.
   559  	var result []uint64
   560  	cur := sg.Params.To
   561  	for i := 0; cur != sg.Params.From && i < len(dist); i++ {
   562  		result = append(result, cur)
   563  		cur = dist[cur].parent
   564  	}
   565  	// Put the path in DestUIDs of the root.
   566  	if cur != sg.Params.From {
   567  		sg.DestUIDs = &pb.List{}
   568  		return nil, nil
   569  	}
   570  
   571  	result = append(result, cur)
   572  	l := len(result)
   573  	// Reverse the list.
   574  	for i := 0; i < l/2; i++ {
   575  		result[i], result[l-i-1] = result[l-i-1], result[i]
   576  	}
   577  	sg.DestUIDs.Uids = result
   578  
   579  	shortestSg := createPathSubgraph(ctx, dist, totalWeight, result)
   580  	return []*SubGraph{shortestSg}, nil
   581  }
   582  
   583  func createPathSubgraph(ctx context.Context, dist map[uint64]nodeInfo, totalWeight float64,
   584  	result []uint64) *SubGraph {
   585  	shortestSg := new(SubGraph)
   586  	shortestSg.Params = params{
   587  		Alias:    "_path_",
   588  		shortest: true,
   589  	}
   590  	shortestSg.pathMeta = &pathMetadata{
   591  		weight: totalWeight,
   592  	}
   593  	curUid := result[0]
   594  	shortestSg.SrcUIDs = &pb.List{Uids: []uint64{curUid}}
   595  	shortestSg.DestUIDs = &pb.List{Uids: []uint64{curUid}}
   596  	shortestSg.uidMatrix = []*pb.List{{Uids: []uint64{curUid}}}
   597  
   598  	curNode := shortestSg
   599  	for i := 0; i < len(result)-1; i++ {
   600  		curUid := result[i]
   601  		childUid := result[i+1]
   602  		node := new(SubGraph)
   603  		nodeInfo := dist[childUid]
   604  		node.Params = params{
   605  			shortest: true,
   606  		}
   607  		if nodeInfo.facet != nil {
   608  			// For consistent later processing.
   609  			node.Params.Facet = &pb.FacetParams{}
   610  		}
   611  		node.Attr = nodeInfo.attr
   612  		node.facetsMatrix = []*pb.FacetsList{{FacetsList: []*pb.Facets{nodeInfo.facet}}}
   613  		node.SrcUIDs = &pb.List{Uids: []uint64{curUid}}
   614  		node.DestUIDs = &pb.List{Uids: []uint64{childUid}}
   615  		node.uidMatrix = []*pb.List{{Uids: []uint64{childUid}}}
   616  
   617  		curNode.Children = append(curNode.Children, node)
   618  		curNode = node
   619  	}
   620  
   621  	node := new(SubGraph)
   622  	node.Params = params{
   623  		shortest: true,
   624  	}
   625  	uid := result[len(result)-1]
   626  	node.SrcUIDs = &pb.List{Uids: []uint64{uid}}
   627  	node.uidMatrix = []*pb.List{{Uids: []uint64{uid}}}
   628  	curNode.Children = append(curNode.Children, node)
   629  
   630  	return shortestSg
   631  }
   632  
   633  func createkroutesubgraph(ctx context.Context, kroutes []route) []*SubGraph {
   634  	var res []*SubGraph
   635  	for _, it := range kroutes {
   636  		shortestSg := new(SubGraph)
   637  		shortestSg.Params = params{
   638  			Alias:    "_path_",
   639  			shortest: true,
   640  		}
   641  		shortestSg.pathMeta = &pathMetadata{
   642  			weight: it.totalWeight,
   643  		}
   644  		curUid := it.route[0].uid
   645  		shortestSg.SrcUIDs = &pb.List{Uids: []uint64{curUid}}
   646  		shortestSg.DestUIDs = &pb.List{Uids: []uint64{curUid}}
   647  		shortestSg.uidMatrix = []*pb.List{{Uids: []uint64{curUid}}}
   648  
   649  		curNode := shortestSg
   650  		i := 0
   651  		for ; i < len(it.route)-1; i++ {
   652  			curUid := it.route[i].uid
   653  			childUid := it.route[i+1].uid
   654  			node := new(SubGraph)
   655  			node.Params = params{
   656  				shortest: true,
   657  			}
   658  			if it.route[i+1].facet != nil {
   659  				// For consistent later processing.
   660  				node.Params.Facet = &pb.FacetParams{}
   661  			}
   662  			node.Attr = it.route[i+1].attr
   663  			node.facetsMatrix = []*pb.FacetsList{{FacetsList: []*pb.Facets{it.route[i+1].facet}}}
   664  			node.SrcUIDs = &pb.List{Uids: []uint64{curUid}}
   665  			node.DestUIDs = &pb.List{Uids: []uint64{childUid}}
   666  			node.uidMatrix = []*pb.List{{Uids: []uint64{childUid}}}
   667  
   668  			curNode.Children = append(curNode.Children, node)
   669  			curNode = node
   670  		}
   671  
   672  		node := new(SubGraph)
   673  		node.Params = params{
   674  			shortest: true,
   675  		}
   676  		uid := it.route[i].uid
   677  		node.SrcUIDs = &pb.List{Uids: []uint64{uid}}
   678  		node.uidMatrix = []*pb.List{{Uids: []uint64{uid}}}
   679  		curNode.Children = append(curNode.Children, node)
   680  
   681  		res = append(res, shortestSg)
   682  	}
   683  	return res
   684  }