github.com/cayleygraph/cayley@v0.7.7/graph/iterator/recursive.go (about)

     1  package iterator
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  
     7  	"github.com/cayleygraph/cayley/graph"
     8  	"github.com/cayleygraph/quad"
     9  )
    10  
    11  var _ graph.IteratorFuture = &Recursive{}
    12  
    13  const recursiveBaseTag = "__base_recursive"
    14  
    15  // Recursive iterator takes a base iterator and a morphism to be applied recursively, for each result.
    16  type Recursive struct {
    17  	it *recursive
    18  	graph.Iterator
    19  }
    20  
    21  type seenAt struct {
    22  	depth int
    23  	tags  map[string]graph.Ref
    24  	val   graph.Ref
    25  }
    26  
    27  var DefaultMaxRecursiveSteps = 50
    28  
    29  func NewRecursive(sub graph.Iterator, morphism Morphism, maxDepth int) *Recursive {
    30  	it := &Recursive{
    31  		it: newRecursive(graph.AsShape(sub), func(it graph.IteratorShape) graph.IteratorShape {
    32  			return graph.AsShape(morphism(graph.AsLegacy(it)))
    33  		}, maxDepth),
    34  	}
    35  	it.Iterator = graph.NewLegacy(it.it, it)
    36  	return it
    37  }
    38  
    39  func (it *Recursive) AddDepthTag(s string) {
    40  	it.it.AddDepthTag(s)
    41  }
    42  
    43  func (it *Recursive) AsShape() graph.IteratorShape {
    44  	it.Close()
    45  	return it.it
    46  }
    47  
    48  var _ graph.IteratorShapeCompat = &recursive{}
    49  
    50  // Recursive iterator takes a base iterator and a morphism to be applied recursively, for each result.
    51  type recursive struct {
    52  	subIt     graph.IteratorShape
    53  	morphism  Morphism2
    54  	maxDepth  int
    55  	depthTags []string
    56  }
    57  
    58  func newRecursive(it graph.IteratorShape, morphism Morphism2, maxDepth int) *recursive {
    59  	if maxDepth == 0 {
    60  		maxDepth = DefaultMaxRecursiveSteps
    61  	}
    62  	return &recursive{
    63  		subIt:    it,
    64  		morphism: morphism,
    65  		maxDepth: maxDepth,
    66  	}
    67  }
    68  
    69  func (it *recursive) Iterate() graph.Scanner {
    70  	return newRecursiveNext(it.subIt.Iterate(), it.morphism, it.maxDepth, it.depthTags)
    71  }
    72  
    73  func (it *recursive) Lookup() graph.Index {
    74  	return newRecursiveContains(newRecursiveNext(it.subIt.Iterate(), it.morphism, it.maxDepth, it.depthTags))
    75  }
    76  
    77  func (it *recursive) AsLegacy() graph.Iterator {
    78  	it2 := &Recursive{it: it}
    79  	it2.Iterator = graph.NewLegacy(it, it2)
    80  	return it2
    81  }
    82  
    83  func (it *recursive) AddDepthTag(s string) {
    84  	it.depthTags = append(it.depthTags, s)
    85  }
    86  
    87  func (it *recursive) SubIterators() []graph.IteratorShape {
    88  	return []graph.IteratorShape{it.subIt}
    89  }
    90  
    91  func (it *recursive) Optimize(ctx context.Context) (graph.IteratorShape, bool) {
    92  	newIt, optimized := it.subIt.Optimize(ctx)
    93  	if optimized {
    94  		it.subIt = newIt
    95  	}
    96  	return it, false
    97  }
    98  
    99  func (it *recursive) Stats(ctx context.Context) (graph.IteratorCosts, error) {
   100  	base := newFixed()
   101  	base.Add(Int64Node(20))
   102  	fanoutit := it.morphism(base)
   103  	fanoutStats, err := fanoutit.Stats(ctx)
   104  	subitStats, err2 := it.subIt.Stats(ctx)
   105  	if err == nil {
   106  		err = err2
   107  	}
   108  	size := int64(math.Pow(float64(subitStats.Size.Size*fanoutStats.Size.Size), 5))
   109  	return graph.IteratorCosts{
   110  		NextCost:     subitStats.NextCost + fanoutStats.NextCost,
   111  		ContainsCost: (subitStats.NextCost+fanoutStats.NextCost)*(size/10) + subitStats.ContainsCost,
   112  		Size: graph.Size{
   113  			Size:  size,
   114  			Exact: false,
   115  		},
   116  	}, err
   117  }
   118  
   119  func (it *recursive) String() string {
   120  	return "Recursive"
   121  }
   122  
   123  // Recursive iterator takes a base iterator and a morphism to be applied recursively, for each result.
   124  type recursiveNext struct {
   125  	subIt  graph.Scanner
   126  	result seenAt
   127  	err    error
   128  
   129  	morphism      Morphism2
   130  	seen          map[interface{}]seenAt
   131  	nextIt        graph.Scanner
   132  	depth         int
   133  	maxDepth      int
   134  	pathMap       map[interface{}][]map[string]graph.Ref
   135  	pathIndex     int
   136  	containsValue graph.Ref
   137  	depthTags     []string
   138  	depthCache    []graph.Ref
   139  	baseIt        *fixed
   140  }
   141  
   142  func newRecursiveNext(it graph.Scanner, morphism Morphism2, maxDepth int, depthTags []string) *recursiveNext {
   143  	return &recursiveNext{
   144  		subIt:     it,
   145  		morphism:  morphism,
   146  		maxDepth:  maxDepth,
   147  		depthTags: depthTags,
   148  
   149  		seen:    make(map[interface{}]seenAt),
   150  		nextIt:  &Null{},
   151  		baseIt:  newFixed(),
   152  		pathMap: make(map[interface{}][]map[string]graph.Ref),
   153  	}
   154  }
   155  
   156  func (it *recursiveNext) TagResults(dst map[string]graph.Ref) {
   157  	for _, tag := range it.depthTags {
   158  		dst[tag] = graph.PreFetched(quad.Int(it.result.depth))
   159  	}
   160  
   161  	if it.containsValue != nil {
   162  		paths := it.pathMap[graph.ToKey(it.containsValue)]
   163  		if len(paths) != 0 {
   164  			for k, v := range paths[it.pathIndex] {
   165  				dst[k] = v
   166  			}
   167  		}
   168  	}
   169  	if it.nextIt != nil {
   170  		it.nextIt.TagResults(dst)
   171  		delete(dst, recursiveBaseTag)
   172  	}
   173  }
   174  
   175  func (it *recursiveNext) Next(ctx context.Context) bool {
   176  	it.pathIndex = 0
   177  	if it.depth == 0 {
   178  		for it.subIt.Next(ctx) {
   179  			res := it.subIt.Result()
   180  			it.depthCache = append(it.depthCache, it.subIt.Result())
   181  			tags := make(map[string]graph.Ref)
   182  			it.subIt.TagResults(tags)
   183  			key := graph.ToKey(res)
   184  			it.pathMap[key] = append(it.pathMap[key], tags)
   185  			for it.subIt.NextPath(ctx) {
   186  				tags := make(map[string]graph.Ref)
   187  				it.subIt.TagResults(tags)
   188  				it.pathMap[key] = append(it.pathMap[key], tags)
   189  			}
   190  		}
   191  	}
   192  
   193  	for {
   194  		if !it.nextIt.Next(ctx) {
   195  			if it.maxDepth > 0 && it.depth >= it.maxDepth {
   196  				return false
   197  			} else if len(it.depthCache) == 0 {
   198  				return false
   199  			}
   200  			it.depth++
   201  			it.baseIt = newFixed(it.depthCache...)
   202  			it.depthCache = nil
   203  			if it.nextIt != nil {
   204  				it.nextIt.Close()
   205  			}
   206  			it.nextIt = it.morphism(TagShape(it.baseIt, recursiveBaseTag)).Iterate()
   207  			continue
   208  		}
   209  		val := it.nextIt.Result()
   210  		results := make(map[string]graph.Ref)
   211  		it.nextIt.TagResults(results)
   212  		key := graph.ToKey(val)
   213  		if _, seen := it.seen[key]; !seen {
   214  			base := results[recursiveBaseTag]
   215  			delete(results, recursiveBaseTag)
   216  			it.seen[key] = seenAt{
   217  				val:   base,
   218  				depth: it.depth,
   219  				tags:  results,
   220  			}
   221  			it.result.depth = it.depth
   222  			it.result.val = val
   223  			it.containsValue = it.getBaseValue(val)
   224  			it.depthCache = append(it.depthCache, val)
   225  			return true
   226  		}
   227  	}
   228  }
   229  
   230  func (it *recursiveNext) Err() error {
   231  	return it.err
   232  }
   233  
   234  func (it *recursiveNext) Result() graph.Ref {
   235  	return it.result.val
   236  }
   237  
   238  func (it *recursiveNext) getBaseValue(val graph.Ref) graph.Ref {
   239  	var at seenAt
   240  	var ok bool
   241  	if at, ok = it.seen[graph.ToKey(val)]; !ok {
   242  		panic("trying to getBaseValue of something unseen")
   243  	}
   244  	for at.depth != 1 {
   245  		if at.depth == 0 {
   246  			panic("seen chain is broken")
   247  		}
   248  		at = it.seen[graph.ToKey(at.val)]
   249  	}
   250  	return at.val
   251  }
   252  
   253  func (it *recursiveNext) NextPath(ctx context.Context) bool {
   254  	if it.pathIndex+1 >= len(it.pathMap[graph.ToKey(it.containsValue)]) {
   255  		return false
   256  	}
   257  	it.pathIndex++
   258  	return true
   259  }
   260  
   261  func (it *recursiveNext) Close() error {
   262  	err := it.subIt.Close()
   263  	if err != nil {
   264  		return err
   265  	}
   266  	err = it.nextIt.Close()
   267  	if err != nil {
   268  		return err
   269  	}
   270  	it.seen = nil
   271  	return it.err
   272  }
   273  
   274  func (it *recursiveNext) String() string {
   275  	return "RecursiveNext"
   276  }
   277  
   278  // Recursive iterator takes a base iterator and a morphism to be applied recursively, for each result.
   279  type recursiveContains struct {
   280  	next *recursiveNext
   281  	tags map[string]graph.Ref
   282  }
   283  
   284  func newRecursiveContains(next *recursiveNext) *recursiveContains {
   285  	return &recursiveContains{
   286  		next: next,
   287  	}
   288  }
   289  
   290  func (it *recursiveContains) TagResults(dst map[string]graph.Ref) {
   291  	it.next.TagResults(dst)
   292  	for k, v := range it.tags {
   293  		dst[k] = v
   294  	}
   295  }
   296  
   297  func (it *recursiveContains) Err() error {
   298  	return it.next.Err()
   299  }
   300  
   301  func (it *recursiveContains) Result() graph.Ref {
   302  	return it.next.Result()
   303  }
   304  
   305  func (it *recursiveContains) Contains(ctx context.Context, val graph.Ref) bool {
   306  	it.next.pathIndex = 0
   307  	key := graph.ToKey(val)
   308  	if at, ok := it.next.seen[key]; ok {
   309  		it.next.containsValue = it.next.getBaseValue(val)
   310  		it.next.result.depth = at.depth
   311  		it.next.result.val = val
   312  		it.tags = at.tags
   313  		return true
   314  	}
   315  	for it.next.Next(ctx) {
   316  		if graph.ToKey(it.next.Result()) == key {
   317  			return true
   318  		}
   319  	}
   320  	return false
   321  }
   322  
   323  func (it *recursiveContains) NextPath(ctx context.Context) bool {
   324  	return it.next.NextPath(ctx)
   325  }
   326  
   327  func (it *recursiveContains) Close() error {
   328  	return it.next.Close()
   329  }
   330  
   331  func (it *recursiveContains) String() string {
   332  	return "RecursiveContains(" + it.next.String() + ")"
   333  }