github.com/cayleygraph/cayley@v0.7.7/query/graphql/graphql.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"strconv"
    10  	"strings"
    11  	"unicode"
    12  
    13  	"github.com/dennwc/graphql/language/ast"
    14  	"github.com/dennwc/graphql/language/lexer"
    15  	"github.com/dennwc/graphql/language/parser"
    16  
    17  	"github.com/cayleygraph/cayley/graph"
    18  	"github.com/cayleygraph/cayley/graph/path"
    19  	"github.com/cayleygraph/cayley/query"
    20  	"github.com/cayleygraph/quad"
    21  )
    22  
    23  const Name = "graphql"
    24  
    25  // GraphQL charset: [_A-Za-z][_0-9A-Za-z]*
    26  // (https://facebook.github.io/graphql/#sec-Names)
    27  
    28  // IRI charset: [^#x00-#x20<>"{}|^`\]
    29  // (https://www.w3.org/TR/turtle/#grammar-production-IRIREF)
    30  
    31  func allowedNameRune(r rune) bool {
    32  	// will include <> in the IRI value
    33  	return r > 0x20 && !strings.ContainsRune("\"{}()|^`", r) && !unicode.IsSpace(r)
    34  }
    35  
    36  func init() {
    37  	lexer.AllowNameRunes = allowedNameRune
    38  
    39  	query.RegisterLanguage(query.Language{
    40  		Name: Name,
    41  		Session: func(qs graph.QuadStore) query.Session {
    42  			return NewSession(qs)
    43  		},
    44  		REPL: func(qs graph.QuadStore) query.REPLSession {
    45  			return NewSession(qs)
    46  		},
    47  		HTTPError: httpError,
    48  		HTTPQuery: httpQuery,
    49  	})
    50  }
    51  
    52  func NewSession(qs graph.QuadStore) *Session {
    53  	return &Session{qs: qs}
    54  }
    55  
    56  type Session struct {
    57  	qs graph.QuadStore
    58  }
    59  
    60  func (s *Session) Execute(ctx context.Context, qu string, opt query.Options) (query.Iterator, error) {
    61  	switch opt.Collation {
    62  	case query.Raw, query.JSON, query.REPL:
    63  	default:
    64  		return nil, &query.ErrUnsupportedCollation{Collation: opt.Collation}
    65  	}
    66  	q, err := Parse(strings.NewReader(qu))
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	return &results{
    71  		s:   s,
    72  		q:   q,
    73  		col: opt.Collation,
    74  	}, nil
    75  }
    76  
    77  type results struct {
    78  	s   *Session
    79  	q   *Query
    80  	col query.Collation
    81  	res map[string]interface{}
    82  	err error
    83  }
    84  
    85  func (it *results) Next(ctx context.Context) bool {
    86  	if it.q == nil {
    87  		return false
    88  	}
    89  	it.res, it.err = it.q.Execute(ctx, it.s.qs)
    90  	it.q = nil
    91  	return it.err == nil && len(it.res) != 0
    92  }
    93  
    94  func (it *results) Result() interface{} {
    95  	if len(it.res) == 0 {
    96  		return nil
    97  	}
    98  	if it.col != query.REPL {
    99  		return it.res
   100  	}
   101  	data, _ := json.MarshalIndent(it.res, "", "   ")
   102  	return string(data)
   103  }
   104  
   105  func (it *results) Err() error {
   106  	return it.err
   107  }
   108  
   109  func (it *results) Close() error {
   110  	it.q = nil
   111  	return nil
   112  }
   113  
   114  // Configurable keywords and special field names.
   115  var (
   116  	ValueKey = "id"
   117  	LimitKey = "first"
   118  	SkipKey  = "offset"
   119  	AnyKey   = "*"
   120  )
   121  
   122  type Query struct {
   123  	fields []field
   124  }
   125  
   126  type has struct {
   127  	Via    quad.IRI
   128  	Rev    bool
   129  	Values []quad.Value
   130  	Labels []quad.Value
   131  }
   132  
   133  type field struct {
   134  	Via       quad.IRI
   135  	Alias     string
   136  	Rev       bool
   137  	Opt       bool
   138  	Labels    []quad.Value
   139  	Has       []has
   140  	Fields    []field
   141  	AllFields bool // fetch all fields
   142  	UnNest    bool // all fields will be saved to parent object
   143  }
   144  
   145  func (f field) isSave() bool { return len(f.Has)+len(f.Fields) == 0 && !f.AllFields }
   146  
   147  type object struct {
   148  	id     graph.Ref
   149  	fields map[string]interface{}
   150  }
   151  
   152  func buildIterator(qs graph.QuadStore, p *path.Path) graph.Iterator {
   153  	it, _ := p.BuildIterator().Optimize()
   154  	return it
   155  }
   156  
   157  func iterateObject(ctx context.Context, qs graph.QuadStore, f *field, p *path.Path) (out []map[string]interface{}, _ error) {
   158  	if len(f.Labels) != 0 {
   159  		p = p.LabelContext(f.Labels)
   160  	} else {
   161  		p = p.LabelContext()
   162  	}
   163  	var (
   164  		limit = -1
   165  		skip  = 0
   166  	)
   167  
   168  	for _, h := range f.Has {
   169  		switch h.Via {
   170  		case quad.IRI(ValueKey): // special key - "id"
   171  			p = p.Is(h.Values...)
   172  		case quad.IRI(LimitKey), quad.IRI(SkipKey): // limit and skip
   173  			if len(h.Values) != 1 {
   174  				return nil, fmt.Errorf("unexpected arguments: %v (%d)", h.Values, len(h.Values))
   175  			}
   176  			n, ok := h.Values[0].(quad.Int)
   177  			if !ok {
   178  				return nil, fmt.Errorf("unexpected value type for %v: %T", string(h.Via), h.Values[0])
   179  			}
   180  			if h.Via == quad.IRI(LimitKey) {
   181  				limit = int(n)
   182  			} else {
   183  				skip = int(n)
   184  				if skip < 0 {
   185  					skip = 0
   186  				}
   187  			}
   188  		default: // everything else - Has constraint
   189  			if len(h.Labels) != 0 {
   190  				p = p.LabelContext(h.Labels)
   191  			}
   192  			if h.Rev {
   193  				p = p.HasReverse(h.Via, h.Values...)
   194  			} else {
   195  				p = p.Has(h.Via, h.Values...)
   196  			}
   197  			if len(h.Labels) != 0 {
   198  				p = p.LabelContext()
   199  			}
   200  		}
   201  	}
   202  	tail := func() {
   203  		if skip > 0 {
   204  			p = p.Skip(int64(skip))
   205  		}
   206  		if limit >= 0 {
   207  			p = p.Limit(int64(limit))
   208  		}
   209  	}
   210  	if f.AllFields {
   211  		tail()
   212  
   213  		it := buildIterator(qs, p)
   214  		defer it.Close()
   215  
   216  		// we don't care about alternative paths to nodes here, so we will not call NextPath
   217  		// and we haven't tagged anything, so we will not call TagResult either
   218  		for i := 0; limit < 0 || i < limit; i++ {
   219  			select {
   220  			case <-ctx.Done():
   221  				return out, ctx.Err()
   222  			default:
   223  			}
   224  			if !it.Next(ctx) {
   225  				break
   226  			}
   227  			nv := it.Result()
   228  			obj := make(map[string]interface{})
   229  			obj[ValueKey] = qs.NameOf(nv)
   230  			func() {
   231  				sit := qs.QuadIterator(quad.Subject, nv)
   232  				defer sit.Close()
   233  				for sit.Next(ctx) {
   234  					q := qs.Quad(sit.Result())
   235  					if p, ok := q.Predicate.(quad.IRI); ok {
   236  						obj[string(p)] = q.Object
   237  					} else {
   238  						obj[quad.ToString(q.Predicate)] = q.Object
   239  					}
   240  				}
   241  			}()
   242  			out = append(out, obj)
   243  		}
   244  		return out, it.Err()
   245  	}
   246  	unnest := make(map[string]bool)
   247  	for _, f2 := range f.Fields {
   248  		if f2.UnNest {
   249  			unnest[f2.Alias] = true
   250  		}
   251  		if !f2.isSave() {
   252  			continue
   253  		}
   254  		if f2.Via == quad.IRI(ValueKey) {
   255  			p = p.Tag(f2.Alias)
   256  			continue
   257  		}
   258  		if len(f2.Labels) != 0 {
   259  			p = p.LabelContext(f2.Labels)
   260  		}
   261  		if f2.Opt {
   262  			if f2.Rev {
   263  				p = p.SaveOptionalReverse(f2.Via, f2.Alias)
   264  			} else {
   265  				p = p.SaveOptional(f2.Via, f2.Alias)
   266  			}
   267  		} else {
   268  			if f2.Rev {
   269  				p = p.SaveReverse(f2.Via, f2.Alias)
   270  			} else {
   271  				p = p.Save(f2.Via, f2.Alias)
   272  			}
   273  		}
   274  		if len(f2.Labels) != 0 {
   275  			p = p.LabelContext()
   276  		}
   277  	}
   278  	tail()
   279  
   280  	// first, collect result node ids and any tags associated with it (flat values)
   281  	it := buildIterator(qs, p)
   282  	defer it.Close()
   283  
   284  	var results []object
   285  	for i := 0; limit < 0 || i < limit; i++ {
   286  		select {
   287  		case <-ctx.Done():
   288  			return out, ctx.Err()
   289  		default:
   290  		}
   291  		if !it.Next(ctx) {
   292  			break
   293  		}
   294  		fields := make(map[string][]graph.Ref)
   295  
   296  		tags := make(map[string]graph.Ref)
   297  		it.TagResults(tags)
   298  		for k, v := range tags {
   299  			fields[k] = []graph.Ref{v}
   300  		}
   301  		for it.NextPath(ctx) {
   302  			select {
   303  			case <-ctx.Done():
   304  				return out, ctx.Err()
   305  			default:
   306  			}
   307  			tags = make(map[string]graph.Ref)
   308  			it.TagResults(tags)
   309  		dedup:
   310  			for k, v := range tags {
   311  				vals := fields[k]
   312  				for _, v2 := range vals {
   313  					if graph.ToKey(v) == graph.ToKey(v2) {
   314  						continue dedup
   315  					}
   316  				}
   317  				fields[k] = append(vals, v)
   318  			}
   319  		}
   320  		obj := object{id: it.Result()}
   321  		if len(fields) > 0 {
   322  			obj.fields = make(map[string]interface{}, len(fields))
   323  			for k, arr := range fields {
   324  				vals, err := graph.ValuesOf(ctx, qs, arr)
   325  				if err != nil {
   326  					return nil, err
   327  				}
   328  				if len(vals) == 1 {
   329  					obj.fields[k] = vals[0]
   330  				} else {
   331  					obj.fields[k] = vals
   332  				}
   333  			}
   334  		}
   335  		results = append(results, obj)
   336  	}
   337  	if err := it.Err(); err != nil {
   338  		return out, err
   339  	}
   340  
   341  	// next, load complex objects inside fields
   342  	for _, r := range results {
   343  		obj := r.fields
   344  		if obj == nil {
   345  			obj = make(map[string]interface{})
   346  		}
   347  		for _, f2 := range f.Fields {
   348  			if f2.isSave() {
   349  				continue // skip flat values
   350  			}
   351  			// start from saved id for a field node
   352  			p2 := path.StartPathNodes(qs, r.id)
   353  			if len(f2.Labels) != 0 {
   354  				p2 = p2.LabelContext(f2.Labels)
   355  			}
   356  			if f2.Rev {
   357  				p2 = p2.In(f2.Via)
   358  			} else {
   359  				p2 = p2.Out(f2.Via)
   360  			}
   361  			if len(f2.Labels) != 0 {
   362  				p2 = p2.LabelContext()
   363  			}
   364  			arr, err := iterateObject(ctx, qs, &f2, p2)
   365  			if err != nil {
   366  				return out, err
   367  			}
   368  			if f2.UnNest {
   369  				if len(arr) > 1 {
   370  					return nil, fmt.Errorf("cannot unnest more than one object on %q; use (%s: 1) to force",
   371  						f2.Alias, LimitKey)
   372  				} else if len(arr) == 0 {
   373  					continue
   374  				}
   375  				for k, v := range arr[0] {
   376  					obj[k] = v
   377  				}
   378  			} else {
   379  				var v interface{}
   380  				if len(arr) == 1 {
   381  					v = arr[0]
   382  				} else if len(arr) > 1 {
   383  					v = arr
   384  				}
   385  				obj[f2.Alias] = v
   386  			}
   387  		}
   388  		out = append(out, obj)
   389  	}
   390  	return out, nil
   391  }
   392  
   393  func (q *Query) Execute(ctx context.Context, qs graph.QuadStore) (map[string]interface{}, error) {
   394  	out := make(map[string]interface{})
   395  	for _, f := range q.fields {
   396  		arr, err := iterateObject(ctx, qs, &f, path.StartPath(qs))
   397  		if err != nil {
   398  			return out, err
   399  		}
   400  		var v interface{}
   401  		if len(arr) == 1 {
   402  			v = arr[0]
   403  		} else if len(arr) > 1 {
   404  			v = arr
   405  		}
   406  		out[f.Alias] = v
   407  	}
   408  	return out, nil
   409  }
   410  
   411  func Parse(r io.Reader) (*Query, error) {
   412  	data, err := ioutil.ReadAll(r)
   413  	if err != nil {
   414  		return nil, err
   415  	}
   416  	doc, err := parser.Parse(parser.ParseParams{Source: string(data)})
   417  	if err != nil {
   418  		return nil, err
   419  	}
   420  	if len(doc.Definitions) != 1 {
   421  		return nil, fmt.Errorf("unsupported query type")
   422  	}
   423  	def, ok := doc.Definitions[0].(*ast.OperationDefinition)
   424  	if !ok {
   425  		return nil, fmt.Errorf("unsupported query type: %T", doc.Definitions[0])
   426  	} else if def.Operation != "query" {
   427  		return nil, fmt.Errorf("unsupported operation: %s", def.Operation)
   428  	}
   429  	fields, all, err := setToFields(def.SelectionSet, nil)
   430  	if err != nil {
   431  		return nil, err
   432  	} else if all {
   433  		return nil, fmt.Errorf("expand all is not supported at top level")
   434  	}
   435  	return &Query{fields: fields}, nil
   436  }
   437  
   438  func setToFields(set *ast.SelectionSet, labels []quad.Value) (out []field, all bool, _ error) {
   439  	if set == nil {
   440  		return
   441  	}
   442  	for _, s := range set.Selections {
   443  		switch sel := s.(type) {
   444  		case *ast.Field:
   445  			fld, err := convField(sel, labels)
   446  			if err != nil {
   447  				return nil, false, err
   448  			}
   449  			if fld.Via == quad.IRI(AnyKey) {
   450  				if len(set.Selections) != 1 {
   451  					return nil, false, fmt.Errorf("expand all cannot be used with other fields")
   452  				} else if len(fld.Has) != 0 || len(fld.Fields) != 0 {
   453  					return nil, false, fmt.Errorf("filters inside expand all are not supported")
   454  				}
   455  				return nil, true, nil
   456  			}
   457  			out = append(out, fld)
   458  		default:
   459  			return nil, false, fmt.Errorf("unknown selection type: %T", s)
   460  		}
   461  	}
   462  	return
   463  }
   464  
   465  func stringToVia(s string) (_ quad.IRI, rev bool) {
   466  	if len(s) > 0 && s[0] == '~' {
   467  		rev = true
   468  		s = s[1:]
   469  	}
   470  	if len(s) > 2 && s[0] == '<' && s[len(s)-1] == '>' {
   471  		s = s[1 : len(s)-1]
   472  	}
   473  	return quad.IRI(s), rev
   474  }
   475  
   476  func argsToHas(dst []has, args []*ast.Argument, rev bool, labels []quad.Value) (out []has, err error) {
   477  	out = dst
   478  	for _, arg := range args {
   479  		var vals []quad.Value
   480  		vals, err = convValue(arg.Value)
   481  		if err != nil {
   482  			return
   483  		}
   484  		h := has{Values: vals, Labels: labels}
   485  		h.Via, h.Rev = stringToVia(arg.Name.Value)
   486  		h.Rev = h.Rev != rev
   487  		out = append(out, h)
   488  	}
   489  	return
   490  }
   491  
   492  func convField(fld *ast.Field, labels []quad.Value) (out field, err error) {
   493  	out.Labels = labels
   494  	name := fld.Name.Value
   495  	if fld.Alias != nil && fld.Alias.Value != "" {
   496  		out.Alias = fld.Alias.Value
   497  	} else {
   498  		out.Alias = name
   499  	}
   500  	out.Via, out.Rev = stringToVia(name)
   501  	// first check for "label" directive - it will affect all traversals
   502  	for _, d := range fld.Directives {
   503  		if d.Name == nil {
   504  			continue
   505  		}
   506  		switch d.Name.Value {
   507  		case "label":
   508  			if len(d.Arguments) == 0 {
   509  				out.Labels = nil
   510  			} else if len(d.Arguments) > 1 {
   511  				return out, fmt.Errorf("label directive should have 0 or 1 argument")
   512  			} else if a := d.Arguments[0]; a.Name == nil || a.Name.Value != "v" {
   513  				return out, fmt.Errorf("label directive should have 'v' argument")
   514  			} else {
   515  				vals, err := convValue(a.Value)
   516  				if err != nil {
   517  					return out, fmt.Errorf("error parsing label: %v", err)
   518  				}
   519  				out.Labels = vals
   520  			}
   521  		}
   522  	}
   523  	for _, d := range fld.Directives {
   524  		if d.Name == nil {
   525  			continue
   526  		}
   527  		switch d.Name.Value {
   528  		case "rev", "reverse":
   529  			if len(d.Arguments) == 0 {
   530  				out.Rev = out.Rev != true
   531  			} else {
   532  				out.Has, err = argsToHas(out.Has, d.Arguments, true, out.Labels)
   533  				if err != nil {
   534  					return
   535  				}
   536  			}
   537  		case "opt", "optional":
   538  			out.Opt = true
   539  		case "label":
   540  			// already processed
   541  		case "unnest":
   542  			out.UnNest = true
   543  		default:
   544  			return out, fmt.Errorf("unknown directive: %q", d.Name.Value)
   545  		}
   546  	}
   547  	out.Fields, out.AllFields, err = setToFields(fld.SelectionSet, out.Labels)
   548  	if err != nil {
   549  		return
   550  	}
   551  	out.Has, err = argsToHas(out.Has, fld.Arguments, false, out.Labels)
   552  	if err != nil {
   553  		return
   554  	}
   555  	return
   556  }
   557  
   558  func convValue(v ast.Value) (out []quad.Value, _ error) {
   559  	switch v := v.(type) {
   560  	case *ast.EnumValue:
   561  		s := v.Value
   562  		if len(s) > 2 && s[0] == '<' && s[len(s)-1] == '>' {
   563  			s = s[1 : len(s)-1]
   564  		}
   565  		if len(s) > 2 && s[0] == '_' && s[1] == ':' {
   566  			return []quad.Value{quad.BNode(s[2:])}, nil
   567  		}
   568  		return []quad.Value{quad.IRI(s)}, nil
   569  	case *ast.StringValue:
   570  		return []quad.Value{quad.StringToValue(v.Value)}, nil
   571  	case *ast.IntValue:
   572  		pv, _ := strconv.Atoi(v.Value)
   573  		return []quad.Value{quad.Int(pv)}, nil
   574  	case *ast.FloatValue:
   575  		pv, _ := strconv.ParseFloat(v.Value, 64)
   576  		return []quad.Value{quad.Float(pv)}, nil
   577  	case *ast.BooleanValue:
   578  		return []quad.Value{quad.Bool(v.Value)}, nil
   579  	case *ast.ListValue:
   580  		for _, sv := range v.Values {
   581  			cv, err := convValue(sv)
   582  			if err != nil {
   583  				return nil, err
   584  			} else if len(cv) != 1 {
   585  				return nil, fmt.Errorf("unexpected value array in list: %v (%d)", cv, len(cv))
   586  			}
   587  			out = append(out, cv[0])
   588  		}
   589  		return
   590  	default:
   591  		return nil, fmt.Errorf("unsupported value type: %T", v)
   592  	}
   593  }