github.com/Desuuuu/genqlient@v0.5.3/generate/genqlient_directive.go (about)

     1  package generate
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/vektah/gqlparser/v2/ast"
     8  	"github.com/vektah/gqlparser/v2/parser"
     9  )
    10  
    11  // Represents the genqlient directive, described in detail in
    12  // docs/genqlient_directive.graphql.
    13  type genqlientDirective struct {
    14  	pos       *ast.Position
    15  	Omitempty *bool
    16  	Pointer   *bool
    17  	Struct    *bool
    18  	Flatten   *bool
    19  	Bind      string
    20  	TypeName  string
    21  	// FieldDirectives contains the directives to be
    22  	// applied to specific fields via the "for" option.
    23  	// Map from type-name -> field-name -> directive.
    24  	FieldDirectives map[string]map[string]*genqlientDirective
    25  }
    26  
    27  func newGenqlientDirective(pos *ast.Position) *genqlientDirective {
    28  	return &genqlientDirective{
    29  		pos:             pos,
    30  		FieldDirectives: make(map[string]map[string]*genqlientDirective),
    31  	}
    32  }
    33  
    34  // Helper for String, returns the directive but without the @genqlient().
    35  func (dir *genqlientDirective) argsString() string {
    36  	var parts []string
    37  	if dir.Omitempty != nil {
    38  		parts = append(parts, fmt.Sprintf("omitempty: %v", *dir.Omitempty))
    39  	}
    40  	if dir.Pointer != nil {
    41  		parts = append(parts, fmt.Sprintf("pointer: %v", *dir.Pointer))
    42  	}
    43  	if dir.Struct != nil {
    44  		parts = append(parts, fmt.Sprintf("struct: %v", *dir.Struct))
    45  	}
    46  	if dir.Flatten != nil {
    47  		parts = append(parts, fmt.Sprintf("flatten: %v", *dir.Flatten))
    48  	}
    49  	if dir.Bind != "" {
    50  		parts = append(parts, fmt.Sprintf("bind: %v", dir.Bind))
    51  	}
    52  	if dir.TypeName != "" {
    53  		parts = append(parts, fmt.Sprintf("typename: %v", dir.TypeName))
    54  	}
    55  	return strings.Join(parts, ", ")
    56  }
    57  
    58  // String is useful for debugging.
    59  func (dir *genqlientDirective) String() string {
    60  	lines := []string{fmt.Sprintf("@genqlient(%s)", dir.argsString())}
    61  	for typeName, dirs := range dir.FieldDirectives {
    62  		for fieldName, fieldDir := range dirs {
    63  			lines = append(lines, fmt.Sprintf("@genqlient(for: %s.%s, %s)",
    64  				typeName, fieldName, fieldDir.argsString()))
    65  		}
    66  	}
    67  	return strings.Join(lines, "\n")
    68  }
    69  
    70  func (dir *genqlientDirective) GetPointer(def bool) bool {
    71  	if dir.Pointer == nil {
    72  		return def
    73  	}
    74  	return *dir.Pointer
    75  }
    76  
    77  func (dir *genqlientDirective) GetOmitempty() bool { return dir.Omitempty != nil && *dir.Omitempty }
    78  func (dir *genqlientDirective) GetStruct() bool    { return dir.Struct != nil && *dir.Struct }
    79  func (dir *genqlientDirective) GetFlatten() bool   { return dir.Flatten != nil && *dir.Flatten }
    80  
    81  func setBool(optionName string, dst **bool, v *ast.Value, pos *ast.Position) error {
    82  	if *dst != nil {
    83  		return errorf(pos, "conflicting values for %v", optionName)
    84  	}
    85  	ei, err := v.Value(nil) // no vars allowed
    86  	if err != nil {
    87  		return errorf(pos, "invalid boolean value %v: %v", v, err)
    88  	}
    89  	if b, ok := ei.(bool); ok {
    90  		*dst = &b
    91  		return nil
    92  	}
    93  	return errorf(pos, "expected boolean, got non-boolean value %T(%v)", ei, ei)
    94  }
    95  
    96  func setString(optionName string, dst *string, v *ast.Value, pos *ast.Position) error {
    97  	if *dst != "" {
    98  		return errorf(pos, "conflicting values for %v", optionName)
    99  	}
   100  	ei, err := v.Value(nil) // no vars allowed
   101  	if err != nil {
   102  		return errorf(pos, "invalid string value %v: %v", v, err)
   103  	}
   104  	if b, ok := ei.(string); ok {
   105  		*dst = b
   106  		return nil
   107  	}
   108  	return errorf(pos, "expected string, got non-string value %T(%v)", ei, ei)
   109  }
   110  
   111  // add adds to this genqlientDirective struct the settings from then given
   112  // GraphQL directive.
   113  //
   114  // If there are multiple genqlient directives are applied to the same node,
   115  // e.g.
   116  //	# @genqlient(...)
   117  //	# @genqlient(...)
   118  // add will be called several times.  In this case, conflicts between the
   119  // options are an error.
   120  func (dir *genqlientDirective) add(graphQLDirective *ast.Directive, pos *ast.Position) error {
   121  	if graphQLDirective.Name != "genqlient" {
   122  		// Actually we just won't get here; we only get here if the line starts
   123  		// with "# @genqlient", unless there's some sort of bug.
   124  		return errorf(pos, "the only valid comment-directive is @genqlient, got %v", graphQLDirective.Name)
   125  	}
   126  
   127  	// First, see if this directive has a "for" option;
   128  	// if it does, the rest of our work will operate on the
   129  	// appropriate place in FieldDirectives.
   130  	var err error
   131  	forField := ""
   132  	for _, arg := range graphQLDirective.Arguments {
   133  		if arg.Name == "for" {
   134  			if forField != "" {
   135  				return errorf(pos, `@genqlient directive had "for:" twice`)
   136  			}
   137  			err = setString("for", &forField, arg.Value, pos)
   138  			if err != nil {
   139  				return err
   140  			}
   141  		}
   142  	}
   143  	if forField != "" {
   144  		forParts := strings.Split(forField, ".")
   145  		if len(forParts) != 2 {
   146  			return errorf(pos, `for must be of the form "MyType.myField"`)
   147  		}
   148  		typeName, fieldName := forParts[0], forParts[1]
   149  
   150  		fieldDir := newGenqlientDirective(pos)
   151  		if dir.FieldDirectives[typeName] == nil {
   152  			dir.FieldDirectives[typeName] = make(map[string]*genqlientDirective)
   153  		}
   154  		dir.FieldDirectives[typeName][fieldName] = fieldDir
   155  
   156  		// Now, the rest of the function will operate on fieldDir.
   157  		dir = fieldDir
   158  	}
   159  
   160  	// Now parse the rest of the arguments.
   161  	for _, arg := range graphQLDirective.Arguments {
   162  		switch arg.Name {
   163  		// TODO(benkraft): Use reflect and struct tags?
   164  		case "omitempty":
   165  			err = setBool("omitempty", &dir.Omitempty, arg.Value, pos)
   166  		case "pointer":
   167  			err = setBool("pointer", &dir.Pointer, arg.Value, pos)
   168  		case "struct":
   169  			err = setBool("struct", &dir.Struct, arg.Value, pos)
   170  		case "flatten":
   171  			err = setBool("flatten", &dir.Flatten, arg.Value, pos)
   172  		case "bind":
   173  			err = setString("bind", &dir.Bind, arg.Value, pos)
   174  		case "typename":
   175  			err = setString("typename", &dir.TypeName, arg.Value, pos)
   176  		case "for":
   177  			// handled above
   178  		default:
   179  			return errorf(pos, "unknown argument %v for @genqlient", arg.Name)
   180  		}
   181  		if err != nil {
   182  			return err
   183  		}
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  func (dir *genqlientDirective) validate(node interface{}, schema *ast.Schema) error {
   190  	// TODO(benkraft): This function has a lot of duplicated checks, figure out
   191  	// how to organize them better to avoid the duplication.
   192  	for typeName, byField := range dir.FieldDirectives {
   193  		typ, ok := schema.Types[typeName]
   194  		if !ok {
   195  			return errorf(dir.pos, `for got invalid type-name "%s"`, typeName)
   196  		}
   197  		for fieldName, fieldDir := range byField {
   198  			var field *ast.FieldDefinition
   199  			for _, typeField := range typ.Fields {
   200  				if typeField.Name == fieldName {
   201  					field = typeField
   202  					break
   203  				}
   204  			}
   205  			if field == nil {
   206  				return errorf(fieldDir.pos,
   207  					`for got invalid field-name "%s" for type "%s"`,
   208  					fieldName, typeName)
   209  			}
   210  
   211  			// All options except struct and flatten potentially apply.  (I
   212  			// mean in theory you could apply them here, but since they require
   213  			// per-use validation, it would be a bit tricky, and the use case
   214  			// is not clear.)
   215  			if fieldDir.Struct != nil || fieldDir.Flatten != nil {
   216  				return errorf(fieldDir.pos, "struct and flatten can't be used via for")
   217  			}
   218  
   219  			if fieldDir.Omitempty != nil && field.Type.NonNull {
   220  				return errorf(fieldDir.pos, "omitempty may only be used on optional arguments")
   221  			}
   222  
   223  			if fieldDir.TypeName != "" && fieldDir.Bind != "" && fieldDir.Bind != "-" {
   224  				return errorf(fieldDir.pos, "typename and bind may not be used together")
   225  			}
   226  		}
   227  	}
   228  
   229  	switch node := node.(type) {
   230  	case *ast.OperationDefinition:
   231  		if dir.Bind != "" {
   232  			return errorf(dir.pos, "bind may not be applied to the entire operation")
   233  		}
   234  
   235  		// Anything else is valid on the entire operation; it will just apply
   236  		// to whatever it is relevant to.
   237  		return nil
   238  	case *ast.FragmentDefinition:
   239  		if dir.Bind != "" {
   240  			// TODO(benkraft): Implement this if people find it useful.
   241  			return errorf(dir.pos, "bind is not implemented for named fragments")
   242  		}
   243  
   244  		if dir.Struct != nil {
   245  			return errorf(dir.pos, "struct is only applicable to fields, not frragment-definitions")
   246  		}
   247  
   248  		// Like operations, anything else will just apply to the entire
   249  		// fragment.
   250  		return nil
   251  	case *ast.VariableDefinition:
   252  		if dir.Omitempty != nil && node.Type.NonNull {
   253  			return errorf(dir.pos, "omitempty may only be used on optional arguments")
   254  		}
   255  
   256  		if dir.Struct != nil {
   257  			return errorf(dir.pos, "struct is only applicable to fields, not variable-definitions")
   258  		}
   259  
   260  		if dir.Flatten != nil {
   261  			return errorf(dir.pos, "flatten is only applicable to fields, not variable-definitions")
   262  		}
   263  
   264  		if len(dir.FieldDirectives) > 0 {
   265  			return errorf(dir.pos, "for is only applicable to operations and arguments")
   266  		}
   267  
   268  		if dir.TypeName != "" && dir.Bind != "" && dir.Bind != "-" {
   269  			return errorf(dir.pos, "typename and bind may not be used together")
   270  		}
   271  
   272  		return nil
   273  	case *ast.Field:
   274  		if dir.Omitempty != nil {
   275  			return errorf(dir.pos, "omitempty is not applicable to variables, not fields")
   276  		}
   277  
   278  		typ := schema.Types[node.Definition.Type.Name()]
   279  		if dir.Struct != nil {
   280  			if err := validateStructOption(typ, node.SelectionSet, dir.pos); err != nil {
   281  				return err
   282  			}
   283  		}
   284  
   285  		if dir.Flatten != nil {
   286  			if _, err := validateFlattenOption(typ, node.SelectionSet, dir.pos); err != nil {
   287  				return err
   288  			}
   289  		}
   290  
   291  		if len(dir.FieldDirectives) > 0 {
   292  			return errorf(dir.pos, "for is only applicable to operations and arguments")
   293  		}
   294  
   295  		if dir.TypeName != "" && dir.Bind != "" && dir.Bind != "-" {
   296  			return errorf(dir.pos, "typename and bind may not be used together")
   297  		}
   298  
   299  		return nil
   300  	default:
   301  		return errorf(dir.pos, "invalid @genqlient directive location: %T", node)
   302  	}
   303  }
   304  
   305  func validateStructOption(
   306  	typ *ast.Definition,
   307  	selectionSet ast.SelectionSet,
   308  	pos *ast.Position,
   309  ) error {
   310  	if typ.Kind != ast.Interface && typ.Kind != ast.Union {
   311  		return errorf(pos, "struct is only applicable to interface-typed fields")
   312  	}
   313  
   314  	// Make sure that all the requested fields apply to the interface itself
   315  	// (not just certain implementations).
   316  	for _, selection := range selectionSet {
   317  		switch selection.(type) {
   318  		case *ast.Field:
   319  			// fields are fine.
   320  		case *ast.InlineFragment, *ast.FragmentSpread:
   321  			// Fragments aren't allowed. In principle we could allow them under
   322  			// the condition that the fragment applies to the whole interface
   323  			// (not just one implementation; and so on recursively), and for
   324  			// fragment spreads additionally that the fragment has the same
   325  			// option applied to it, but it seems more trouble than it's worth
   326  			// right now.
   327  			return errorf(pos, "struct is not allowed for types with fragments")
   328  		}
   329  	}
   330  	return nil
   331  }
   332  
   333  func validateFlattenOption(
   334  	typ *ast.Definition,
   335  	selectionSet ast.SelectionSet,
   336  	pos *ast.Position,
   337  ) (index int, err error) {
   338  	index = -1
   339  	if len(selectionSet) == 0 {
   340  		return -1, errorf(pos, "flatten is not allowed for leaf fields")
   341  	}
   342  
   343  	for i, selection := range selectionSet {
   344  		switch selection := selection.(type) {
   345  		case *ast.Field:
   346  			// If the field is auto-added __typename, ignore it for flattening
   347  			// purposes.
   348  			if selection.Name == "__typename" && selection.Position == nil {
   349  				continue
   350  			}
   351  			// Type-wise, it's no harder to implement flatten for fields, but
   352  			// it requires new logic in UnmarshalJSON.  We can add that if it
   353  			// proves useful relative to its complexity.
   354  			return -1, errorf(pos, "flatten is not yet supported for fields (only fragment spreads)")
   355  
   356  		case *ast.InlineFragment:
   357  			// Inline fragments aren't allowed. In principle there's nothing
   358  			// stopping us from allowing them (under the same type-match
   359  			// conditions as fragment spreads), but there's little value to it.
   360  			return -1, errorf(pos, "flatten is not allowed for selections with inline fragments")
   361  
   362  		case *ast.FragmentSpread:
   363  			if index != -1 {
   364  				return -1, errorf(pos, "flatten is not allowed for fields with multiple selections")
   365  			} else if !fragmentMatches(typ, selection.Definition.Definition) {
   366  				// We don't let you flatten
   367  				//  field { # type: FieldType
   368  				//		...Fragment # type: FragmentType
   369  				//	}
   370  				// unless FragmentType implements FieldType, because otherwise
   371  				// what do we do if we get back a type that doesn't implement
   372  				// FragmentType?
   373  				return -1, errorf(pos,
   374  					"flatten is not allowed for fields with fragment-spreads "+
   375  						"unless the field-type implements the fragment-type; "+
   376  						"field-type %s does not implement fragment-type %s",
   377  					typ.Name, selection.Definition.Definition.Name)
   378  			}
   379  			index = i
   380  		}
   381  	}
   382  	return index, nil
   383  }
   384  
   385  func fillDefaultBool(target **bool, defaults ...*bool) {
   386  	if *target != nil {
   387  		return
   388  	}
   389  
   390  	for _, val := range defaults {
   391  		if val != nil {
   392  			*target = val
   393  			return
   394  		}
   395  	}
   396  }
   397  
   398  func fillDefaultString(target *string, defaults ...string) {
   399  	if *target != "" {
   400  		return
   401  	}
   402  
   403  	for _, val := range defaults {
   404  		if val != "" {
   405  			*target = val
   406  			return
   407  		}
   408  	}
   409  }
   410  
   411  // merge updates the receiver, which is a directive applied to some node, with
   412  // the information from the directive applied to the fragment or operation
   413  // containing that node.  (The update is in-place.)
   414  //
   415  // Note this has slightly different semantics than .add(), see inline for
   416  // details.
   417  //
   418  // parent is as described in parsePrecedingComment.  operationDirective is the
   419  // directive applied to this operation or fragment.
   420  func (dir *genqlientDirective) mergeOperationDirective(
   421  	node interface{},
   422  	parentIfInputField *ast.Definition,
   423  	operationDirective *genqlientDirective,
   424  ) {
   425  	// We'll set forField to the `@genqlient(for: "<this field>", ...)`
   426  	// directive from our operation/fragment, if any.
   427  	var forField *genqlientDirective
   428  	switch field := node.(type) {
   429  	case *ast.Field: // query field
   430  		typeName := field.ObjectDefinition.Name
   431  		forField = operationDirective.FieldDirectives[typeName][field.Name]
   432  	case *ast.FieldDefinition: // input-type field
   433  		forField = operationDirective.FieldDirectives[parentIfInputField.Name][field.Name]
   434  	}
   435  	// Just to simplify nil-checking in the code below:
   436  	if forField == nil {
   437  		forField = newGenqlientDirective(nil)
   438  	}
   439  
   440  	// Now fill defaults; in general local directive wins over the "for" field
   441  	// directive wins over the operation directive.
   442  	fillDefaultBool(&dir.Omitempty, forField.Omitempty, operationDirective.Omitempty)
   443  	fillDefaultBool(&dir.Pointer, forField.Pointer, operationDirective.Pointer)
   444  	// struct and flatten aren't settable via "for".
   445  	fillDefaultBool(&dir.Struct, operationDirective.Struct)
   446  	fillDefaultBool(&dir.Flatten, operationDirective.Flatten)
   447  	fillDefaultString(&dir.Bind, forField.Bind, operationDirective.Bind)
   448  	// typename isn't settable on the operation (when set there it replies to
   449  	// the response-type).
   450  	fillDefaultString(&dir.TypeName, forField.TypeName)
   451  }
   452  
   453  // parsePrecedingComment looks at the comment right before this node, and
   454  // returns the genqlient directive applied to it (or an empty one if there is
   455  // none), the remaining human-readable comment (or "" if there is none), and an
   456  // error if the directive is invalid.
   457  //
   458  // queryOptions are the options to be applied to this entire query (or
   459  // fragment); the local options will be merged into those.  It should be nil if
   460  // we are parsing the directive on the entire query.
   461  //
   462  // parentIfInputField need only be set if node is an input-type field; it
   463  // should be the type containing this field.  (We can get this from gqlparser
   464  // in other cases, but not input-type fields.)
   465  func (g *generator) parsePrecedingComment(
   466  	node interface{},
   467  	parentIfInputField *ast.Definition,
   468  	pos *ast.Position,
   469  	queryOptions *genqlientDirective,
   470  ) (comment string, directive *genqlientDirective, err error) {
   471  	directive = newGenqlientDirective(pos)
   472  	hasDirective := false
   473  
   474  	// For directives on genqlient-generated nodes, we don't actually need to
   475  	// parse anything.  (But we do need to merge below.)
   476  	var commentLines []string
   477  	if pos != nil && pos.Src != nil {
   478  		sourceLines := strings.Split(pos.Src.Input, "\n")
   479  		for i := pos.Line - 1; i > 0; i-- {
   480  			line := strings.TrimSpace(sourceLines[i-1])
   481  			trimmed := strings.TrimSpace(strings.TrimPrefix(line, "#"))
   482  			if strings.HasPrefix(line, "# @genqlient") {
   483  				hasDirective = true
   484  				var graphQLDirective *ast.Directive
   485  				graphQLDirective, err = parseDirective(trimmed, pos)
   486  				if err != nil {
   487  					return "", nil, err
   488  				}
   489  				err = directive.add(graphQLDirective, pos)
   490  				if err != nil {
   491  					return "", nil, err
   492  				}
   493  			} else if strings.HasPrefix(line, "#") {
   494  				commentLines = append(commentLines, trimmed)
   495  			} else {
   496  				break
   497  			}
   498  		}
   499  	}
   500  
   501  	if hasDirective { // (else directive is empty)
   502  		err = directive.validate(node, g.schema)
   503  		if err != nil {
   504  			return "", nil, err
   505  		}
   506  	}
   507  
   508  	if queryOptions != nil {
   509  		// If we are part of an operation/fragment, merge its options in.
   510  		directive.mergeOperationDirective(node, parentIfInputField, queryOptions)
   511  
   512  		// TODO(benkraft): Really we should do all the validation after
   513  		// merging, probably?  But this is the only check that can fail only
   514  		// after merging, and it's a bit tricky because the "does not apply"
   515  		// checks may need to happen before merging so we know where the
   516  		// directive "is".
   517  		if directive.TypeName != "" && directive.Bind != "" && directive.Bind != "-" {
   518  			return "", nil, errorf(directive.pos, "typename and bind may not be used together")
   519  		}
   520  	}
   521  
   522  	reverse(commentLines)
   523  
   524  	return strings.TrimSpace(strings.Join(commentLines, "\n")), directive, nil
   525  }
   526  
   527  func parseDirective(line string, pos *ast.Position) (*ast.Directive, error) {
   528  	// HACK: parse the "directive" by making a fake query containing it.
   529  	fakeQuery := fmt.Sprintf("query %v { field }", line)
   530  	doc, err := parser.ParseQuery(&ast.Source{Input: fakeQuery})
   531  	if err != nil {
   532  		return nil, errorf(pos, "invalid genqlient directive: %v", err)
   533  	}
   534  	return doc.Operations[0].Directives[0], nil
   535  }