github.com/humans-group/gqlgen@v0.7.2/codegen/object.go (about)

     1  package codegen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"text/template"
     9  	"unicode"
    10  
    11  	"github.com/vektah/gqlparser/ast"
    12  )
    13  
    14  type GoFieldType int
    15  
    16  const (
    17  	GoFieldUndefined GoFieldType = iota
    18  	GoFieldMethod
    19  	GoFieldVariable
    20  )
    21  
    22  type Object struct {
    23  	*NamedType
    24  
    25  	Fields             []Field
    26  	Satisfies          []string
    27  	Implements         []*NamedType
    28  	ResolverInterface  *Ref
    29  	Root               bool
    30  	DisableConcurrency bool
    31  	Stream             bool
    32  }
    33  
    34  type Field struct {
    35  	*Type
    36  	Description      string          // Description of a field
    37  	GQLName          string          // The name of the field in graphql
    38  	GoFieldType      GoFieldType     // The field type in go, if any
    39  	GoReceiverName   string          // The name of method & var receiver in go, if any
    40  	GoFieldName      string          // The name of the method or var in go, if any
    41  	Args             []FieldArgument // A list of arguments to be passed to this field
    42  	ForceResolver    bool            // Should be emit Resolver method
    43  	MethodHasContext bool            // If this is bound to a go method, does the method also take a context
    44  	NoErr            bool            // If this is bound to a go method, does that method have an error as the second argument
    45  	Object           *Object         // A link back to the parent object
    46  	Default          interface{}     // The default value
    47  }
    48  
    49  type FieldArgument struct {
    50  	*Type
    51  
    52  	GQLName   string      // The name of the argument in graphql
    53  	GoVarName string      // The name of the var in go
    54  	Object    *Object     // A link back to the parent object
    55  	Default   interface{} // The default value
    56  }
    57  
    58  type Objects []*Object
    59  
    60  func (o *Object) Implementors() string {
    61  	satisfiedBy := strconv.Quote(o.GQLType)
    62  	for _, s := range o.Satisfies {
    63  		satisfiedBy += ", " + strconv.Quote(s)
    64  	}
    65  	return "[]string{" + satisfiedBy + "}"
    66  }
    67  
    68  func (o *Object) HasResolvers() bool {
    69  	for _, f := range o.Fields {
    70  		if f.IsResolver() {
    71  			return true
    72  		}
    73  	}
    74  	return false
    75  }
    76  
    77  func (o *Object) IsConcurrent() bool {
    78  	for _, f := range o.Fields {
    79  		if f.IsConcurrent() {
    80  			return true
    81  		}
    82  	}
    83  	return false
    84  }
    85  
    86  func (o *Object) IsReserved() bool {
    87  	return strings.HasPrefix(o.GQLType, "__")
    88  }
    89  
    90  func (f *Field) IsResolver() bool {
    91  	return f.GoFieldName == ""
    92  }
    93  
    94  func (f *Field) IsReserved() bool {
    95  	return strings.HasPrefix(f.GQLName, "__")
    96  }
    97  
    98  func (f *Field) IsMethod() bool {
    99  	return f.GoFieldType == GoFieldMethod
   100  }
   101  
   102  func (f *Field) IsVariable() bool {
   103  	return f.GoFieldType == GoFieldVariable
   104  }
   105  
   106  func (f *Field) IsConcurrent() bool {
   107  	if f.Object.DisableConcurrency {
   108  		return false
   109  	}
   110  	return f.MethodHasContext || f.IsResolver()
   111  }
   112  
   113  func (f *Field) GoNameExported() string {
   114  	return lintName(ucFirst(f.GQLName))
   115  }
   116  
   117  func (f *Field) GoNameUnexported() string {
   118  	return lintName(f.GQLName)
   119  }
   120  
   121  func (f *Field) ShortInvocation() string {
   122  	if !f.IsResolver() {
   123  		return ""
   124  	}
   125  
   126  	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
   127  }
   128  
   129  func (f *Field) ArgsFunc() string {
   130  	if len(f.Args) == 0 {
   131  		return ""
   132  	}
   133  
   134  	return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
   135  }
   136  
   137  func (f *Field) ResolverType() string {
   138  	if !f.IsResolver() {
   139  		return ""
   140  	}
   141  
   142  	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
   143  }
   144  
   145  func (f *Field) ShortResolverDeclaration() string {
   146  	if !f.IsResolver() {
   147  		return ""
   148  	}
   149  	res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
   150  
   151  	if !f.Object.Root {
   152  		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
   153  	}
   154  	for _, arg := range f.Args {
   155  		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
   156  	}
   157  
   158  	result := f.Signature()
   159  	if f.Object.Stream {
   160  		result = "<-chan " + result
   161  	}
   162  
   163  	res += fmt.Sprintf(") (%s, error)", result)
   164  	return res
   165  }
   166  
   167  func (f *Field) ResolverDeclaration() string {
   168  	if !f.IsResolver() {
   169  		return ""
   170  	}
   171  	res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
   172  
   173  	if !f.Object.Root {
   174  		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
   175  	}
   176  	for _, arg := range f.Args {
   177  		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
   178  	}
   179  
   180  	result := f.Signature()
   181  	if f.Object.Stream {
   182  		result = "<-chan " + result
   183  	}
   184  
   185  	res += fmt.Sprintf(") (%s, error)", result)
   186  	return res
   187  }
   188  
   189  func (f *Field) ComplexitySignature() string {
   190  	res := fmt.Sprintf("func(childComplexity int")
   191  	for _, arg := range f.Args {
   192  		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
   193  	}
   194  	res += ") int"
   195  	return res
   196  }
   197  
   198  func (f *Field) ComplexityArgs() string {
   199  	var args []string
   200  	for _, arg := range f.Args {
   201  		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
   202  	}
   203  
   204  	return strings.Join(args, ", ")
   205  }
   206  
   207  func (f *Field) CallArgs() string {
   208  	var args []string
   209  
   210  	if f.IsResolver() {
   211  		args = append(args, "rctx")
   212  
   213  		if !f.Object.Root {
   214  			args = append(args, "obj")
   215  		}
   216  	} else {
   217  		if f.MethodHasContext {
   218  			args = append(args, "ctx")
   219  		}
   220  	}
   221  
   222  	for _, arg := range f.Args {
   223  		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
   224  	}
   225  
   226  	return strings.Join(args, ", ")
   227  }
   228  
   229  // should be in the template, but its recursive and has a bunch of args
   230  func (f *Field) WriteJson() string {
   231  	return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
   232  }
   233  
   234  func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
   235  	switch {
   236  	case len(remainingMods) > 0 && remainingMods[0] == modPtr:
   237  		return tpl(`
   238  			if {{.val}} == nil {
   239  				{{- if .nonNull }}
   240  					if !ec.HasError(rctx) {
   241  						ec.Errorf(ctx, "must not be null")
   242  					}
   243  				{{- end }}
   244  				return graphql.Null
   245  			}
   246  			{{.next }}`, map[string]interface{}{
   247  			"val":     val,
   248  			"nonNull": astType.NonNull,
   249  			"next":    f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
   250  		})
   251  
   252  	case len(remainingMods) > 0 && remainingMods[0] == modList:
   253  		if isPtr {
   254  			val = "*" + val
   255  		}
   256  		var arr = "arr" + strconv.Itoa(depth)
   257  		var index = "idx" + strconv.Itoa(depth)
   258  		var usePtr bool
   259  		if len(remainingMods) == 1 && !isPtr {
   260  			usePtr = true
   261  		}
   262  
   263  		return tpl(`
   264  			{{.arr}} := make(graphql.Array, len({{.val}}))
   265  			{{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
   266  			{{ if not .isScalar }}
   267  				isLen1 := len({{.val}}) == 1
   268  				if !isLen1 {
   269  					wg.Add(len({{.val}}))
   270  				}
   271  			{{ end }}
   272  			for {{.index}} := range {{.val}} {
   273  				{{- if not .isScalar }}
   274  					{{.index}} := {{.index}}
   275  					rctx := &graphql.ResolverContext{
   276  						Index: &{{.index}},
   277  						Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
   278  					}
   279  					ctx := graphql.WithResolverContext(ctx, rctx)
   280  					f := func({{.index}} int) {
   281  						if !isLen1 {
   282  							defer wg.Done()
   283  						}
   284  						{{.arr}}[{{.index}}] = func() graphql.Marshaler {
   285  							{{ .next }}
   286  						}()
   287  					}
   288  					if isLen1 {
   289  						f({{.index}})
   290  					} else {
   291  						go f({{.index}})
   292  					}
   293  				{{ else }}
   294  					{{.arr}}[{{.index}}] = func() graphql.Marshaler {
   295  						{{ .next }}
   296  					}()
   297  				{{- end}}
   298  			}
   299  			{{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
   300  			return {{.arr}}`, map[string]interface{}{
   301  			"val":      val,
   302  			"arr":      arr,
   303  			"index":    index,
   304  			"top":      depth == 1,
   305  			"arrayLen": len(val),
   306  			"isScalar": f.IsScalar,
   307  			"usePtr":   usePtr,
   308  			"next":     f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
   309  		})
   310  
   311  	case f.IsScalar:
   312  		if isPtr {
   313  			val = "*" + val
   314  		}
   315  		return f.Marshal(val)
   316  
   317  	default:
   318  		if !isPtr {
   319  			val = "&" + val
   320  		}
   321  		return tpl(`
   322  			return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
   323  			"type": f.GQLType,
   324  			"val":  val,
   325  		})
   326  	}
   327  }
   328  
   329  func (f *FieldArgument) Stream() bool {
   330  	return f.Object != nil && f.Object.Stream
   331  }
   332  
   333  func (os Objects) ByName(name string) *Object {
   334  	for i, o := range os {
   335  		if strings.EqualFold(o.GQLType, name) {
   336  			return os[i]
   337  		}
   338  	}
   339  	return nil
   340  }
   341  
   342  func tpl(tpl string, vars map[string]interface{}) string {
   343  	b := &bytes.Buffer{}
   344  	err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
   345  	if err != nil {
   346  		panic(err)
   347  	}
   348  	return b.String()
   349  }
   350  
   351  func ucFirst(s string) string {
   352  	if s == "" {
   353  		return ""
   354  	}
   355  
   356  	r := []rune(s)
   357  	r[0] = unicode.ToUpper(r[0])
   358  	return string(r)
   359  }
   360  
   361  // copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
   362  
   363  // lintName returns a different name if it should be different.
   364  func lintName(name string) (should string) {
   365  	// Fast path for simple cases: "_" and all lowercase.
   366  	if name == "_" {
   367  		return name
   368  	}
   369  	allLower := true
   370  	for _, r := range name {
   371  		if !unicode.IsLower(r) {
   372  			allLower = false
   373  			break
   374  		}
   375  	}
   376  	if allLower {
   377  		return name
   378  	}
   379  
   380  	// Split camelCase at any lower->upper transition, and split on underscores.
   381  	// Check each word for common initialisms.
   382  	runes := []rune(name)
   383  	w, i := 0, 0 // index of start of word, scan
   384  	for i+1 <= len(runes) {
   385  		eow := false // whether we hit the end of a word
   386  		if i+1 == len(runes) {
   387  			eow = true
   388  		} else if runes[i+1] == '_' {
   389  			// underscore; shift the remainder forward over any run of underscores
   390  			eow = true
   391  			n := 1
   392  			for i+n+1 < len(runes) && runes[i+n+1] == '_' {
   393  				n++
   394  			}
   395  
   396  			// Leave at most one underscore if the underscore is between two digits
   397  			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
   398  				n--
   399  			}
   400  
   401  			copy(runes[i+1:], runes[i+n+1:])
   402  			runes = runes[:len(runes)-n]
   403  		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
   404  			// lower->non-lower
   405  			eow = true
   406  		}
   407  		i++
   408  		if !eow {
   409  			continue
   410  		}
   411  
   412  		// [w,i) is a word.
   413  		word := string(runes[w:i])
   414  		if u := strings.ToUpper(word); commonInitialisms[u] {
   415  			// Keep consistent case, which is lowercase only at the start.
   416  			if w == 0 && unicode.IsLower(runes[w]) {
   417  				u = strings.ToLower(u)
   418  			}
   419  			// All the common initialisms are ASCII,
   420  			// so we can replace the bytes exactly.
   421  			copy(runes[w:], []rune(u))
   422  		} else if w > 0 && strings.ToLower(word) == word {
   423  			// already all lowercase, and not the first word, so uppercase the first character.
   424  			runes[w] = unicode.ToUpper(runes[w])
   425  		}
   426  		w = i
   427  	}
   428  	return string(runes)
   429  }
   430  
   431  // commonInitialisms is a set of common initialisms.
   432  // Only add entries that are highly unlikely to be non-initialisms.
   433  // For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
   434  var commonInitialisms = map[string]bool{
   435  	"ACL":   true,
   436  	"API":   true,
   437  	"ASCII": true,
   438  	"CPU":   true,
   439  	"CSS":   true,
   440  	"DNS":   true,
   441  	"EOF":   true,
   442  	"GUID":  true,
   443  	"HTML":  true,
   444  	"HTTP":  true,
   445  	"HTTPS": true,
   446  	"ID":    true,
   447  	"IP":    true,
   448  	"JSON":  true,
   449  	"LHS":   true,
   450  	"QPS":   true,
   451  	"RAM":   true,
   452  	"RHS":   true,
   453  	"RPC":   true,
   454  	"SLA":   true,
   455  	"SMTP":  true,
   456  	"SQL":   true,
   457  	"SSH":   true,
   458  	"TCP":   true,
   459  	"TLS":   true,
   460  	"TTL":   true,
   461  	"UDP":   true,
   462  	"UI":    true,
   463  	"UID":   true,
   464  	"UUID":  true,
   465  	"URI":   true,
   466  	"URL":   true,
   467  	"UTF8":  true,
   468  	"VM":    true,
   469  	"XML":   true,
   470  	"XMPP":  true,
   471  	"XSRF":  true,
   472  	"XSS":   true,
   473  }