github.com/matiasanaya/gqlgen@v0.6.0/codegen/object.go (about)

     1  package codegen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"text/template"
     9  	"unicode"
    10  
    11  	"github.com/vektah/gqlparser/ast"
    12  )
    13  
    14  type GoFieldType int
    15  
    16  const (
    17  	GoFieldUndefined GoFieldType = iota
    18  	GoFieldMethod
    19  	GoFieldVariable
    20  )
    21  
    22  type Object struct {
    23  	*NamedType
    24  
    25  	Fields             []Field
    26  	Satisfies          []string
    27  	Implements         []*NamedType
    28  	ResolverInterface  *Ref
    29  	Root               bool
    30  	DisableConcurrency bool
    31  	Stream             bool
    32  }
    33  
    34  type Field struct {
    35  	*Type
    36  	Description    string          // Description of a field
    37  	GQLName        string          // The name of the field in graphql
    38  	GoFieldType    GoFieldType     // The field type in go, if any
    39  	GoReceiverName string          // The name of method & var receiver in go, if any
    40  	GoFieldName    string          // The name of the method or var in go, if any
    41  	Args           []FieldArgument // A list of arguments to be passed to this field
    42  	ForceResolver  bool            // Should be emit Resolver method
    43  	NoErr          bool            // If this is bound to a go method, does that method have an error as the second argument
    44  	Object         *Object         // A link back to the parent object
    45  	Default        interface{}     // The default value
    46  }
    47  
    48  type FieldArgument struct {
    49  	*Type
    50  
    51  	GQLName   string      // The name of the argument in graphql
    52  	GoVarName string      // The name of the var in go
    53  	Object    *Object     // A link back to the parent object
    54  	Default   interface{} // The default value
    55  }
    56  
    57  type Objects []*Object
    58  
    59  func (o *Object) Implementors() string {
    60  	satisfiedBy := strconv.Quote(o.GQLType)
    61  	for _, s := range o.Satisfies {
    62  		satisfiedBy += ", " + strconv.Quote(s)
    63  	}
    64  	return "[]string{" + satisfiedBy + "}"
    65  }
    66  
    67  func (o *Object) HasResolvers() bool {
    68  	for _, f := range o.Fields {
    69  		if f.IsResolver() {
    70  			return true
    71  		}
    72  	}
    73  	return false
    74  }
    75  
    76  func (o *Object) IsConcurrent() bool {
    77  	for _, f := range o.Fields {
    78  		if f.IsConcurrent() {
    79  			return true
    80  		}
    81  	}
    82  	return false
    83  }
    84  
    85  func (o *Object) IsReserved() bool {
    86  	return strings.HasPrefix(o.GQLType, "__")
    87  }
    88  
    89  func (f *Field) IsResolver() bool {
    90  	return f.GoFieldName == ""
    91  }
    92  
    93  func (f *Field) IsReserved() bool {
    94  	return strings.HasPrefix(f.GQLName, "__")
    95  }
    96  
    97  func (f *Field) IsMethod() bool {
    98  	return f.GoFieldType == GoFieldMethod
    99  }
   100  
   101  func (f *Field) IsVariable() bool {
   102  	return f.GoFieldType == GoFieldVariable
   103  }
   104  
   105  func (f *Field) IsConcurrent() bool {
   106  	return f.IsResolver() && !f.Object.DisableConcurrency
   107  }
   108  
   109  func (f *Field) GoNameExported() string {
   110  	return lintName(ucFirst(f.GQLName))
   111  }
   112  
   113  func (f *Field) GoNameUnexported() string {
   114  	return lintName(f.GQLName)
   115  }
   116  
   117  func (f *Field) ShortInvocation() string {
   118  	if !f.IsResolver() {
   119  		return ""
   120  	}
   121  
   122  	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
   123  }
   124  
   125  func (f *Field) ArgsFunc() string {
   126  	if len(f.Args) == 0 {
   127  		return ""
   128  	}
   129  
   130  	return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
   131  }
   132  
   133  func (f *Field) ResolverType() string {
   134  	if !f.IsResolver() {
   135  		return ""
   136  	}
   137  
   138  	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
   139  }
   140  
   141  func (f *Field) ShortResolverDeclaration() string {
   142  	if !f.IsResolver() {
   143  		return ""
   144  	}
   145  	res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
   146  
   147  	if !f.Object.Root {
   148  		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
   149  	}
   150  	for _, arg := range f.Args {
   151  		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
   152  	}
   153  
   154  	result := f.Signature()
   155  	if f.Object.Stream {
   156  		result = "<-chan " + result
   157  	}
   158  
   159  	res += fmt.Sprintf(") (%s, error)", result)
   160  	return res
   161  }
   162  
   163  func (f *Field) ResolverDeclaration() string {
   164  	if !f.IsResolver() {
   165  		return ""
   166  	}
   167  	res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
   168  
   169  	if !f.Object.Root {
   170  		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
   171  	}
   172  	for _, arg := range f.Args {
   173  		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
   174  	}
   175  
   176  	result := f.Signature()
   177  	if f.Object.Stream {
   178  		result = "<-chan " + result
   179  	}
   180  
   181  	res += fmt.Sprintf(") (%s, error)", result)
   182  	return res
   183  }
   184  
   185  func (f *Field) ComplexitySignature() string {
   186  	res := fmt.Sprintf("func(childComplexity int")
   187  	for _, arg := range f.Args {
   188  		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
   189  	}
   190  	res += ") int"
   191  	return res
   192  }
   193  
   194  func (f *Field) ComplexityArgs() string {
   195  	var args []string
   196  	for _, arg := range f.Args {
   197  		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
   198  	}
   199  
   200  	return strings.Join(args, ", ")
   201  }
   202  
   203  func (f *Field) CallArgs() string {
   204  	var args []string
   205  
   206  	if f.IsResolver() {
   207  		args = append(args, "rctx")
   208  
   209  		if !f.Object.Root {
   210  			args = append(args, "obj")
   211  		}
   212  	}
   213  
   214  	for _, arg := range f.Args {
   215  		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
   216  	}
   217  
   218  	return strings.Join(args, ", ")
   219  }
   220  
   221  // should be in the template, but its recursive and has a bunch of args
   222  func (f *Field) WriteJson() string {
   223  	return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
   224  }
   225  
   226  func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
   227  	switch {
   228  	case len(remainingMods) > 0 && remainingMods[0] == modPtr:
   229  		return tpl(`
   230  			if {{.val}} == nil {
   231  				{{- if .nonNull }}
   232  					if !ec.HasError(rctx) {
   233  						ec.Errorf(ctx, "must not be null")
   234  					}
   235  				{{- end }}
   236  				return graphql.Null
   237  			}
   238  			{{.next }}`, map[string]interface{}{
   239  			"val":     val,
   240  			"nonNull": astType.NonNull,
   241  			"next":    f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
   242  		})
   243  
   244  	case len(remainingMods) > 0 && remainingMods[0] == modList:
   245  		if isPtr {
   246  			val = "*" + val
   247  		}
   248  		var arr = "arr" + strconv.Itoa(depth)
   249  		var index = "idx" + strconv.Itoa(depth)
   250  		var usePtr bool
   251  		if len(remainingMods) == 1 && !isPtr {
   252  			usePtr = true
   253  		}
   254  
   255  		return tpl(`
   256  			{{.arr}} := make(graphql.Array, len({{.val}}))
   257  			{{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
   258  			{{ if not .isScalar }}
   259  				isLen1 := len({{.val}}) == 1
   260  				if !isLen1 {
   261  					wg.Add(len({{.val}}))
   262  				}
   263  			{{ end }}
   264  			for {{.index}} := range {{.val}} {
   265  				{{- if not .isScalar }}
   266  					{{.index}} := {{.index}}
   267  					rctx := &graphql.ResolverContext{
   268  						Index: &{{.index}},
   269  						Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
   270  					}
   271  					ctx := graphql.WithResolverContext(ctx, rctx)
   272  					f := func({{.index}} int) {
   273  						if !isLen1 {
   274  							defer wg.Done()
   275  						}
   276  						{{.arr}}[{{.index}}] = func() graphql.Marshaler {
   277  							{{ .next }}
   278  						}()
   279  					}
   280  					if isLen1 {
   281  						f({{.index}})
   282  					} else {
   283  						go f({{.index}})
   284  					}
   285  				{{ else }}
   286  					{{.arr}}[{{.index}}] = func() graphql.Marshaler {
   287  						{{ .next }}
   288  					}()
   289  				{{- end}}
   290  			}
   291  			{{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
   292  			return {{.arr}}`, map[string]interface{}{
   293  			"val":      val,
   294  			"arr":      arr,
   295  			"index":    index,
   296  			"top":      depth == 1,
   297  			"arrayLen": len(val),
   298  			"isScalar": f.IsScalar,
   299  			"usePtr":   usePtr,
   300  			"next":     f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
   301  		})
   302  
   303  	case f.IsScalar:
   304  		if isPtr {
   305  			val = "*" + val
   306  		}
   307  		return f.Marshal(val)
   308  
   309  	default:
   310  		if !isPtr {
   311  			val = "&" + val
   312  		}
   313  		return tpl(`
   314  			return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
   315  			"type": f.GQLType,
   316  			"val":  val,
   317  		})
   318  	}
   319  }
   320  
   321  func (f *FieldArgument) Stream() bool {
   322  	return f.Object != nil && f.Object.Stream
   323  }
   324  
   325  func (os Objects) ByName(name string) *Object {
   326  	for i, o := range os {
   327  		if strings.EqualFold(o.GQLType, name) {
   328  			return os[i]
   329  		}
   330  	}
   331  	return nil
   332  }
   333  
   334  func tpl(tpl string, vars map[string]interface{}) string {
   335  	b := &bytes.Buffer{}
   336  	err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
   337  	if err != nil {
   338  		panic(err)
   339  	}
   340  	return b.String()
   341  }
   342  
   343  func ucFirst(s string) string {
   344  	if s == "" {
   345  		return ""
   346  	}
   347  
   348  	r := []rune(s)
   349  	r[0] = unicode.ToUpper(r[0])
   350  	return string(r)
   351  }
   352  
   353  // copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
   354  
   355  // lintName returns a different name if it should be different.
   356  func lintName(name string) (should string) {
   357  	// Fast path for simple cases: "_" and all lowercase.
   358  	if name == "_" {
   359  		return name
   360  	}
   361  	allLower := true
   362  	for _, r := range name {
   363  		if !unicode.IsLower(r) {
   364  			allLower = false
   365  			break
   366  		}
   367  	}
   368  	if allLower {
   369  		return name
   370  	}
   371  
   372  	// Split camelCase at any lower->upper transition, and split on underscores.
   373  	// Check each word for common initialisms.
   374  	runes := []rune(name)
   375  	w, i := 0, 0 // index of start of word, scan
   376  	for i+1 <= len(runes) {
   377  		eow := false // whether we hit the end of a word
   378  		if i+1 == len(runes) {
   379  			eow = true
   380  		} else if runes[i+1] == '_' {
   381  			// underscore; shift the remainder forward over any run of underscores
   382  			eow = true
   383  			n := 1
   384  			for i+n+1 < len(runes) && runes[i+n+1] == '_' {
   385  				n++
   386  			}
   387  
   388  			// Leave at most one underscore if the underscore is between two digits
   389  			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
   390  				n--
   391  			}
   392  
   393  			copy(runes[i+1:], runes[i+n+1:])
   394  			runes = runes[:len(runes)-n]
   395  		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
   396  			// lower->non-lower
   397  			eow = true
   398  		}
   399  		i++
   400  		if !eow {
   401  			continue
   402  		}
   403  
   404  		// [w,i) is a word.
   405  		word := string(runes[w:i])
   406  		if u := strings.ToUpper(word); commonInitialisms[u] {
   407  			// Keep consistent case, which is lowercase only at the start.
   408  			if w == 0 && unicode.IsLower(runes[w]) {
   409  				u = strings.ToLower(u)
   410  			}
   411  			// All the common initialisms are ASCII,
   412  			// so we can replace the bytes exactly.
   413  			copy(runes[w:], []rune(u))
   414  		} else if w > 0 && strings.ToLower(word) == word {
   415  			// already all lowercase, and not the first word, so uppercase the first character.
   416  			runes[w] = unicode.ToUpper(runes[w])
   417  		}
   418  		w = i
   419  	}
   420  	return string(runes)
   421  }
   422  
   423  // commonInitialisms is a set of common initialisms.
   424  // Only add entries that are highly unlikely to be non-initialisms.
   425  // For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
   426  var commonInitialisms = map[string]bool{
   427  	"ACL":   true,
   428  	"API":   true,
   429  	"ASCII": true,
   430  	"CPU":   true,
   431  	"CSS":   true,
   432  	"DNS":   true,
   433  	"EOF":   true,
   434  	"GUID":  true,
   435  	"HTML":  true,
   436  	"HTTP":  true,
   437  	"HTTPS": true,
   438  	"ID":    true,
   439  	"IP":    true,
   440  	"JSON":  true,
   441  	"LHS":   true,
   442  	"QPS":   true,
   443  	"RAM":   true,
   444  	"RHS":   true,
   445  	"RPC":   true,
   446  	"SLA":   true,
   447  	"SMTP":  true,
   448  	"SQL":   true,
   449  	"SSH":   true,
   450  	"TCP":   true,
   451  	"TLS":   true,
   452  	"TTL":   true,
   453  	"UDP":   true,
   454  	"UI":    true,
   455  	"UID":   true,
   456  	"UUID":  true,
   457  	"URI":   true,
   458  	"URL":   true,
   459  	"UTF8":  true,
   460  	"VM":    true,
   461  	"XML":   true,
   462  	"XMPP":  true,
   463  	"XSRF":  true,
   464  	"XSS":   true,
   465  }