github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"go/types"
     7  	"sort"
     8  	"strings"
     9  	"text/template"
    10  
    11  	"github.com/spread-ai/gqlgen/codegen/config"
    12  	"github.com/spread-ai/gqlgen/codegen/templates"
    13  	"github.com/spread-ai/gqlgen/plugin"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  )
    16  
    17  //go:embed models.gotpl
    18  var modelTemplate string
    19  
    20  type BuildMutateHook = func(b *ModelBuild) *ModelBuild
    21  
    22  type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
    23  
    24  // defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook.
    25  func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
    26  	return GoTagFieldHook(td, fd, f)
    27  }
    28  
    29  func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    30  	return b
    31  }
    32  
    33  type ModelBuild struct {
    34  	PackageName string
    35  	Interfaces  []*Interface
    36  	Models      []*Object
    37  	Enums       []*Enum
    38  	Scalars     []string
    39  }
    40  
    41  type Interface struct {
    42  	Description string
    43  	Name        string
    44  	Fields      []*Field
    45  	Implements  []string
    46  }
    47  
    48  type Object struct {
    49  	Description string
    50  	Name        string
    51  	Fields      []*Field
    52  	Implements  []string
    53  }
    54  
    55  type Field struct {
    56  	Description string
    57  	// Name is the field's name as it appears in the schema
    58  	Name string
    59  	// GoName is the field's name as it appears in the generated Go code
    60  	GoName string
    61  	Type   types.Type
    62  	Tag    string
    63  }
    64  
    65  type Enum struct {
    66  	Description string
    67  	Name        string
    68  	Values      []*EnumValue
    69  }
    70  
    71  type EnumValue struct {
    72  	Description string
    73  	Name        string
    74  }
    75  
    76  func New() plugin.Plugin {
    77  	return &Plugin{
    78  		MutateHook: defaultBuildMutateHook,
    79  		FieldHook:  defaultFieldMutateHook,
    80  	}
    81  }
    82  
    83  type Plugin struct {
    84  	MutateHook BuildMutateHook
    85  	FieldHook  FieldMutateHook
    86  }
    87  
    88  var _ plugin.ConfigMutator = &Plugin{}
    89  
    90  func (m *Plugin) Name() string {
    91  	return "modelgen"
    92  }
    93  
    94  func (m *Plugin) MutateConfig(cfg *config.Config) error {
    95  
    96  	b := &ModelBuild{
    97  		PackageName: cfg.Model.Package,
    98  	}
    99  
   100  	for _, schemaType := range cfg.Schema.Types {
   101  		if cfg.Models.UserDefined(schemaType.Name) {
   102  			continue
   103  		}
   104  		switch schemaType.Kind {
   105  		case ast.Interface, ast.Union:
   106  			var fields []*Field
   107  			var err error
   108  			if !cfg.OmitGetters {
   109  				fields, err = m.generateFields(cfg, schemaType)
   110  				if err != nil {
   111  					return err
   112  				}
   113  			}
   114  
   115  			it := &Interface{
   116  				Description: schemaType.Description,
   117  				Name:        schemaType.Name,
   118  				Implements:  schemaType.Interfaces,
   119  				Fields:      fields,
   120  			}
   121  
   122  			b.Interfaces = append(b.Interfaces, it)
   123  		case ast.Object, ast.InputObject:
   124  			if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
   125  				continue
   126  			}
   127  
   128  			fields, err := m.generateFields(cfg, schemaType)
   129  			if err != nil {
   130  				return err
   131  			}
   132  
   133  			it := &Object{
   134  				Description: schemaType.Description,
   135  				Name:        schemaType.Name,
   136  				Fields:      fields,
   137  			}
   138  
   139  			// If Interface A implements interface B, and Interface C also implements interface B
   140  			// then both A and C have methods of B.
   141  			// The reason for checking unique is to prevent the same method B from being generated twice.
   142  			uniqueMap := map[string]bool{}
   143  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   144  				if !uniqueMap[implementor.Name] {
   145  					it.Implements = append(it.Implements, implementor.Name)
   146  					uniqueMap[implementor.Name] = true
   147  				}
   148  				// for interface implements
   149  				for _, iface := range implementor.Interfaces {
   150  					if !uniqueMap[iface] {
   151  						it.Implements = append(it.Implements, iface)
   152  						uniqueMap[iface] = true
   153  					}
   154  				}
   155  			}
   156  
   157  			b.Models = append(b.Models, it)
   158  		case ast.Enum:
   159  			it := &Enum{
   160  				Name:        schemaType.Name,
   161  				Description: schemaType.Description,
   162  			}
   163  
   164  			for _, v := range schemaType.EnumValues {
   165  				it.Values = append(it.Values, &EnumValue{
   166  					Name:        v.Name,
   167  					Description: v.Description,
   168  				})
   169  			}
   170  
   171  			b.Enums = append(b.Enums, it)
   172  		case ast.Scalar:
   173  			b.Scalars = append(b.Scalars, schemaType.Name)
   174  		}
   175  	}
   176  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   177  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   178  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   179  
   180  	// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
   181  	// check for cyclical relationships and recursive structs
   182  	if !cfg.StructFieldsAlwaysPointers {
   183  		findAndHandleCyclicalRelationships(b)
   184  	}
   185  
   186  	for _, it := range b.Enums {
   187  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   188  	}
   189  	for _, it := range b.Models {
   190  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   191  	}
   192  	for _, it := range b.Interfaces {
   193  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   194  	}
   195  	for _, it := range b.Scalars {
   196  		cfg.Models.Add(it, "github.com/spread-ai/gqlgen/graphql.String")
   197  	}
   198  
   199  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   200  		return nil
   201  	}
   202  
   203  	if m.MutateHook != nil {
   204  		b = m.MutateHook(b)
   205  	}
   206  
   207  	getInterfaceByName := func(name string) *Interface {
   208  		// Allow looking up interfaces, so template can generate getters for each field
   209  		for _, i := range b.Interfaces {
   210  			if i.Name == name {
   211  				return i
   212  			}
   213  		}
   214  
   215  		return nil
   216  	}
   217  	gettersGenerated := make(map[string]map[string]struct{})
   218  	generateGetter := func(model *Object, field *Field) string {
   219  		if model == nil || field == nil {
   220  			return ""
   221  		}
   222  
   223  		// Let templates check if a given getter has been generated already
   224  		typeGetters, exists := gettersGenerated[model.Name]
   225  		if !exists {
   226  			typeGetters = make(map[string]struct{})
   227  			gettersGenerated[model.Name] = typeGetters
   228  		}
   229  
   230  		_, exists = typeGetters[field.GoName]
   231  		typeGetters[field.GoName] = struct{}{}
   232  		if exists {
   233  			return ""
   234  		}
   235  
   236  		_, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer)
   237  		var structFieldTypeIsPointer bool
   238  		for _, f := range model.Fields {
   239  			if f.GoName == field.GoName {
   240  				_, structFieldTypeIsPointer = f.Type.(*types.Pointer)
   241  				break
   242  			}
   243  		}
   244  		goType := templates.CurrentImports.LookupType(field.Type)
   245  		if strings.HasPrefix(goType, "[]") {
   246  			getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType)
   247  			getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
   248  			getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
   249  			getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
   250  			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
   251  				getter += "&"
   252  			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
   253  				getter += "*"
   254  			}
   255  			getter += "concrete) }\n"
   256  			getter += "\treturn interfaceSlice\n"
   257  			getter += "}"
   258  			return getter
   259  		} else {
   260  			getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)
   261  
   262  			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
   263  				getter += "&"
   264  			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
   265  				getter += "*"
   266  			}
   267  
   268  			getter += fmt.Sprintf("this.%s }", field.GoName)
   269  			return getter
   270  		}
   271  	}
   272  	funcMap := template.FuncMap{
   273  		"getInterfaceByName": getInterfaceByName,
   274  		"generateGetter":     generateGetter,
   275  	}
   276  
   277  	err := templates.Render(templates.Options{
   278  		PackageName:     cfg.Model.Package,
   279  		Filename:        cfg.Model.Filename,
   280  		Data:            b,
   281  		GeneratedHeader: true,
   282  		Packages:        cfg.Packages,
   283  		Template:        modelTemplate,
   284  		Funcs:           funcMap,
   285  	})
   286  	if err != nil {
   287  		return err
   288  	}
   289  
   290  	// We may have generated code in a package we already loaded, so we reload all packages
   291  	// to allow packages to be compared correctly
   292  	cfg.ReloadAllPackages()
   293  
   294  	return nil
   295  }
   296  
   297  func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
   298  	binder := cfg.NewBinder()
   299  	fields := make([]*Field, 0)
   300  
   301  	for _, field := range schemaType.Fields {
   302  		var typ types.Type
   303  		fieldDef := cfg.Schema.Types[field.Type.Name()]
   304  
   305  		if cfg.Models.UserDefined(field.Type.Name()) {
   306  			var err error
   307  			typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   308  			if err != nil {
   309  				return nil, err
   310  			}
   311  		} else {
   312  			switch fieldDef.Kind {
   313  			case ast.Scalar:
   314  				// no user defined model, referencing a default scalar
   315  				typ = types.NewNamed(
   316  					types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   317  					nil,
   318  					nil,
   319  				)
   320  
   321  			case ast.Interface, ast.Union:
   322  				// no user defined model, referencing a generated interface type
   323  				typ = types.NewNamed(
   324  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   325  					types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   326  					nil,
   327  				)
   328  
   329  			case ast.Enum:
   330  				// no user defined model, must reference a generated enum
   331  				typ = types.NewNamed(
   332  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   333  					nil,
   334  					nil,
   335  				)
   336  
   337  			case ast.Object, ast.InputObject:
   338  				// no user defined model, must reference a generated struct
   339  				typ = types.NewNamed(
   340  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   341  					types.NewStruct(nil, nil),
   342  					nil,
   343  				)
   344  
   345  			default:
   346  				panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   347  			}
   348  		}
   349  
   350  		name := templates.ToGo(field.Name)
   351  		if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   352  			name = nameOveride
   353  		}
   354  
   355  		typ = binder.CopyModifiersFromAst(field.Type, typ)
   356  
   357  		if cfg.StructFieldsAlwaysPointers {
   358  			if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   359  				typ = types.NewPointer(typ)
   360  			}
   361  		}
   362  
   363  		f := &Field{
   364  			Name:        field.Name,
   365  			GoName:      name,
   366  			Type:        typ,
   367  			Description: field.Description,
   368  			Tag:         `json:"` + field.Name + `"`,
   369  		}
   370  
   371  		if m.FieldHook != nil {
   372  			mf, err := m.FieldHook(schemaType, field, f)
   373  			if err != nil {
   374  				return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
   375  			}
   376  			f = mf
   377  		}
   378  
   379  		fields = append(fields, f)
   380  	}
   381  
   382  	return fields, nil
   383  }
   384  
   385  // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field
   386  // name is used when no value argument is present.
   387  func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   388  	args := make([]string, 0)
   389  	for _, goTag := range fd.Directives.ForNames("goTag") {
   390  		key := ""
   391  		value := fd.Name
   392  
   393  		if arg := goTag.Arguments.ForName("key"); arg != nil {
   394  			if k, err := arg.Value.Value(nil); err == nil {
   395  				key = k.(string)
   396  			}
   397  		}
   398  
   399  		if arg := goTag.Arguments.ForName("value"); arg != nil {
   400  			if v, err := arg.Value.Value(nil); err == nil {
   401  				value = v.(string)
   402  			}
   403  		}
   404  
   405  		if key == "json" {
   406  			if value == "omitempty" {
   407  				f.Tag = strings.ReplaceAll(f.Tag, `json:"`+f.Name+`"`, `json:"`+f.Name+`,omitempty"`)
   408  			} else {
   409  				f.Tag = strings.ReplaceAll(f.Tag, `json:"`+f.Name+`"`, "")
   410  				args = append(args, key+":\""+value+"\"")
   411  			}
   412  		} else {
   413  			args = append(args, key+":\""+value+"\"")
   414  		}
   415  	}
   416  
   417  	if len(args) > 0 {
   418  		f.Tag = f.Tag + " " + strings.Join(args, " ")
   419  	}
   420  
   421  	return f, nil
   422  }
   423  
   424  func isStruct(t types.Type) bool {
   425  	_, is := t.Underlying().(*types.Struct)
   426  	return is
   427  }
   428  
   429  // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
   430  // with pointers. These relationships will produce compilation errors if they are not pointers.
   431  // Also handles recursive structs.
   432  func findAndHandleCyclicalRelationships(b *ModelBuild) {
   433  	for ii, structA := range b.Models {
   434  		for _, fieldA := range structA.Fields {
   435  			if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
   436  				fmt.Print()
   437  			}
   438  			if !isStruct(fieldA.Type) {
   439  				continue
   440  			}
   441  
   442  			// the field Type string will be in the form "github.com/spread-ai/gqlgen/codegen/testserver/followschema.LoopA"
   443  			// we only want the part after the last dot: "LoopA"
   444  			// this could lead to false positives, as we are only checking the name of the struct type, but these
   445  			// should be extremely rare, if it is even possible at all.
   446  			fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
   447  			fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]
   448  
   449  			// find this struct type amongst the generated structs
   450  			for jj, structB := range b.Models {
   451  				if structB.Name != fieldAStructName {
   452  					continue
   453  				}
   454  
   455  				// check if structB contains a cyclical reference back to structA
   456  				var cyclicalReferenceFound bool
   457  				for _, fieldB := range structB.Fields {
   458  					if !isStruct(fieldB.Type) {
   459  						continue
   460  					}
   461  
   462  					fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
   463  					fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
   464  					if fieldBStructName == structA.Name {
   465  						cyclicalReferenceFound = true
   466  						fieldB.Type = types.NewPointer(fieldB.Type)
   467  						// keep looping in case this struct has additional fields of this type
   468  					}
   469  				}
   470  
   471  				// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
   472  				if cyclicalReferenceFound && ii != jj {
   473  					fieldA.Type = types.NewPointer(fieldA.Type)
   474  					break
   475  				}
   476  			}
   477  		}
   478  	}
   479  }