github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"go/types"
     7  	"sort"
     8  	"strings"
     9  	"text/template"
    10  
    11  	"github.com/mstephano/gqlgen-schemagen/codegen/config"
    12  	"github.com/mstephano/gqlgen-schemagen/codegen/templates"
    13  	"github.com/mstephano/gqlgen-schemagen/plugin"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  )
    16  
    17  //go:embed models.gotpl
    18  var modelTemplate string
    19  
    20  type (
    21  	BuildMutateHook = func(b *ModelBuild) *ModelBuild
    22  	FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
    23  )
    24  
    25  // DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook.
    26  func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
    27  	var err error
    28  	f, err = GoFieldHook(td, fd, f)
    29  	if err != nil {
    30  		return f, err
    31  	}
    32  	return GoTagFieldHook(td, fd, f)
    33  }
    34  
    35  // DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild.
    36  func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    37  	return b
    38  }
    39  
    40  type ModelBuild struct {
    41  	PackageName string
    42  	Interfaces  []*Interface
    43  	Models      []*Object
    44  	Enums       []*Enum
    45  	Scalars     []string
    46  }
    47  
    48  type Interface struct {
    49  	Description string
    50  	Name        string
    51  	Fields      []*Field
    52  	Implements  []string
    53  }
    54  
    55  type Object struct {
    56  	Description string
    57  	Name        string
    58  	Fields      []*Field
    59  	Implements  []string
    60  }
    61  
    62  type Field struct {
    63  	Description string
    64  	// Name is the field's name as it appears in the schema
    65  	Name string
    66  	// GoName is the field's name as it appears in the generated Go code
    67  	GoName string
    68  	Type   types.Type
    69  	Tag    string
    70  }
    71  
    72  type Enum struct {
    73  	Description string
    74  	Name        string
    75  	Values      []*EnumValue
    76  }
    77  
    78  type EnumValue struct {
    79  	Description string
    80  	Name        string
    81  }
    82  
    83  func New() plugin.Plugin {
    84  	return &Plugin{
    85  		MutateHook: DefaultBuildMutateHook,
    86  		FieldHook:  DefaultFieldMutateHook,
    87  	}
    88  }
    89  
    90  type Plugin struct {
    91  	MutateHook BuildMutateHook
    92  	FieldHook  FieldMutateHook
    93  }
    94  
    95  var _ plugin.ConfigMutator = &Plugin{}
    96  
    97  func (m *Plugin) Name() string {
    98  	return "modelgen"
    99  }
   100  
   101  func (m *Plugin) MutateConfig(cfg *config.Config) error {
   102  	b := &ModelBuild{
   103  		PackageName: cfg.Model.Package,
   104  	}
   105  
   106  	for _, schemaType := range cfg.Schema.Types {
   107  		if cfg.Models.UserDefined(schemaType.Name) {
   108  			continue
   109  		}
   110  		switch schemaType.Kind {
   111  		case ast.Interface, ast.Union:
   112  			var fields []*Field
   113  			var err error
   114  			if !cfg.OmitGetters {
   115  				fields, err = m.generateFields(cfg, schemaType)
   116  				if err != nil {
   117  					return err
   118  				}
   119  			}
   120  
   121  			it := &Interface{
   122  				Description: schemaType.Description,
   123  				Name:        schemaType.Name,
   124  				Implements:  schemaType.Interfaces,
   125  				Fields:      fields,
   126  			}
   127  
   128  			b.Interfaces = append(b.Interfaces, it)
   129  		case ast.Object, ast.InputObject:
   130  			if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
   131  				continue
   132  			}
   133  
   134  			fields, err := m.generateFields(cfg, schemaType)
   135  			if err != nil {
   136  				return err
   137  			}
   138  
   139  			it := &Object{
   140  				Description: schemaType.Description,
   141  				Name:        schemaType.Name,
   142  				Fields:      fields,
   143  			}
   144  
   145  			// If Interface A implements interface B, and Interface C also implements interface B
   146  			// then both A and C have methods of B.
   147  			// The reason for checking unique is to prevent the same method B from being generated twice.
   148  			uniqueMap := map[string]bool{}
   149  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   150  				if !uniqueMap[implementor.Name] {
   151  					it.Implements = append(it.Implements, implementor.Name)
   152  					uniqueMap[implementor.Name] = true
   153  				}
   154  				// for interface implements
   155  				for _, iface := range implementor.Interfaces {
   156  					if !uniqueMap[iface] {
   157  						it.Implements = append(it.Implements, iface)
   158  						uniqueMap[iface] = true
   159  					}
   160  				}
   161  			}
   162  
   163  			b.Models = append(b.Models, it)
   164  		case ast.Enum:
   165  			it := &Enum{
   166  				Name:        schemaType.Name,
   167  				Description: schemaType.Description,
   168  			}
   169  
   170  			for _, v := range schemaType.EnumValues {
   171  				it.Values = append(it.Values, &EnumValue{
   172  					Name:        v.Name,
   173  					Description: v.Description,
   174  				})
   175  			}
   176  
   177  			b.Enums = append(b.Enums, it)
   178  		case ast.Scalar:
   179  			b.Scalars = append(b.Scalars, schemaType.Name)
   180  		}
   181  	}
   182  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   183  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   184  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   185  
   186  	// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
   187  	// check for cyclical relationships and recursive structs
   188  	if !cfg.StructFieldsAlwaysPointers {
   189  		findAndHandleCyclicalRelationships(b)
   190  	}
   191  
   192  	for _, it := range b.Enums {
   193  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   194  	}
   195  	for _, it := range b.Models {
   196  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   197  	}
   198  	for _, it := range b.Interfaces {
   199  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   200  	}
   201  	for _, it := range b.Scalars {
   202  		cfg.Models.Add(it, "github.com/mstephano/gqlgen-schemagen/graphql.String")
   203  	}
   204  
   205  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   206  		return nil
   207  	}
   208  
   209  	if m.MutateHook != nil {
   210  		b = m.MutateHook(b)
   211  	}
   212  
   213  	getInterfaceByName := func(name string) *Interface {
   214  		// Allow looking up interfaces, so template can generate getters for each field
   215  		for _, i := range b.Interfaces {
   216  			if i.Name == name {
   217  				return i
   218  			}
   219  		}
   220  
   221  		return nil
   222  	}
   223  	gettersGenerated := make(map[string]map[string]struct{})
   224  	generateGetter := func(model *Object, field *Field) string {
   225  		if model == nil || field == nil {
   226  			return ""
   227  		}
   228  
   229  		// Let templates check if a given getter has been generated already
   230  		typeGetters, exists := gettersGenerated[model.Name]
   231  		if !exists {
   232  			typeGetters = make(map[string]struct{})
   233  			gettersGenerated[model.Name] = typeGetters
   234  		}
   235  
   236  		_, exists = typeGetters[field.GoName]
   237  		typeGetters[field.GoName] = struct{}{}
   238  		if exists {
   239  			return ""
   240  		}
   241  
   242  		_, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer)
   243  		var structFieldTypeIsPointer bool
   244  		for _, f := range model.Fields {
   245  			if f.GoName == field.GoName {
   246  				_, structFieldTypeIsPointer = f.Type.(*types.Pointer)
   247  				break
   248  			}
   249  		}
   250  		goType := templates.CurrentImports.LookupType(field.Type)
   251  		if strings.HasPrefix(goType, "[]") {
   252  			getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType)
   253  			getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
   254  			getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
   255  			getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
   256  			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
   257  				getter += "&"
   258  			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
   259  				getter += "*"
   260  			}
   261  			getter += "concrete) }\n"
   262  			getter += "\treturn interfaceSlice\n"
   263  			getter += "}"
   264  			return getter
   265  		} else {
   266  			getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)
   267  
   268  			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
   269  				getter += "&"
   270  			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
   271  				getter += "*"
   272  			}
   273  
   274  			getter += fmt.Sprintf("this.%s }", field.GoName)
   275  			return getter
   276  		}
   277  	}
   278  	funcMap := template.FuncMap{
   279  		"getInterfaceByName": getInterfaceByName,
   280  		"generateGetter":     generateGetter,
   281  	}
   282  
   283  	err := templates.Render(templates.Options{
   284  		PackageName:     cfg.Model.Package,
   285  		Filename:        cfg.Model.Filename,
   286  		Data:            b,
   287  		GeneratedHeader: true,
   288  		Packages:        cfg.Packages,
   289  		Template:        modelTemplate,
   290  		Funcs:           funcMap,
   291  	})
   292  	if err != nil {
   293  		return err
   294  	}
   295  
   296  	// We may have generated code in a package we already loaded, so we reload all packages
   297  	// to allow packages to be compared correctly
   298  	cfg.ReloadAllPackages()
   299  
   300  	return nil
   301  }
   302  
   303  func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
   304  	binder := cfg.NewBinder()
   305  	fields := make([]*Field, 0)
   306  
   307  	for _, field := range schemaType.Fields {
   308  		var typ types.Type
   309  		fieldDef := cfg.Schema.Types[field.Type.Name()]
   310  
   311  		if cfg.Models.UserDefined(field.Type.Name()) {
   312  			var err error
   313  			typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   314  			if err != nil {
   315  				return nil, err
   316  			}
   317  		} else {
   318  			switch fieldDef.Kind {
   319  			case ast.Scalar:
   320  				// no user defined model, referencing a default scalar
   321  				typ = types.NewNamed(
   322  					types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   323  					nil,
   324  					nil,
   325  				)
   326  
   327  			case ast.Interface, ast.Union:
   328  				// no user defined model, referencing a generated interface type
   329  				typ = types.NewNamed(
   330  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   331  					types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   332  					nil,
   333  				)
   334  
   335  			case ast.Enum:
   336  				// no user defined model, must reference a generated enum
   337  				typ = types.NewNamed(
   338  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   339  					nil,
   340  					nil,
   341  				)
   342  
   343  			case ast.Object, ast.InputObject:
   344  				// no user defined model, must reference a generated struct
   345  				typ = types.NewNamed(
   346  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   347  					types.NewStruct(nil, nil),
   348  					nil,
   349  				)
   350  
   351  			default:
   352  				panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   353  			}
   354  		}
   355  
   356  		name := templates.ToGo(field.Name)
   357  		if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   358  			name = nameOveride
   359  		}
   360  
   361  		typ = binder.CopyModifiersFromAst(field.Type, typ)
   362  
   363  		if cfg.StructFieldsAlwaysPointers {
   364  			if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   365  				typ = types.NewPointer(typ)
   366  			}
   367  		}
   368  
   369  		f := &Field{
   370  			Name:        field.Name,
   371  			GoName:      name,
   372  			Type:        typ,
   373  			Description: field.Description,
   374  			Tag:         `json:"` + field.Name + `"`,
   375  		}
   376  
   377  		if m.FieldHook != nil {
   378  			mf, err := m.FieldHook(schemaType, field, f)
   379  			if err != nil {
   380  				return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
   381  			}
   382  			f = mf
   383  		}
   384  
   385  		fields = append(fields, f)
   386  	}
   387  
   388  	return fields, nil
   389  }
   390  
   391  // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field
   392  // name is used when no value argument is present.
   393  func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   394  	args := make([]string, 0)
   395  	for _, goTag := range fd.Directives.ForNames("goTag") {
   396  		key := ""
   397  		value := fd.Name
   398  
   399  		if arg := goTag.Arguments.ForName("key"); arg != nil {
   400  			if k, err := arg.Value.Value(nil); err == nil {
   401  				key = k.(string)
   402  			}
   403  		}
   404  
   405  		if arg := goTag.Arguments.ForName("value"); arg != nil {
   406  			if v, err := arg.Value.Value(nil); err == nil {
   407  				value = v.(string)
   408  			}
   409  		}
   410  
   411  		args = append(args, key+":\""+value+"\"")
   412  	}
   413  
   414  	if len(args) > 0 {
   415  		f.Tag = f.Tag + " " + strings.Join(args, " ")
   416  	}
   417  
   418  	return f, nil
   419  }
   420  
   421  // GoFieldHook applies the goField directive to the generated Field f.
   422  func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   423  	args := make([]string, 0)
   424  	_ = args
   425  	for _, goField := range fd.Directives.ForNames("goField") {
   426  		if arg := goField.Arguments.ForName("name"); arg != nil {
   427  			if k, err := arg.Value.Value(nil); err == nil {
   428  				f.GoName = k.(string)
   429  			}
   430  		}
   431  	}
   432  	return f, nil
   433  }
   434  
   435  func isStruct(t types.Type) bool {
   436  	_, is := t.Underlying().(*types.Struct)
   437  	return is
   438  }
   439  
   440  // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
   441  // with pointers. These relationships will produce compilation errors if they are not pointers.
   442  // Also handles recursive structs.
   443  func findAndHandleCyclicalRelationships(b *ModelBuild) {
   444  	for ii, structA := range b.Models {
   445  		for _, fieldA := range structA.Fields {
   446  			if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
   447  				fmt.Print()
   448  			}
   449  			if !isStruct(fieldA.Type) {
   450  				continue
   451  			}
   452  
   453  			// the field Type string will be in the form "github.com/mstephano/gqlgen-schemagen/codegen/testserver/followschema.LoopA"
   454  			// we only want the part after the last dot: "LoopA"
   455  			// this could lead to false positives, as we are only checking the name of the struct type, but these
   456  			// should be extremely rare, if it is even possible at all.
   457  			fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
   458  			fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]
   459  
   460  			// find this struct type amongst the generated structs
   461  			for jj, structB := range b.Models {
   462  				if structB.Name != fieldAStructName {
   463  					continue
   464  				}
   465  
   466  				// check if structB contains a cyclical reference back to structA
   467  				var cyclicalReferenceFound bool
   468  				for _, fieldB := range structB.Fields {
   469  					if !isStruct(fieldB.Type) {
   470  						continue
   471  					}
   472  
   473  					fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
   474  					fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
   475  					if fieldBStructName == structA.Name {
   476  						cyclicalReferenceFound = true
   477  						fieldB.Type = types.NewPointer(fieldB.Type)
   478  						// keep looping in case this struct has additional fields of this type
   479  					}
   480  				}
   481  
   482  				// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
   483  				if cyclicalReferenceFound && ii != jj {
   484  					fieldA.Type = types.NewPointer(fieldA.Type)
   485  					break
   486  				}
   487  			}
   488  		}
   489  	}
   490  }