github.com/99designs/gqlgen@v0.17.45/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"go/types"
     7  	"os"
     8  	"sort"
     9  	"strings"
    10  	"text/template"
    11  
    12  	"github.com/vektah/gqlparser/v2/ast"
    13  
    14  	"github.com/99designs/gqlgen/codegen/config"
    15  	"github.com/99designs/gqlgen/codegen/templates"
    16  	"github.com/99designs/gqlgen/plugin"
    17  )
    18  
    19  //go:embed models.gotpl
    20  var modelTemplate string
    21  
    22  type (
    23  	BuildMutateHook = func(b *ModelBuild) *ModelBuild
    24  	FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
    25  )
    26  
    27  // DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook.
    28  func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
    29  	var err error
    30  	f, err = GoFieldHook(td, fd, f)
    31  	if err != nil {
    32  		return f, err
    33  	}
    34  	return GoTagFieldHook(td, fd, f)
    35  }
    36  
    37  // DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild.
    38  func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    39  	return b
    40  }
    41  
    42  type ModelBuild struct {
    43  	PackageName string
    44  	Interfaces  []*Interface
    45  	Models      []*Object
    46  	Enums       []*Enum
    47  	Scalars     []string
    48  }
    49  
    50  type Interface struct {
    51  	Description string
    52  	Name        string
    53  	Fields      []*Field
    54  	Implements  []string
    55  	OmitCheck   bool
    56  	Models      []*Object
    57  }
    58  
    59  type Object struct {
    60  	Description string
    61  	Name        string
    62  	Fields      []*Field
    63  	Implements  []string
    64  }
    65  
    66  type Field struct {
    67  	Description string
    68  	// Name is the field's name as it appears in the schema
    69  	Name string
    70  	// GoName is the field's name as it appears in the generated Go code
    71  	GoName     string
    72  	Type       types.Type
    73  	Tag        string
    74  	IsResolver bool
    75  	Omittable  bool
    76  }
    77  
    78  type Enum struct {
    79  	Description string
    80  	Name        string
    81  	Values      []*EnumValue
    82  }
    83  
    84  type EnumValue struct {
    85  	Description string
    86  	Name        string
    87  }
    88  
    89  func New() plugin.Plugin {
    90  	return &Plugin{
    91  		MutateHook: DefaultBuildMutateHook,
    92  		FieldHook:  DefaultFieldMutateHook,
    93  	}
    94  }
    95  
    96  type Plugin struct {
    97  	MutateHook BuildMutateHook
    98  	FieldHook  FieldMutateHook
    99  }
   100  
   101  var _ plugin.ConfigMutator = &Plugin{}
   102  
   103  func (m *Plugin) Name() string {
   104  	return "modelgen"
   105  }
   106  
   107  func (m *Plugin) MutateConfig(cfg *config.Config) error {
   108  	b := &ModelBuild{
   109  		PackageName: cfg.Model.Package,
   110  	}
   111  
   112  	for _, schemaType := range cfg.Schema.Types {
   113  		if cfg.Models.UserDefined(schemaType.Name) {
   114  			continue
   115  		}
   116  		switch schemaType.Kind {
   117  		case ast.Interface, ast.Union:
   118  			var fields []*Field
   119  			var err error
   120  			if !cfg.OmitGetters {
   121  				fields, err = m.generateFields(cfg, schemaType)
   122  				if err != nil {
   123  					return err
   124  				}
   125  			}
   126  
   127  			it := &Interface{
   128  				Description: schemaType.Description,
   129  				Name:        schemaType.Name,
   130  				Implements:  schemaType.Interfaces,
   131  				Fields:      fields,
   132  				OmitCheck:   cfg.OmitInterfaceChecks,
   133  			}
   134  
   135  			// if the interface has a key directive as an entity interface, allow it to implement _Entity
   136  			if schemaType.Directives.ForName("key") != nil {
   137  				it.Implements = append(it.Implements, "_Entity")
   138  			}
   139  
   140  			b.Interfaces = append(b.Interfaces, it)
   141  		case ast.Object, ast.InputObject:
   142  			if cfg.IsRoot(schemaType) {
   143  				if !cfg.OmitRootModels {
   144  					b.Models = append(b.Models, &Object{
   145  						Description: schemaType.Description,
   146  						Name:        schemaType.Name,
   147  					})
   148  				}
   149  				continue
   150  			}
   151  
   152  			fields, err := m.generateFields(cfg, schemaType)
   153  			if err != nil {
   154  				return err
   155  			}
   156  
   157  			it := &Object{
   158  				Description: schemaType.Description,
   159  				Name:        schemaType.Name,
   160  				Fields:      fields,
   161  			}
   162  
   163  			// If Interface A implements interface B, and Interface C also implements interface B
   164  			// then both A and C have methods of B.
   165  			// The reason for checking unique is to prevent the same method B from being generated twice.
   166  			uniqueMap := map[string]bool{}
   167  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   168  				if !uniqueMap[implementor.Name] {
   169  					it.Implements = append(it.Implements, implementor.Name)
   170  					uniqueMap[implementor.Name] = true
   171  				}
   172  				// for interface implements
   173  				for _, iface := range implementor.Interfaces {
   174  					if !uniqueMap[iface] {
   175  						it.Implements = append(it.Implements, iface)
   176  						uniqueMap[iface] = true
   177  					}
   178  				}
   179  
   180  			}
   181  
   182  			b.Models = append(b.Models, it)
   183  		case ast.Enum:
   184  			it := &Enum{
   185  				Name:        schemaType.Name,
   186  				Description: schemaType.Description,
   187  			}
   188  
   189  			for _, v := range schemaType.EnumValues {
   190  				it.Values = append(it.Values, &EnumValue{
   191  					Name:        v.Name,
   192  					Description: v.Description,
   193  				})
   194  			}
   195  
   196  			b.Enums = append(b.Enums, it)
   197  		case ast.Scalar:
   198  			b.Scalars = append(b.Scalars, schemaType.Name)
   199  		}
   200  	}
   201  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   202  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   203  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   204  
   205  	// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
   206  	// check for cyclical relationships and recursive structs
   207  	if !cfg.StructFieldsAlwaysPointers {
   208  		findAndHandleCyclicalRelationships(b)
   209  	}
   210  
   211  	for _, it := range b.Enums {
   212  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   213  	}
   214  	for _, it := range b.Models {
   215  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   216  	}
   217  	for _, it := range b.Interfaces {
   218  		// On a given interface we want to keep a reference to all the models that implement it
   219  		for _, model := range b.Models {
   220  			for _, impl := range model.Implements {
   221  				if impl == it.Name {
   222  					// check if this isn't an implementation of an entity interface
   223  					if impl != "_Entity" {
   224  						// If this model has an implementation, add it to the Interface's Models
   225  						it.Models = append(it.Models, model)
   226  					}
   227  				}
   228  			}
   229  		}
   230  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   231  	}
   232  	for _, it := range b.Scalars {
   233  		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
   234  	}
   235  
   236  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   237  		return nil
   238  	}
   239  
   240  	if m.MutateHook != nil {
   241  		b = m.MutateHook(b)
   242  	}
   243  
   244  	getInterfaceByName := func(name string) *Interface {
   245  		// Allow looking up interfaces, so template can generate getters for each field
   246  		for _, i := range b.Interfaces {
   247  			if i.Name == name {
   248  				return i
   249  			}
   250  		}
   251  
   252  		return nil
   253  	}
   254  	gettersGenerated := make(map[string]map[string]struct{})
   255  	generateGetter := func(model *Object, field *Field) string {
   256  		if model == nil || field == nil {
   257  			return ""
   258  		}
   259  
   260  		// Let templates check if a given getter has been generated already
   261  		typeGetters, exists := gettersGenerated[model.Name]
   262  		if !exists {
   263  			typeGetters = make(map[string]struct{})
   264  			gettersGenerated[model.Name] = typeGetters
   265  		}
   266  
   267  		_, exists = typeGetters[field.GoName]
   268  		typeGetters[field.GoName] = struct{}{}
   269  		if exists {
   270  			return ""
   271  		}
   272  
   273  		_, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer)
   274  		var structFieldTypeIsPointer bool
   275  		for _, f := range model.Fields {
   276  			if f.GoName == field.GoName {
   277  				_, structFieldTypeIsPointer = f.Type.(*types.Pointer)
   278  				break
   279  			}
   280  		}
   281  		goType := templates.CurrentImports.LookupType(field.Type)
   282  		if strings.HasPrefix(goType, "[]") {
   283  			getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType)
   284  			getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
   285  			getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
   286  			getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
   287  			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
   288  				getter += "&"
   289  			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
   290  				getter += "*"
   291  			}
   292  			getter += "concrete) }\n"
   293  			getter += "\treturn interfaceSlice\n"
   294  			getter += "}"
   295  			return getter
   296  		} else {
   297  			getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)
   298  
   299  			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
   300  				getter += "&"
   301  			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
   302  				getter += "*"
   303  			}
   304  
   305  			getter += fmt.Sprintf("this.%s }", field.GoName)
   306  			return getter
   307  		}
   308  	}
   309  	funcMap := template.FuncMap{
   310  		"getInterfaceByName": getInterfaceByName,
   311  		"generateGetter":     generateGetter,
   312  	}
   313  	newModelTemplate := modelTemplate
   314  	if cfg.Model.ModelTemplate != "" {
   315  		newModelTemplate = readModelTemplate(cfg.Model.ModelTemplate)
   316  	}
   317  
   318  	err := templates.Render(templates.Options{
   319  		PackageName:     cfg.Model.Package,
   320  		Filename:        cfg.Model.Filename,
   321  		Data:            b,
   322  		GeneratedHeader: true,
   323  		Packages:        cfg.Packages,
   324  		Template:        newModelTemplate,
   325  		Funcs:           funcMap,
   326  	})
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	// We may have generated code in a package we already loaded, so we reload all packages
   332  	// to allow packages to be compared correctly
   333  	cfg.ReloadAllPackages()
   334  
   335  	return nil
   336  }
   337  
   338  func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
   339  	binder := cfg.NewBinder()
   340  	fields := make([]*Field, 0)
   341  
   342  	var omittableType types.Type
   343  
   344  	for _, field := range schemaType.Fields {
   345  		var typ types.Type
   346  		fieldDef := cfg.Schema.Types[field.Type.Name()]
   347  
   348  		if cfg.Models.UserDefined(field.Type.Name()) {
   349  			var err error
   350  			typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   351  			if err != nil {
   352  				return nil, err
   353  			}
   354  		} else {
   355  			switch fieldDef.Kind {
   356  			case ast.Scalar:
   357  				// no user defined model, referencing a default scalar
   358  				typ = types.NewNamed(
   359  					types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   360  					nil,
   361  					nil,
   362  				)
   363  
   364  			case ast.Interface, ast.Union:
   365  				// no user defined model, referencing a generated interface type
   366  				typ = types.NewNamed(
   367  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   368  					types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   369  					nil,
   370  				)
   371  
   372  			case ast.Enum:
   373  				// no user defined model, must reference a generated enum
   374  				typ = types.NewNamed(
   375  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   376  					nil,
   377  					nil,
   378  				)
   379  
   380  			case ast.Object, ast.InputObject:
   381  				// no user defined model, must reference a generated struct
   382  				typ = types.NewNamed(
   383  					types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   384  					types.NewStruct(nil, nil),
   385  					nil,
   386  				)
   387  
   388  			default:
   389  				panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   390  			}
   391  		}
   392  
   393  		name := templates.ToGo(field.Name)
   394  		if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   395  			name = nameOveride
   396  		}
   397  
   398  		typ = binder.CopyModifiersFromAst(field.Type, typ)
   399  
   400  		if cfg.StructFieldsAlwaysPointers {
   401  			if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   402  				typ = types.NewPointer(typ)
   403  			}
   404  		}
   405  
   406  		f := &Field{
   407  			Name:        field.Name,
   408  			GoName:      name,
   409  			Type:        typ,
   410  			Description: field.Description,
   411  			Tag:         getStructTagFromField(cfg, field),
   412  			Omittable:   cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull,
   413  		}
   414  
   415  		if m.FieldHook != nil {
   416  			mf, err := m.FieldHook(schemaType, field, f)
   417  			if err != nil {
   418  				return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
   419  			}
   420  			f = mf
   421  		}
   422  
   423  		if f.IsResolver && cfg.OmitResolverFields {
   424  			continue
   425  		}
   426  
   427  		if f.Omittable {
   428  			if schemaType.Kind != ast.InputObject || field.Type.NonNull {
   429  				return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name)
   430  			}
   431  
   432  			var err error
   433  
   434  			if omittableType == nil {
   435  				omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable")
   436  				if err != nil {
   437  					return nil, err
   438  				}
   439  			}
   440  
   441  			f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type})
   442  			if err != nil {
   443  				return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
   444  			}
   445  		}
   446  
   447  		fields = append(fields, f)
   448  	}
   449  
   450  	// appending extra fields at the end of the fields list.
   451  	modelcfg := cfg.Models[schemaType.Name]
   452  	if len(modelcfg.ExtraFields) > 0 {
   453  		ff := make([]*Field, 0, len(modelcfg.ExtraFields))
   454  		for fname, fspec := range modelcfg.ExtraFields {
   455  			ftype := buildType(fspec.Type)
   456  
   457  			tag := `json:"-"`
   458  			if fspec.OverrideTags != "" {
   459  				tag = fspec.OverrideTags
   460  			}
   461  
   462  			ff = append(ff,
   463  				&Field{
   464  					Name:        fname,
   465  					GoName:      fname,
   466  					Type:        ftype,
   467  					Description: fspec.Description,
   468  					Tag:         tag,
   469  				})
   470  		}
   471  
   472  		sort.Slice(ff, func(i, j int) bool {
   473  			return ff[i].Name < ff[j].Name
   474  		})
   475  
   476  		fields = append(fields, ff...)
   477  	}
   478  
   479  	return fields, nil
   480  }
   481  
   482  func getStructTagFromField(cfg *config.Config, field *ast.FieldDefinition) string {
   483  	if !field.Type.NonNull && (cfg.EnableModelJsonOmitemptyTag == nil || *cfg.EnableModelJsonOmitemptyTag) {
   484  		return `json:"` + field.Name + `,omitempty"`
   485  	}
   486  	return `json:"` + field.Name + `"`
   487  }
   488  
   489  // GoTagFieldHook prepends the goTag directive to the generated Field f.
   490  // When applying the Tag to the field, the field
   491  // name is used if no value argument is present.
   492  func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   493  	args := make([]string, 0)
   494  	for _, goTag := range fd.Directives.ForNames("goTag") {
   495  		key := ""
   496  		value := fd.Name
   497  
   498  		if arg := goTag.Arguments.ForName("key"); arg != nil {
   499  			if k, err := arg.Value.Value(nil); err == nil {
   500  				key = k.(string)
   501  			}
   502  		}
   503  
   504  		if arg := goTag.Arguments.ForName("value"); arg != nil {
   505  			if v, err := arg.Value.Value(nil); err == nil {
   506  				value = v.(string)
   507  			}
   508  		}
   509  
   510  		args = append(args, key+":\""+value+"\"")
   511  	}
   512  
   513  	if len(args) > 0 {
   514  		f.Tag = removeDuplicateTags(f.Tag + " " + strings.Join(args, " "))
   515  	}
   516  
   517  	return f, nil
   518  }
   519  
   520  // splitTagsBySpace split tags by space, except when space is inside quotes
   521  func splitTagsBySpace(tagsString string) []string {
   522  	var tags []string
   523  	var currentTag string
   524  	inQuotes := false
   525  
   526  	for _, c := range tagsString {
   527  		if c == '"' {
   528  			inQuotes = !inQuotes
   529  		}
   530  		if c == ' ' && !inQuotes {
   531  			tags = append(tags, currentTag)
   532  			currentTag = ""
   533  		} else {
   534  			currentTag += string(c)
   535  		}
   536  	}
   537  	tags = append(tags, currentTag)
   538  
   539  	return tags
   540  }
   541  
   542  // containsInvalidSpace checks if the tagsString contains invalid space
   543  func containsInvalidSpace(valuesString string) bool {
   544  	// get rid of quotes
   545  	valuesString = strings.ReplaceAll(valuesString, "\"", "")
   546  	if strings.Contains(valuesString, ",") {
   547  		// split by comma,
   548  		values := strings.Split(valuesString, ",")
   549  		for _, value := range values {
   550  			if strings.TrimSpace(value) != value {
   551  				return true
   552  			}
   553  		}
   554  		return false
   555  	}
   556  	if strings.Contains(valuesString, ";") {
   557  		// split by semicolon, which is common in gorm
   558  		values := strings.Split(valuesString, ";")
   559  		for _, value := range values {
   560  			if strings.TrimSpace(value) != value {
   561  				return true
   562  			}
   563  		}
   564  		return false
   565  	}
   566  	// single value
   567  	if strings.TrimSpace(valuesString) != valuesString {
   568  		return true
   569  	}
   570  	return false
   571  }
   572  
   573  func removeDuplicateTags(t string) string {
   574  	processed := make(map[string]bool)
   575  	tt := splitTagsBySpace(t)
   576  	returnTags := ""
   577  
   578  	// iterate backwards through tags so appended goTag directives are prioritized
   579  	for i := len(tt) - 1; i >= 0; i-- {
   580  		ti := tt[i]
   581  		// check if ti contains ":", and not contains any empty space. if not, tag is in wrong format
   582  		// correct example: json:"name"
   583  		if !strings.Contains(ti, ":") {
   584  			panic(fmt.Errorf("wrong format of tags: %s. goTag directive should be in format: @goTag(key: \"something\", value:\"value\"), ", t))
   585  		}
   586  
   587  		kv := strings.Split(ti, ":")
   588  		if len(kv) == 0 || processed[kv[0]] {
   589  			continue
   590  		}
   591  
   592  		key := kv[0]
   593  		value := strings.Join(kv[1:], ":")
   594  		processed[key] = true
   595  		if len(returnTags) > 0 {
   596  			returnTags = " " + returnTags
   597  		}
   598  
   599  		isContained := containsInvalidSpace(value)
   600  		if isContained {
   601  			panic(fmt.Errorf("tag value should not contain any leading or trailing spaces: %s", value))
   602  		}
   603  
   604  		returnTags = key + ":" + value + returnTags
   605  	}
   606  
   607  	return returnTags
   608  }
   609  
   610  // GoFieldHook applies the goField directive to the generated Field f.
   611  func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   612  	args := make([]string, 0)
   613  	_ = args
   614  	for _, goField := range fd.Directives.ForNames("goField") {
   615  		if arg := goField.Arguments.ForName("name"); arg != nil {
   616  			if k, err := arg.Value.Value(nil); err == nil {
   617  				f.GoName = k.(string)
   618  			}
   619  		}
   620  
   621  		if arg := goField.Arguments.ForName("forceResolver"); arg != nil {
   622  			if k, err := arg.Value.Value(nil); err == nil {
   623  				f.IsResolver = k.(bool)
   624  			}
   625  		}
   626  
   627  		if arg := goField.Arguments.ForName("omittable"); arg != nil {
   628  			if k, err := arg.Value.Value(nil); err == nil {
   629  				f.Omittable = k.(bool)
   630  			}
   631  		}
   632  	}
   633  	return f, nil
   634  }
   635  
   636  func isStruct(t types.Type) bool {
   637  	_, is := t.Underlying().(*types.Struct)
   638  	return is
   639  }
   640  
   641  // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
   642  // with pointers. These relationships will produce compilation errors if they are not pointers.
   643  // Also handles recursive structs.
   644  func findAndHandleCyclicalRelationships(b *ModelBuild) {
   645  	for ii, structA := range b.Models {
   646  		for _, fieldA := range structA.Fields {
   647  			if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
   648  				fmt.Print()
   649  			}
   650  			if !isStruct(fieldA.Type) {
   651  				continue
   652  			}
   653  
   654  			// the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA"
   655  			// we only want the part after the last dot: "LoopA"
   656  			// this could lead to false positives, as we are only checking the name of the struct type, but these
   657  			// should be extremely rare, if it is even possible at all.
   658  			fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
   659  			fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]
   660  
   661  			// find this struct type amongst the generated structs
   662  			for jj, structB := range b.Models {
   663  				if structB.Name != fieldAStructName {
   664  					continue
   665  				}
   666  
   667  				// check if structB contains a cyclical reference back to structA
   668  				var cyclicalReferenceFound bool
   669  				for _, fieldB := range structB.Fields {
   670  					if !isStruct(fieldB.Type) {
   671  						continue
   672  					}
   673  
   674  					fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
   675  					fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
   676  					if fieldBStructName == structA.Name {
   677  						cyclicalReferenceFound = true
   678  						fieldB.Type = types.NewPointer(fieldB.Type)
   679  						// keep looping in case this struct has additional fields of this type
   680  					}
   681  				}
   682  
   683  				// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
   684  				if cyclicalReferenceFound && ii != jj {
   685  					fieldA.Type = types.NewPointer(fieldA.Type)
   686  					break
   687  				}
   688  			}
   689  		}
   690  	}
   691  }
   692  
   693  func readModelTemplate(customModelTemplate string) string {
   694  	contentBytes, err := os.ReadFile(customModelTemplate)
   695  	if err != nil {
   696  		panic(err)
   697  	}
   698  	return string(contentBytes)
   699  }