github.com/cayleygraph/cayley@v0.7.7/schema/loader.go (about)

     1  package schema
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  
     9  	"github.com/cayleygraph/cayley/graph/path"
    10  	"github.com/cayleygraph/quad"
    11  
    12  	"github.com/cayleygraph/cayley/graph"
    13  	"github.com/cayleygraph/cayley/graph/iterator"
    14  )
    15  
    16  var (
    17  	errNotFound               = errors.New("not found")
    18  	errRequiredFieldIsMissing = errors.New("required field is missing")
    19  )
    20  
    21  // Optimize flags controls an optimization step performed before queries.
    22  var Optimize = true
    23  
    24  // IsNotFound check if error is related to a missing object (either because of wrong ID or because of type constrains).
    25  func IsNotFound(err error) bool {
    26  	return err == errNotFound || err == errRequiredFieldIsMissing
    27  }
    28  
    29  // LoadTo will load a sub-graph of objects starting from ids (or from any nodes, if empty)
    30  // to a destination Go object. Destination can be a struct, slice or channel.
    31  //
    32  // Mapping to quads is done via Go struct tag "quad" or "json" as a fallback.
    33  //
    34  // A simplest mapping is an "@id" tag which saves node ID (subject of a quad) into tagged field.
    35  //
    36  //	type Node struct{
    37  //		ID quad.IRI `json:"@id"` // or `quad:"@id"`
    38  // 	}
    39  //
    40  // Field with an "@id" tag is omitted, but in case of Go->quads mapping new ID will be generated
    41  // using GenerateID callback, which can be changed to provide a custom mappings.
    42  //
    43  // All other tags are interpreted as a predicate name for a specific field:
    44  //
    45  //	type Person struct{
    46  //		ID quad.IRI `json:"@id"`
    47  //		Name string `json:"name"`
    48  // 	}
    49  //	p := Person{"bob","Bob"}
    50  //	// is equivalent to triple:
    51  //	// <bob> <name> "Bob"
    52  //
    53  // Predicate IRIs in RDF can have a long namespaces, but they can be written in short
    54  // form. They will be expanded automatically if namespace prefix is registered within
    55  // QuadStore or globally via "voc" package.
    56  // There is also a special predicate name "@type" which is mapped to "rdf:type" IRI.
    57  //
    58  //	voc.RegisterPrefix("ex:", "http://example.org/")
    59  //	type Person struct{
    60  //		ID quad.IRI `json:"@id"`
    61  //		Type quad.IRI `json:"@type"`
    62  //		Name string `json:"ex:name"` // will be expanded to http://example.org/name
    63  // 	}
    64  //	p := Person{"bob",quad.IRI("Person"),"Bob"}
    65  //	// is equivalent to triples:
    66  //	// <bob> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <Person>
    67  //	// <bob> <http://example.org/name> "Bob"
    68  //
    69  // Predicate link direction can be reversed with a special tag syntax (not available for "json" tag):
    70  //
    71  // 	type Person struct{
    72  //		ID quad.IRI `json:"@id"`
    73  //		Name string `json:"name"` // same as `quad:"name"` or `quad:"name > *"`
    74  //		Parents []quad.IRI `quad:"isParentOf < *"`
    75  // 	}
    76  //	p := Person{"bob","Bob",[]quad.IRI{"alice","fred"}}
    77  //	// is equivalent to triples:
    78  //	// <bob> <name> "Bob"
    79  //	// <alice> <isParentOf> <bob>
    80  //	// <fred> <isParentOf> <bob>
    81  //
    82  // All fields in structs are interpreted as required (except slices), thus struct will not be
    83  // loaded if one of fields is missing. An "optional" tag can be specified to relax this requirement.
    84  // Also, "required" can be specified for slices to alter default value.
    85  //
    86  //	type Person struct{
    87  //		ID quad.IRI `json:"@id"`
    88  //		Name string `json:"name"` // required field
    89  //		ThirdName string `quad:"thirdName,optional"` // can be empty
    90  //		FollowedBy []quad.IRI `quad:"follows"`
    91  // 	}
    92  func (c *Config) LoadTo(ctx context.Context, qs graph.QuadStore, dst interface{}, ids ...quad.Value) error {
    93  	return c.LoadToDepth(ctx, qs, dst, -1, ids...)
    94  }
    95  
    96  // LoadToDepth is the same as LoadTo, but stops at a specified depth.
    97  // Negative value means unlimited depth, and zero means top level only.
    98  func (c *Config) LoadToDepth(ctx context.Context, qs graph.QuadStore, dst interface{}, depth int, ids ...quad.Value) error {
    99  	if dst == nil {
   100  		return fmt.Errorf("nil destination object")
   101  	}
   102  	var it graph.Iterator
   103  	if len(ids) != 0 {
   104  		fixed := iterator.NewFixed()
   105  		for _, id := range ids {
   106  			fixed.Add(qs.ValueOf(id))
   107  		}
   108  		it = fixed
   109  	}
   110  	var rv reflect.Value
   111  	if v, ok := dst.(reflect.Value); ok {
   112  		rv = v
   113  	} else {
   114  		rv = reflect.ValueOf(dst)
   115  	}
   116  	return c.LoadIteratorToDepth(ctx, qs, rv, depth, it)
   117  }
   118  
   119  // LoadPathTo is the same as LoadTo, but starts loading objects from a given path.
   120  func (c *Config) LoadPathTo(ctx context.Context, qs graph.QuadStore, dst interface{}, p *path.Path) error {
   121  	return c.LoadIteratorTo(ctx, qs, reflect.ValueOf(dst), p.BuildIterator())
   122  }
   123  
   124  // LoadIteratorTo is a lower level version of LoadTo.
   125  //
   126  // It expects an iterator of nodes to be passed explicitly and
   127  // destination value to be obtained via reflect package manually.
   128  //
   129  // Nodes iterator can be nil, All iterator will be used in this case.
   130  func (c *Config) LoadIteratorTo(ctx context.Context, qs graph.QuadStore, dst reflect.Value, list graph.Iterator) error {
   131  	return c.LoadIteratorToDepth(ctx, qs, dst, -1, list)
   132  }
   133  
   134  // LoadIteratorToDepth is the same as LoadIteratorTo, but stops at a specified depth.
   135  // Negative value means unlimited depth, and zero means top level only.
   136  func (c *Config) LoadIteratorToDepth(ctx context.Context, qs graph.QuadStore, dst reflect.Value, depth int, list graph.Iterator) error {
   137  	if depth >= 0 {
   138  		// 0 depth means "current level only" for user, but it's easier to make depth=0 a stop condition
   139  		depth++
   140  	}
   141  	l := c.newLoader(qs)
   142  	return l.loadIteratorToDepth(ctx, dst, depth, list)
   143  }
   144  
   145  type loader struct {
   146  	c  *Config
   147  	qs graph.QuadStore
   148  
   149  	pathForType     map[reflect.Type]*path.Path
   150  	pathForTypeRoot map[reflect.Type]*path.Path
   151  
   152  	seen map[quad.Value]reflect.Value
   153  }
   154  
   155  func (c *Config) newLoader(qs graph.QuadStore) *loader {
   156  	return &loader{
   157  		c:  c,
   158  		qs: qs,
   159  
   160  		pathForType:     make(map[reflect.Type]*path.Path),
   161  		pathForTypeRoot: make(map[reflect.Type]*path.Path),
   162  
   163  		seen: make(map[quad.Value]reflect.Value),
   164  	}
   165  }
   166  
   167  func (l *loader) makePathForType(rt reflect.Type, tagPref string, rootOnly bool) (*path.Path, error) {
   168  	for rt.Kind() == reflect.Ptr {
   169  		rt = rt.Elem()
   170  	}
   171  	if rt.Kind() != reflect.Struct {
   172  		return nil, fmt.Errorf("expected struct, got %v", rt)
   173  	}
   174  	if tagPref == "" {
   175  		m := l.pathForType
   176  		if rootOnly {
   177  			m = l.pathForTypeRoot
   178  		}
   179  		if p, ok := m[rt]; ok {
   180  			return p, nil
   181  		}
   182  	}
   183  
   184  	p := path.StartMorphism()
   185  
   186  	if iri := getTypeIRI(rt); iri != quad.IRI("") {
   187  		p = p.Has(l.c.iri(iriType), iri)
   188  	}
   189  
   190  	// TODO(dennwc): rewrite to shapes
   191  
   192  	allOptional := true
   193  	var alt *path.Path
   194  	for i := 0; i < rt.NumField(); i++ {
   195  		f := rt.Field(i)
   196  		if f.Anonymous {
   197  			pa, err := l.makePathForType(f.Type, tagPref+f.Name+".", rootOnly)
   198  			if err != nil {
   199  				return nil, err
   200  			}
   201  			p = p.Follow(pa)
   202  			continue
   203  		}
   204  		name := f.Name
   205  		rule, err := l.c.fieldRule(f)
   206  		if err != nil {
   207  			return nil, err
   208  		} else if rule == nil { // skip
   209  			continue
   210  		}
   211  		ft := f.Type
   212  		if ft.Kind() == reflect.Ptr {
   213  			ft = ft.Elem()
   214  		}
   215  		if err = checkFieldType(ft); err != nil {
   216  			return nil, err
   217  		}
   218  		switch rule := rule.(type) {
   219  		case idRule:
   220  			p = p.Tag(tagPref + name)
   221  		case constraintRule:
   222  			allOptional = false
   223  			var nodes []quad.Value
   224  			if rule.Val != "" {
   225  				nodes = []quad.Value{rule.Val}
   226  			}
   227  			if rule.Rev {
   228  				p = p.HasReverse(rule.Pred, nodes...)
   229  			} else {
   230  				p = p.Has(rule.Pred, nodes...)
   231  			}
   232  		case saveRule:
   233  			tag := tagPref + name
   234  			if rule.Opt {
   235  				if !rootOnly {
   236  					if rule.Rev {
   237  						p = p.SaveOptionalReverse(rule.Pred, tag)
   238  						if allOptional {
   239  							ap := path.StartMorphism().HasReverse(rule.Pred)
   240  							if alt == nil {
   241  								alt = ap
   242  							} else {
   243  								alt = alt.Or(ap)
   244  							}
   245  						}
   246  					} else {
   247  						p = p.SaveOptional(rule.Pred, tag)
   248  						if allOptional {
   249  							ap := path.StartMorphism().Has(rule.Pred)
   250  							if alt == nil {
   251  								alt = ap
   252  							} else {
   253  								alt = alt.Or(ap)
   254  							}
   255  						}
   256  					}
   257  				}
   258  			} else if rootOnly { // do not save field, enforce constraint only
   259  				allOptional = false
   260  				if rule.Rev {
   261  					p = p.HasReverse(rule.Pred)
   262  				} else {
   263  					p = p.Has(rule.Pred)
   264  				}
   265  			} else {
   266  				allOptional = false
   267  				if rule.Rev {
   268  					p = p.SaveReverse(rule.Pred, tag)
   269  				} else {
   270  					p = p.Save(rule.Pred, tag)
   271  				}
   272  			}
   273  		}
   274  	}
   275  	if allOptional {
   276  		p = p.And(alt.Unique())
   277  	}
   278  	if tagPref != "" {
   279  		return p, nil
   280  	}
   281  	m := l.pathForType
   282  	if rootOnly {
   283  		m = l.pathForTypeRoot
   284  	}
   285  	m[rt] = p
   286  	return p, nil
   287  }
   288  
   289  func (l *loader) loadToValue(ctx context.Context, dst reflect.Value, depth int, m map[string][]graph.Ref, tagPref string) error {
   290  	if ctx == nil {
   291  		ctx = context.TODO()
   292  	}
   293  	for dst.Kind() == reflect.Ptr {
   294  		dst = dst.Elem()
   295  	}
   296  	rt := dst.Type()
   297  	if rt.Kind() != reflect.Struct {
   298  		return fmt.Errorf("expected struct, got %v", rt)
   299  	}
   300  	var fields fieldRules
   301  	if v := ctx.Value(fieldsCtxKey{}); v != nil {
   302  		fields = v.(fieldRules)
   303  	} else {
   304  		nfields, err := l.c.rulesFor(rt)
   305  		if err != nil {
   306  			return err
   307  		}
   308  		fields = nfields
   309  	}
   310  	if depth != 0 { // do not check required fields if depth limit is reached
   311  		for name, field := range fields {
   312  			if r, ok := field.(saveRule); ok && !r.Opt {
   313  				if vals := m[name]; len(vals) == 0 {
   314  					return errRequiredFieldIsMissing
   315  				}
   316  			}
   317  		}
   318  	}
   319  	for i := 0; i < rt.NumField(); i++ {
   320  		select {
   321  		case <-ctx.Done():
   322  			return ctx.Err()
   323  		default:
   324  		}
   325  		f := rt.Field(i)
   326  		name := f.Name
   327  		if err := checkFieldType(f.Type); err != nil {
   328  			return err
   329  		}
   330  		df := dst.Field(i)
   331  		if f.Anonymous {
   332  			if err := l.loadToValue(ctx, df, depth, m, tagPref+name+"."); err != nil {
   333  				return fmt.Errorf("load anonymous field %s failed: %v", f.Name, err)
   334  			}
   335  			continue
   336  		}
   337  		rules := fields[tagPref+name]
   338  		if rules == nil {
   339  			continue
   340  		}
   341  		arr, ok := m[tagPref+name]
   342  		if !ok || len(arr) == 0 {
   343  			continue
   344  		}
   345  		ft := f.Type
   346  		native := isNative(ft)
   347  		ptr := ft.Kind() == reflect.Ptr
   348  		for ft.Kind() == reflect.Ptr || ft.Kind() == reflect.Slice {
   349  			ft = ft.Elem()
   350  			native = native || isNative(ft)
   351  			switch ft.Kind() {
   352  			case reflect.Ptr:
   353  				ptr = true
   354  			case reflect.Slice:
   355  				ptr = false
   356  			}
   357  		}
   358  		recursive := !native && ft.Kind() == reflect.Struct
   359  		for _, fv := range arr {
   360  			var sv reflect.Value
   361  			if recursive {
   362  				if ptr {
   363  					fv := l.qs.NameOf(fv)
   364  					var ok bool
   365  					sv, ok = l.seen[fv]
   366  					if ok && sv.Type().AssignableTo(f.Type) {
   367  						df.Set(sv)
   368  						continue
   369  					}
   370  				}
   371  				sv = reflect.New(ft).Elem()
   372  				err := l.loadIteratorToDepth(ctx, sv, depth-1, iterator.NewFixed(fv))
   373  				if err == errRequiredFieldIsMissing {
   374  					continue
   375  				} else if err != nil {
   376  					return err
   377  				}
   378  			} else {
   379  				fv := l.qs.NameOf(fv)
   380  				if fv == nil {
   381  					continue
   382  				}
   383  				sv = reflect.ValueOf(fv)
   384  			}
   385  			if err := DefaultConverter.SetValue(df, sv); err != nil {
   386  				return fmt.Errorf("field %s: %v", f.Name, err)
   387  			}
   388  		}
   389  	}
   390  	return nil
   391  }
   392  
   393  func (l *loader) iteratorForType(root graph.Iterator, rt reflect.Type, rootOnly bool) (graph.Iterator, error) {
   394  	p, err := l.makePathForType(rt, "", rootOnly)
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  	return l.iteratorFromPath(root, p)
   399  }
   400  
   401  func mergeMap(dst map[string][]graph.Ref, m map[string]graph.Ref) {
   402  loop:
   403  	for k, v := range m {
   404  		sl := dst[k]
   405  		for _, sv := range sl {
   406  			if keysEqual(sv, v) {
   407  				continue loop
   408  			}
   409  		}
   410  		dst[k] = append(sl, v)
   411  	}
   412  }
   413  
   414  func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, depth int, list graph.Iterator) error {
   415  	if ctx == nil {
   416  		ctx = context.TODO()
   417  	}
   418  	if dst.Kind() == reflect.Ptr {
   419  		dst = dst.Elem()
   420  	}
   421  	et := dst.Type()
   422  	slice, chanl := false, false
   423  	if dst.Kind() == reflect.Slice {
   424  		et = et.Elem()
   425  		slice = true
   426  	} else if dst.Kind() == reflect.Chan {
   427  		et = et.Elem()
   428  		chanl = true
   429  		defer dst.Close()
   430  	}
   431  	fields, err := l.c.rulesFor(et)
   432  	if err != nil {
   433  		return err
   434  	}
   435  
   436  	ctxDone := func() bool {
   437  		select {
   438  		case <-ctx.Done():
   439  			return true
   440  		default:
   441  		}
   442  		return false
   443  	}
   444  
   445  	if ctxDone() {
   446  		return ctx.Err()
   447  	}
   448  
   449  	rootOnly := depth == 0
   450  	it, err := l.iteratorForType(list, et, rootOnly)
   451  	if err != nil {
   452  		return err
   453  	}
   454  	defer it.Close()
   455  
   456  	ctx = context.WithValue(ctx, fieldsCtxKey{}, fields)
   457  	for it.Next(ctx) {
   458  		if ctxDone() {
   459  			return ctx.Err()
   460  		}
   461  		id := l.qs.NameOf(it.Result())
   462  		if id != nil {
   463  			if sv, ok := l.seen[id]; ok {
   464  				if slice {
   465  					dst.Set(reflect.Append(dst, sv.Elem()))
   466  				} else if chanl {
   467  					dst.Send(sv.Elem())
   468  				} else if dst.Kind() != reflect.Ptr {
   469  					dst.Set(sv.Elem())
   470  					return nil
   471  				} else {
   472  					dst.Set(sv)
   473  					return nil
   474  				}
   475  				continue
   476  			}
   477  		}
   478  		mp := make(map[string]graph.Ref)
   479  		it.TagResults(mp)
   480  		if len(mp) == 0 {
   481  			continue
   482  		}
   483  		cur := dst
   484  		if slice || chanl {
   485  			cur = reflect.New(et)
   486  		}
   487  		mo := make(map[string][]graph.Ref, len(mp))
   488  		for k, v := range mp {
   489  			mo[k] = []graph.Ref{v}
   490  		}
   491  		for it.NextPath(ctx) {
   492  			if ctxDone() {
   493  				return ctx.Err()
   494  			}
   495  			mp = make(map[string]graph.Ref)
   496  			it.TagResults(mp)
   497  			if len(mp) == 0 {
   498  				continue
   499  			}
   500  			// TODO(dennwc): replace with something more efficient
   501  			mergeMap(mo, mp)
   502  		}
   503  		if id != nil {
   504  			sv := cur
   505  			if sv.Kind() != reflect.Ptr && sv.CanAddr() {
   506  				sv = sv.Addr()
   507  			}
   508  			l.seen[id] = sv
   509  		}
   510  		err := l.loadToValue(ctx, cur, depth, mo, "")
   511  		if err == errRequiredFieldIsMissing {
   512  			if !slice && !chanl {
   513  				return err
   514  			}
   515  			continue
   516  		} else if err != nil {
   517  			return err
   518  		}
   519  		if slice {
   520  			dst.Set(reflect.Append(dst, cur.Elem()))
   521  		} else if chanl {
   522  			dst.Send(cur.Elem())
   523  		} else {
   524  			return nil
   525  		}
   526  	}
   527  	if err := it.Err(); err != nil {
   528  		return err
   529  	}
   530  	if slice || chanl {
   531  		return nil
   532  	}
   533  	if list != nil { // TODO(dennwc): optional optimization: do this only if iterator is not "all nodes"
   534  		// distinguish between missing object and type constraints
   535  		list.Reset()
   536  		and := iterator.NewAnd(list, l.qs.NodesAllIterator())
   537  		defer and.Close()
   538  		if and.Next(ctx) {
   539  			return errRequiredFieldIsMissing
   540  		}
   541  	}
   542  	return errNotFound
   543  }
   544  
   545  func (l *loader) iteratorFromPath(root graph.Iterator, p *path.Path) (graph.Iterator, error) {
   546  	it := p.BuildIteratorOn(l.qs)
   547  	if root != nil {
   548  		it = iterator.NewAnd(root, it)
   549  	}
   550  	if Optimize {
   551  		it, _ = it.Optimize()
   552  	}
   553  	return it, nil
   554  }