github.com/dgraph-io/dgraph@v1.2.8/graphql/resolve/query_rewriter.go (about)

     1  /*
     2   * Copyright 2019 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package resolve
    18  
    19  import (
    20  	"fmt"
    21  	"sort"
    22  	"strconv"
    23  
    24  	"github.com/dgraph-io/dgraph/gql"
    25  	"github.com/dgraph-io/dgraph/graphql/schema"
    26  	"github.com/dgraph-io/dgraph/protos/pb"
    27  	"github.com/pkg/errors"
    28  )
    29  
    30  type queryRewriter struct{}
    31  
    32  // NewQueryRewriter returns a new QueryRewriter.
    33  func NewQueryRewriter() QueryRewriter {
    34  	return &queryRewriter{}
    35  }
    36  
    37  // Rewrite rewrites a GraphQL query into a Dgraph GraphQuery.
    38  func (qr *queryRewriter) Rewrite(gqlQuery schema.Query) (*gql.GraphQuery, error) {
    39  
    40  	switch gqlQuery.QueryType() {
    41  	case schema.GetQuery:
    42  
    43  		// TODO: The only error that can occur in query rewriting is if an ID argument
    44  		// can't be parsed as a uid: e.g. the query was something like:
    45  		//
    46  		// getT(id: "HI") { ... }
    47  		//
    48  		// But that's not a rewriting error!  It should be caught by validation
    49  		// way up when the query first comes in.  All other possible problems with
    50  		// the query are caught by validation.
    51  		// ATM, I'm not sure how to hook into the GraphQL validator to get that to happen
    52  		xid, uid, err := gqlQuery.IDArgValue()
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  
    57  		dgQuery := rewriteAsGet(gqlQuery, uid, xid)
    58  		addTypeFilter(dgQuery, gqlQuery.Type())
    59  
    60  		return dgQuery, nil
    61  
    62  	case schema.FilterQuery:
    63  		return rewriteAsQuery(gqlQuery), nil
    64  	default:
    65  		return nil, errors.Errorf("unimplemented query type %s", gqlQuery.QueryType())
    66  	}
    67  }
    68  
    69  func intersection(a, b []uint64) []uint64 {
    70  	m := make(map[uint64]bool)
    71  	var c []uint64
    72  
    73  	for _, item := range a {
    74  		m[item] = true
    75  	}
    76  
    77  	for _, item := range b {
    78  		if _, ok := m[item]; ok {
    79  			c = append(c, item)
    80  		}
    81  	}
    82  
    83  	return c
    84  }
    85  
    86  // addUID adds UID for every node that we query. Otherwise we can't tell the
    87  // difference in a query result between a node that's missing and a node that's
    88  // missing a single value.  E.g. if we are asking for an Author and only the
    89  // 'text' of all their posts e.g. getAuthor(id: 0x123) { posts { text } }
    90  // If the author has 10 posts but three of them have a title, but no text,
    91  // then Dgraph would just return 7 posts.  And we'd have no way of knowing if
    92  // there's only 7 posts, or if there's more that are missing 'text'.
    93  // But, for GraphQL, we want to know about those missing values.
    94  func addUID(dgQuery *gql.GraphQuery) {
    95  	if len(dgQuery.Children) == 0 {
    96  		return
    97  	}
    98  	for _, c := range dgQuery.Children {
    99  		addUID(c)
   100  	}
   101  
   102  	uidChild := &gql.GraphQuery{
   103  		Attr:  "uid",
   104  		Alias: "dgraph.uid",
   105  	}
   106  	dgQuery.Children = append(dgQuery.Children, uidChild)
   107  }
   108  
   109  func rewriteAsQueryByIds(field schema.Field, uids []uint64) *gql.GraphQuery {
   110  	dgQuery := &gql.GraphQuery{
   111  		Attr: field.ResponseName(),
   112  		Func: &gql.Function{
   113  			Name: "uid",
   114  			UID:  uids,
   115  		},
   116  	}
   117  
   118  	if ids := idFilter(field, field.Type().IDField()); ids != nil {
   119  		addUIDFunc(dgQuery, intersection(ids, uids))
   120  	}
   121  
   122  	addArgumentsToField(dgQuery, field)
   123  	return dgQuery
   124  }
   125  
   126  // addArgumentsToField adds various different arguments to a field, such as
   127  // filter, order, pagination and selection set.
   128  func addArgumentsToField(dgQuery *gql.GraphQuery, field schema.Field) {
   129  	filter, _ := field.ArgValue("filter").(map[string]interface{})
   130  	addFilter(dgQuery, field.Type(), filter)
   131  	addOrder(dgQuery, field)
   132  	addPagination(dgQuery, field)
   133  	addSelectionSetFrom(dgQuery, field)
   134  	addUID(dgQuery)
   135  }
   136  
   137  func rewriteAsGet(field schema.Field, uid uint64, xid *string) *gql.GraphQuery {
   138  	if xid == nil {
   139  		return rewriteAsQueryByIds(field, []uint64{uid})
   140  	}
   141  
   142  	xidArgName := field.XIDArg()
   143  	eqXidFunc := &gql.Function{
   144  		Name: "eq",
   145  		Args: []gql.Arg{
   146  			{Value: xidArgName},
   147  			{Value: maybeQuoteArg("eq", *xid)},
   148  		},
   149  	}
   150  
   151  	var dgQuery *gql.GraphQuery
   152  	if uid > 0 {
   153  		dgQuery = &gql.GraphQuery{
   154  			Attr: field.ResponseName(),
   155  			Func: &gql.Function{
   156  				Name: "uid",
   157  				UID:  []uint64{uid},
   158  			},
   159  		}
   160  		dgQuery.Filter = &gql.FilterTree{
   161  			Func: eqXidFunc,
   162  		}
   163  
   164  	} else {
   165  		dgQuery = &gql.GraphQuery{
   166  			Attr: field.ResponseName(),
   167  			Func: eqXidFunc,
   168  		}
   169  	}
   170  	addSelectionSetFrom(dgQuery, field)
   171  	addUID(dgQuery)
   172  	return dgQuery
   173  }
   174  
   175  func rewriteAsQuery(field schema.Field) *gql.GraphQuery {
   176  	dgQuery := &gql.GraphQuery{
   177  		Attr: field.ResponseName(),
   178  	}
   179  
   180  	if ids := idFilter(field, field.Type().IDField()); ids != nil {
   181  		addUIDFunc(dgQuery, ids)
   182  	} else {
   183  		addTypeFunc(dgQuery, field.Type().DgraphName())
   184  	}
   185  
   186  	addArgumentsToField(dgQuery, field)
   187  	return dgQuery
   188  }
   189  
   190  func addTypeFilter(q *gql.GraphQuery, typ schema.Type) {
   191  	thisFilter := &gql.FilterTree{
   192  		Func: &gql.Function{
   193  			Name: "type",
   194  			Args: []gql.Arg{{Value: typ.DgraphName()}},
   195  		},
   196  	}
   197  
   198  	if q.Filter == nil {
   199  		q.Filter = thisFilter
   200  	} else {
   201  		q.Filter = &gql.FilterTree{
   202  			Op:    "and",
   203  			Child: []*gql.FilterTree{q.Filter, thisFilter},
   204  		}
   205  	}
   206  }
   207  
   208  func addUIDFunc(q *gql.GraphQuery, uids []uint64) {
   209  	q.Func = &gql.Function{
   210  		Name: "uid",
   211  		UID:  uids,
   212  	}
   213  }
   214  
   215  func addTypeFunc(q *gql.GraphQuery, typ string) {
   216  	q.Func = &gql.Function{
   217  		Name: "type",
   218  		Args: []gql.Arg{{Value: typ}},
   219  	}
   220  
   221  }
   222  
   223  func addSelectionSetFrom(q *gql.GraphQuery, field schema.Field) {
   224  	// Only add dgraph.type as a child if this field is an interface type and has some children.
   225  	// dgraph.type would later be used in completeObject as different objects in the resulting
   226  	// JSON would return different fields based on their concrete type.
   227  	if field.InterfaceType() && len(field.SelectionSet()) > 0 {
   228  		q.Children = append(q.Children, &gql.GraphQuery{
   229  			Attr: "dgraph.type",
   230  		})
   231  	}
   232  	for _, f := range field.SelectionSet() {
   233  		// We skip typename because we can generate the information from schema or
   234  		// dgraph.type depending upon if the type is interface or not. For interface type
   235  		// we always query dgraph.type and can pick up the value from there.
   236  		if f.Skip() || !f.Include() || f.Name() == schema.Typename {
   237  			continue
   238  		}
   239  
   240  		child := &gql.GraphQuery{}
   241  
   242  		if f.Alias() != "" {
   243  			child.Alias = f.Alias()
   244  		} else {
   245  			child.Alias = f.Name()
   246  		}
   247  
   248  		if f.Type().Name() == schema.IDType {
   249  			child.Attr = "uid"
   250  		} else {
   251  			child.Attr = f.DgraphPredicate()
   252  		}
   253  
   254  		filter, _ := f.ArgValue("filter").(map[string]interface{})
   255  		addFilter(child, f.Type(), filter)
   256  		addOrder(child, f)
   257  		addPagination(child, f)
   258  
   259  		addSelectionSetFrom(child, f)
   260  
   261  		q.Children = append(q.Children, child)
   262  	}
   263  }
   264  
   265  func addOrder(q *gql.GraphQuery, field schema.Field) {
   266  	orderArg := field.ArgValue("order")
   267  	order, ok := orderArg.(map[string]interface{})
   268  	for ok {
   269  		ascArg := order["asc"]
   270  		descArg := order["desc"]
   271  		thenArg := order["then"]
   272  
   273  		if asc, ok := ascArg.(string); ok {
   274  			q.Order = append(q.Order,
   275  				&pb.Order{Attr: field.Type().DgraphPredicate(asc)})
   276  		} else if desc, ok := descArg.(string); ok {
   277  			q.Order = append(q.Order,
   278  				&pb.Order{Attr: field.Type().DgraphPredicate(desc), Desc: true})
   279  		}
   280  
   281  		order, ok = thenArg.(map[string]interface{})
   282  	}
   283  }
   284  
   285  func addPagination(q *gql.GraphQuery, field schema.Field) {
   286  	q.Args = make(map[string]string)
   287  
   288  	first := field.ArgValue("first")
   289  	if first != nil {
   290  		q.Args["first"] = fmt.Sprintf("%v", first)
   291  	}
   292  
   293  	offset := field.ArgValue("offset")
   294  	if offset != nil {
   295  		q.Args["offset"] = fmt.Sprintf("%v", offset)
   296  	}
   297  }
   298  
   299  func convertIDs(idsSlice []interface{}) []uint64 {
   300  	ids := make([]uint64, 0, len(idsSlice))
   301  	for _, id := range idsSlice {
   302  		uid, err := strconv.ParseUint(id.(string), 0, 64)
   303  		if err != nil {
   304  			// Skip sending the is part of the query to Dgraph.
   305  			continue
   306  		}
   307  		ids = append(ids, uid)
   308  	}
   309  	return ids
   310  }
   311  
   312  func idFilter(field schema.Field, idField schema.FieldDefinition) []uint64 {
   313  	filter, ok := field.ArgValue("filter").(map[string]interface{})
   314  	if !ok || idField == nil {
   315  		return nil
   316  	}
   317  
   318  	idsFilter := filter[idField.Name()]
   319  	if idsFilter == nil {
   320  		return nil
   321  	}
   322  	idsSlice := idsFilter.([]interface{})
   323  	return convertIDs(idsSlice)
   324  }
   325  
   326  func addFilter(q *gql.GraphQuery, typ schema.Type, filter map[string]interface{}) {
   327  	if len(filter) == 0 {
   328  		return
   329  	}
   330  
   331  	// There are two cases here.
   332  	// 1. It could be the case of a filter at root.  In this case we would have added a uid
   333  	// function at root. Lets delete the ids key so that it isn't added in the filter.
   334  	// Also, we need to add a dgraph.type filter.
   335  	// 2. This could be a deep filter. In that case we don't need to do anything special.
   336  	idField := typ.IDField()
   337  	idName := ""
   338  	if idField != nil {
   339  		idName = idField.Name()
   340  	}
   341  
   342  	_, hasIDsFilter := filter[idName]
   343  	filterAtRoot := hasIDsFilter && q.Func != nil && q.Func.Name == "uid"
   344  	if filterAtRoot {
   345  		// If id was present as a filter,
   346  		delete(filter, idName)
   347  	}
   348  	q.Filter = buildFilter(typ, filter)
   349  	if filterAtRoot {
   350  		addTypeFilter(q, typ)
   351  	}
   352  }
   353  
   354  // buildFilter builds a Dgraph gql.FilterTree from a GraphQL 'filter' arg.
   355  //
   356  // All the 'filter' args built by the GraphQL layer look like
   357  // filter: { title: { anyofterms: "GraphQL" }, ... }
   358  // or
   359  // filter: { title: { anyofterms: "GraphQL" }, isPublished: true, ... }
   360  // or
   361  // filter: { title: { anyofterms: "GraphQL" }, and: { not: { ... } } }
   362  // etc
   363  //
   364  // typ is the GraphQL type we are filtering on, and is needed to turn for example
   365  // title (the GraphQL field) into Post.title (to Dgraph predicate).
   366  //
   367  // buildFilter turns any one filter object into a conjunction
   368  // eg:
   369  // filter: { title: { anyofterms: "GraphQL" }, isPublished: true }
   370  // into:
   371  // @filter(anyofterms(Post.title, "GraphQL") AND eq(Post.isPublished, true))
   372  //
   373  // Filters with `or:` and `not:` get translated to Dgraph OR and NOT.
   374  //
   375  // TODO: There's cases that don't make much sense like
   376  // filter: { or: { title: { anyofterms: "GraphQL" } } }
   377  // ATM those will probably generate junk that might cause a Dgraph error.  And
   378  // bubble back to the user as a GraphQL error when the query fails. Really,
   379  // they should fail query validation and never get here.
   380  func buildFilter(typ schema.Type, filter map[string]interface{}) *gql.FilterTree {
   381  
   382  	var ands []*gql.FilterTree
   383  	var or *gql.FilterTree
   384  
   385  	// Get a stable ordering so we generate the same thing each time.
   386  	var keys []string
   387  	for key := range filter {
   388  		keys = append(keys, key)
   389  	}
   390  	sort.Strings(keys)
   391  
   392  	// Each key in filter is either "and", "or", "not" or the field name it
   393  	// applies to such as "title" in: `title: { anyofterms: "GraphQL" }``
   394  	for _, field := range keys {
   395  		switch field {
   396  
   397  		// In 'and', 'or' and 'not' cases, filter[field] must be a map[string]interface{}
   398  		// or it would have failed GraphQL validation - e.g. 'filter: { and: 10 }'
   399  		// would have failed validation.
   400  
   401  		case "and":
   402  			// title: { anyofterms: "GraphQL" }, and: { ... }
   403  			//                       we are here ^^
   404  			// ->
   405  			// @filter(anyofterms(Post.title, "GraphQL") AND ... )
   406  			ft := buildFilter(typ, filter[field].(map[string]interface{}))
   407  			ands = append(ands, ft)
   408  		case "or":
   409  			// title: { anyofterms: "GraphQL" }, or: { ... }
   410  			//                       we are here ^^
   411  			// ->
   412  			// @filter(anyofterms(Post.title, "GraphQL") OR ... )
   413  			or = buildFilter(typ, filter[field].(map[string]interface{}))
   414  		case "not":
   415  			// title: { anyofterms: "GraphQL" }, not: { isPublished: true}
   416  			//                       we are here ^^
   417  			// ->
   418  			// @filter(anyofterms(Post.title, "GraphQL") AND NOT eq(Post.isPublished, true))
   419  			not := buildFilter(typ, filter[field].(map[string]interface{}))
   420  			ands = append(ands,
   421  				&gql.FilterTree{
   422  					Op:    "not",
   423  					Child: []*gql.FilterTree{not},
   424  				})
   425  		default:
   426  			// It's a base case like:
   427  			// title: { anyofterms: "GraphQL" } ->  anyofterms(Post.title: "GraphQL")
   428  
   429  			switch dgFunc := filter[field].(type) {
   430  			case map[string]interface{}:
   431  				// title: { anyofterms: "GraphQL" } ->  anyofterms(Post.title, "GraphQL")
   432  				// OR
   433  				// numLikes: { le: 10 } -> le(Post.numLikes, 10)
   434  				fn, val := first(dgFunc)
   435  				ands = append(ands, &gql.FilterTree{
   436  					Func: &gql.Function{
   437  						Name: fn,
   438  						Args: []gql.Arg{
   439  							{Value: typ.DgraphPredicate(field)},
   440  							{Value: maybeQuoteArg(fn, val)},
   441  						},
   442  					},
   443  				})
   444  			case []interface{}:
   445  				// ids: [ 0x123, 0x124 ] -> uid(0x123, 0x124)
   446  				ids := convertIDs(dgFunc)
   447  				ands = append(ands, &gql.FilterTree{
   448  					Func: &gql.Function{
   449  						Name: "uid",
   450  						UID:  ids,
   451  					},
   452  				})
   453  			case interface{}:
   454  				// isPublished: true -> eq(Post.isPublished, true)
   455  				// OR an enum case
   456  				// postType: Question -> eq(Post.postType, "Question")
   457  				fn := "eq"
   458  				ands = append(ands, &gql.FilterTree{
   459  					Func: &gql.Function{
   460  						Name: fn,
   461  						Args: []gql.Arg{
   462  							{Value: typ.DgraphPredicate(field)},
   463  							{Value: fmt.Sprintf("%v", dgFunc)},
   464  						},
   465  					},
   466  				})
   467  			}
   468  		}
   469  	}
   470  
   471  	var andFt *gql.FilterTree
   472  	if len(ands) == 1 {
   473  		andFt = ands[0]
   474  	} else if len(ands) > 1 {
   475  		andFt = &gql.FilterTree{
   476  			Op:    "and",
   477  			Child: ands,
   478  		}
   479  	}
   480  
   481  	if or == nil {
   482  		return andFt
   483  	}
   484  
   485  	return &gql.FilterTree{
   486  		Op:    "or",
   487  		Child: []*gql.FilterTree{andFt, or},
   488  	}
   489  }
   490  
   491  func maybeQuoteArg(fn string, arg interface{}) string {
   492  	switch arg := arg.(type) {
   493  	case string: // dateTime also parsed as string
   494  		if fn == "regexp" {
   495  			return arg
   496  		}
   497  		return fmt.Sprintf("%q", arg)
   498  	default:
   499  		return fmt.Sprintf("%v", arg)
   500  	}
   501  }
   502  
   503  // fst returns the first element it finds in a map - we bump into lots of one-element
   504  // maps like { "anyofterms": "GraphQL" }.  fst helps extract that single mapping.
   505  func first(aMap map[string]interface{}) (string, interface{}) {
   506  	for key, val := range aMap {
   507  		return key, val
   508  	}
   509  	return "", nil
   510  }