github.com/cayleygraph/cayley@v0.7.7/graph/sql/iterator.go (about)

     1  // Copyright 2017 The Cayley Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sql
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"github.com/cayleygraph/cayley/graph"
    24  	"github.com/cayleygraph/cayley/graph/shape"
    25  	"github.com/cayleygraph/quad"
    26  )
    27  
    28  var _ shape.Optimizer = (*QuadStore)(nil)
    29  
    30  func (qs *QuadStore) OptimizeShape(s shape.Shape) (shape.Shape, bool) {
    31  	return qs.opt.OptimizeShape(s)
    32  }
    33  
    34  func (qs *QuadStore) prepareQuery(s Shape) (string, []interface{}) {
    35  	args := s.Args()
    36  	vals := make([]interface{}, 0, len(args))
    37  	for _, a := range args {
    38  		vals = append(vals, a.SQLValue())
    39  	}
    40  	b := NewBuilder(qs.flavor.QueryDialect)
    41  	qu := s.SQL(b)
    42  	return qu, vals
    43  }
    44  
    45  func (qs *QuadStore) QueryRow(ctx context.Context, s Shape) *sql.Row {
    46  	qu, vals := qs.prepareQuery(s)
    47  	return qs.db.QueryRowContext(ctx, qu, vals...)
    48  }
    49  
    50  func (qs *QuadStore) Query(ctx context.Context, s Shape) (*sql.Rows, error) {
    51  	qu, vals := qs.prepareQuery(s)
    52  	rows, err := qs.db.QueryContext(ctx, qu, vals...)
    53  	if err != nil {
    54  		return nil, fmt.Errorf("sql query failed: %v\nquery: %v", err, qu)
    55  	}
    56  	return rows, nil
    57  }
    58  
    59  var _ graph.IteratorFuture = (*Iterator)(nil)
    60  
    61  func (qs *QuadStore) NewIterator(s Select) *Iterator {
    62  	it := &Iterator{
    63  		it: qs.newIterator(s),
    64  	}
    65  	it.Iterator = graph.NewLegacy(it.it, it)
    66  	return it
    67  }
    68  
    69  type Iterator struct {
    70  	it *iterator2
    71  	graph.Iterator
    72  }
    73  
    74  func (it *Iterator) AsShape() graph.IteratorShape {
    75  	it.Close()
    76  	return it.it
    77  }
    78  
    79  var _ graph.IteratorShapeCompat = (*iterator2)(nil)
    80  
    81  func (qs *QuadStore) newIterator(s Select) *iterator2 {
    82  	return &iterator2{
    83  		qs:    qs,
    84  		query: s,
    85  	}
    86  }
    87  
    88  type iterator2 struct {
    89  	qs    *QuadStore
    90  	query Select
    91  	err   error
    92  }
    93  
    94  func (it *iterator2) Iterate() graph.Scanner {
    95  	return newIteratorNext(it.qs, it.query)
    96  }
    97  
    98  func (it *iterator2) Lookup() graph.Index {
    99  	return newIteratorContains(it.qs, it.query)
   100  }
   101  
   102  func (it *iterator2) AsLegacy() graph.Iterator {
   103  	it2 := &Iterator{it: it}
   104  	it2.Iterator = graph.NewLegacy(it, it2)
   105  	return it2
   106  }
   107  
   108  func (it *iterator2) Stats(ctx context.Context) (graph.IteratorCosts, error) {
   109  	sz, err := it.getSize(ctx)
   110  	return graph.IteratorCosts{
   111  		NextCost:     1,
   112  		ContainsCost: 10,
   113  		Size:         sz,
   114  	}, err
   115  }
   116  
   117  func (it *iterator2) estimateSize(ctx context.Context) int64 {
   118  	if it.query.Limit > 0 {
   119  		return it.query.Limit
   120  	}
   121  	st, err := it.qs.Stats(ctx, false)
   122  	if err != nil && it.err == nil {
   123  		it.err = err
   124  	}
   125  	return st.Quads.Size
   126  }
   127  
   128  func (it *iterator2) getSize(ctx context.Context) (graph.Size, error) {
   129  	sz, err := it.qs.querySize(ctx, it.query)
   130  	if err != nil {
   131  		it.err = err
   132  		return graph.Size{Size: it.estimateSize(ctx), Exact: false}, err
   133  	}
   134  	return sz, nil
   135  }
   136  
   137  func (it *iterator2) Optimize(ctx context.Context) (graph.IteratorShape, bool) {
   138  	return it, false
   139  }
   140  
   141  func (it *iterator2) SubIterators() []graph.IteratorShape {
   142  	return nil
   143  }
   144  
   145  func (it *iterator2) String() string {
   146  	return it.query.SQL(NewBuilder(it.qs.flavor.QueryDialect))
   147  }
   148  
   149  func newIteratorBase(qs *QuadStore, s Select) iteratorBase {
   150  	return iteratorBase{
   151  		qs:    qs,
   152  		query: s,
   153  	}
   154  }
   155  
   156  type iteratorBase struct {
   157  	qs    *QuadStore
   158  	query Select
   159  
   160  	cols []string
   161  	cind map[quad.Direction]int
   162  
   163  	err  error
   164  	res  graph.Ref
   165  	tags map[string]graph.Ref
   166  }
   167  
   168  func (it *iteratorBase) TagResults(m map[string]graph.Ref) {
   169  	for tag, val := range it.tags {
   170  		m[tag] = val
   171  	}
   172  }
   173  
   174  func (it *iteratorBase) Result() graph.Ref {
   175  	return it.res
   176  }
   177  
   178  func (it *iteratorBase) ensureColumns() {
   179  	if it.cols != nil {
   180  		return
   181  	}
   182  	it.cols = it.query.Columns()
   183  	it.cind = make(map[quad.Direction]int, len(quad.Directions)+1)
   184  	for i, name := range it.cols {
   185  		if !strings.HasPrefix(name, tagPref) {
   186  			continue
   187  		}
   188  		if name == tagNode {
   189  			it.cind[quad.Any] = i
   190  			continue
   191  		}
   192  		name = name[len(tagPref):]
   193  		for _, d := range quad.Directions {
   194  			if name == d.String() {
   195  				it.cind[d] = i
   196  				break
   197  			}
   198  		}
   199  	}
   200  }
   201  
   202  func (it *iteratorBase) scanValue(r *sql.Rows) bool {
   203  	it.ensureColumns()
   204  	nodes := make([]NodeHash, len(it.cols))
   205  	pointers := make([]interface{}, len(nodes))
   206  	for i := range pointers {
   207  		pointers[i] = &nodes[i]
   208  	}
   209  	if err := r.Scan(pointers...); err != nil {
   210  		it.err = err
   211  		return false
   212  	}
   213  	it.tags = make(map[string]graph.Ref)
   214  	for i, name := range it.cols {
   215  		if !strings.Contains(name, tagPref) {
   216  			it.tags[name] = nodes[i].ValueHash
   217  		}
   218  	}
   219  	if len(it.cind) > 1 {
   220  		var q QuadHashes
   221  		for _, d := range quad.Directions {
   222  			i, ok := it.cind[d]
   223  			if !ok {
   224  				it.err = fmt.Errorf("cannot find quad %v in query output (columns: %v)", d, it.cols)
   225  				return false
   226  			}
   227  			q.Set(d, nodes[i].ValueHash)
   228  		}
   229  		it.res = q
   230  		return true
   231  	}
   232  	i, ok := it.cind[quad.Any]
   233  	if !ok {
   234  		it.err = fmt.Errorf("cannot find node hash in query output (columns: %v, cind: %v)", it.cols, it.cind)
   235  		return false
   236  	}
   237  	it.res = nodes[i]
   238  	return true
   239  }
   240  
   241  func (it *iteratorBase) Err() error {
   242  	return it.err
   243  }
   244  
   245  func (it *iteratorBase) String() string {
   246  	return it.query.SQL(NewBuilder(it.qs.flavor.QueryDialect))
   247  }
   248  
   249  func newIteratorNext(qs *QuadStore, s Select) *iteratorNext {
   250  	return &iteratorNext{
   251  		iteratorBase: newIteratorBase(qs, s),
   252  	}
   253  }
   254  
   255  type iteratorNext struct {
   256  	iteratorBase
   257  	cursor *sql.Rows
   258  	// TODO(dennwc): nextPath workaround; remove when we get rid of NextPath in general
   259  	nextPathRes  graph.Ref
   260  	nextPathTags map[string]graph.Ref
   261  }
   262  
   263  func (it *iteratorNext) Next(ctx context.Context) bool {
   264  	if it.err != nil {
   265  		return false
   266  	}
   267  	if it.cursor == nil {
   268  		it.cursor, it.err = it.qs.Query(ctx, it.query)
   269  	}
   270  	// TODO(dennwc): this loop exists only because of nextPath workaround
   271  	for {
   272  		if it.err != nil {
   273  			return false
   274  		}
   275  		if it.nextPathRes != nil {
   276  			it.res = it.nextPathRes
   277  			it.tags = it.nextPathTags
   278  			it.nextPathRes = nil
   279  			it.nextPathTags = nil
   280  			return true
   281  		}
   282  		if !it.cursor.Next() {
   283  			it.err = it.cursor.Err()
   284  			it.cursor.Close()
   285  			return false
   286  		}
   287  
   288  		prev := it.res
   289  		if !it.scanValue(it.cursor) {
   290  			return false
   291  		}
   292  		if !it.query.nextPath {
   293  			return true
   294  		}
   295  		if prev == nil || prev.Key() != it.res.Key() {
   296  			return true
   297  		}
   298  		// skip the same main key if in nextPath mode
   299  		// the user should receive accept those results via NextPath of the iterator
   300  	}
   301  }
   302  
   303  func (it *iteratorNext) NextPath(ctx context.Context) bool {
   304  	if it.err != nil {
   305  		return false
   306  	}
   307  	if !it.query.nextPath {
   308  		return false
   309  	}
   310  	if !it.cursor.Next() {
   311  		it.err = it.cursor.Err()
   312  		it.cursor.Close()
   313  		return false
   314  	}
   315  	prev := it.res
   316  	if !it.scanValue(it.cursor) {
   317  		return false
   318  	}
   319  	if prev.Key() == it.res.Key() {
   320  		return true
   321  	}
   322  	// different main keys - return false, but keep this results for the Next
   323  	it.nextPathRes = it.res
   324  	it.nextPathTags = it.tags
   325  	it.res = nil
   326  	it.tags = nil
   327  	return false
   328  }
   329  
   330  func (it *iteratorNext) Close() error {
   331  	if it.cursor != nil {
   332  		it.cursor.Close()
   333  		it.cursor = nil
   334  	}
   335  	return nil
   336  }
   337  
   338  func newIteratorContains(qs *QuadStore, s Select) *iteratorContains {
   339  	return &iteratorContains{
   340  		iteratorBase: newIteratorBase(qs, s),
   341  	}
   342  }
   343  
   344  type iteratorContains struct {
   345  	iteratorBase
   346  	// TODO(dennwc): nextPath workaround; remove when we get rid of NextPath in general
   347  	nextPathRows *sql.Rows
   348  }
   349  
   350  func (it *iteratorContains) Contains(ctx context.Context, v graph.Ref) bool {
   351  	it.ensureColumns()
   352  	sel := it.query
   353  	sel.Where = append([]Where{}, sel.Where...)
   354  	switch v := v.(type) {
   355  	case NodeHash:
   356  		i, ok := it.cind[quad.Any]
   357  		if !ok {
   358  			return false
   359  		}
   360  		f := it.query.Fields[i]
   361  		sel.WhereEq(f.Table, f.Name, v)
   362  	case QuadHashes:
   363  		for _, d := range quad.Directions {
   364  			i, ok := it.cind[d]
   365  			if !ok {
   366  				return false
   367  			}
   368  			h := v.Get(d)
   369  			if !h.Valid() {
   370  				continue
   371  			}
   372  			f := it.query.Fields[i]
   373  			sel.WhereEq(f.Table, f.Name, NodeHash{h})
   374  		}
   375  	default:
   376  		return false
   377  	}
   378  
   379  	rows, err := it.qs.Query(ctx, sel)
   380  	if err != nil {
   381  		it.err = err
   382  		return false
   383  	}
   384  	if it.query.nextPath {
   385  		if it.nextPathRows != nil {
   386  			_ = it.nextPathRows.Close()
   387  		}
   388  		it.nextPathRows = rows
   389  	} else {
   390  		defer rows.Close()
   391  	}
   392  	if !rows.Next() {
   393  		it.err = rows.Err()
   394  		return false
   395  	}
   396  	return it.scanValue(rows)
   397  }
   398  
   399  func (it *iteratorContains) NextPath(ctx context.Context) bool {
   400  	if it.err != nil {
   401  		return false
   402  	}
   403  	if !it.query.nextPath {
   404  		return false
   405  	}
   406  	if !it.nextPathRows.Next() {
   407  		it.err = it.nextPathRows.Err()
   408  		return false
   409  	}
   410  	return it.scanValue(it.nextPathRows)
   411  }
   412  
   413  func (it *iteratorContains) Close() error {
   414  	if it.nextPathRows != nil {
   415  		return it.nextPathRows.Close()
   416  	}
   417  	return nil
   418  }