github.com/senomas/gqlgen@v0.17.11-0.20220626120754-9aee61b0716a/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/99designs/gqlgen/codegen/config"
    10  	"github.com/99designs/gqlgen/codegen/templates"
    11  	"github.com/99designs/gqlgen/plugin"
    12  	"github.com/vektah/gqlparser/v2/ast"
    13  )
    14  
    15  type BuildMutateHook = func(b *ModelBuild) *ModelBuild
    16  
    17  type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
    18  
    19  // defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook.
    20  func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
    21  	return GoTagFieldHook(td, fd, f)
    22  }
    23  
    24  func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    25  	return b
    26  }
    27  
    28  type ModelBuild struct {
    29  	PackageName string
    30  	Interfaces  []*Interface
    31  	Models      []*Object
    32  	Enums       []*Enum
    33  	Scalars     []string
    34  }
    35  
    36  type Interface struct {
    37  	Description string
    38  	Name        string
    39  	Implements  []string
    40  }
    41  
    42  type Object struct {
    43  	Description string
    44  	Name        string
    45  	Fields      []*Field
    46  	Implements  []string
    47  }
    48  
    49  type Field struct {
    50  	Description string
    51  	GoName      string
    52  	Type        types.Type
    53  	Tag         string
    54  	Ref         string
    55  }
    56  
    57  type Enum struct {
    58  	Description string
    59  	Name        string
    60  	Values      []*EnumValue
    61  }
    62  
    63  type EnumValue struct {
    64  	Description string
    65  	Name        string
    66  }
    67  
    68  func New() plugin.Plugin {
    69  	return &Plugin{
    70  		MutateHook: defaultBuildMutateHook,
    71  		FieldHook:  defaultFieldMutateHook,
    72  	}
    73  }
    74  
    75  type Plugin struct {
    76  	MutateHook BuildMutateHook
    77  	FieldHook  FieldMutateHook
    78  }
    79  
    80  var _ plugin.ConfigMutator = &Plugin{}
    81  
    82  func (m *Plugin) Name() string {
    83  	return "modelgen"
    84  }
    85  
    86  func (m *Plugin) MutateConfig(cfg *config.Config) error {
    87  	binder := cfg.NewBinder()
    88  
    89  	b := &ModelBuild{
    90  		PackageName: cfg.Model.Package,
    91  	}
    92  
    93  	for _, schemaType := range cfg.Schema.Types {
    94  		if cfg.Models.UserDefined(schemaType.Name) {
    95  			continue
    96  		}
    97  		switch schemaType.Kind {
    98  		case ast.Interface, ast.Union:
    99  			it := &Interface{
   100  				Description: schemaType.Description,
   101  				Name:        schemaType.Name,
   102  				Implements:  schemaType.Interfaces,
   103  			}
   104  
   105  			b.Interfaces = append(b.Interfaces, it)
   106  		case ast.Object, ast.InputObject:
   107  			if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
   108  				continue
   109  			}
   110  			it := &Object{
   111  				Description: schemaType.Description,
   112  				Name:        schemaType.Name,
   113  			}
   114  
   115  			// If Interface A implements interface B, and Interface C also implements interface B
   116  			// then both A and C have methods of B.
   117  			// The reason for checking unique is to prevent the same method B from being generated twice.
   118  			uniqueMap := map[string]bool{}
   119  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   120  				if !uniqueMap[implementor.Name] {
   121  					it.Implements = append(it.Implements, implementor.Name)
   122  					uniqueMap[implementor.Name] = true
   123  				}
   124  				// for interface implements
   125  				for _, iface := range implementor.Interfaces {
   126  					if !uniqueMap[iface] {
   127  						it.Implements = append(it.Implements, iface)
   128  						uniqueMap[iface] = true
   129  					}
   130  				}
   131  			}
   132  
   133  			for _, field := range schemaType.Fields {
   134  				var typ types.Type
   135  				fieldDef := cfg.Schema.Types[field.Type.Name()]
   136  
   137  				if cfg.Models.UserDefined(field.Type.Name()) {
   138  					var err error
   139  					typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   140  					if err != nil {
   141  						return err
   142  					}
   143  				} else {
   144  					switch fieldDef.Kind {
   145  					case ast.Scalar:
   146  						// no user defined model, referencing a default scalar
   147  						typ = types.NewNamed(
   148  							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   149  							nil,
   150  							nil,
   151  						)
   152  
   153  					case ast.Interface, ast.Union:
   154  						// no user defined model, referencing a generated interface type
   155  						typ = types.NewNamed(
   156  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   157  							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   158  							nil,
   159  						)
   160  
   161  					case ast.Enum:
   162  						// no user defined model, must reference a generated enum
   163  						typ = types.NewNamed(
   164  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   165  							nil,
   166  							nil,
   167  						)
   168  
   169  					case ast.Object, ast.InputObject:
   170  						// no user defined model, must reference a generated struct
   171  						typ = types.NewNamed(
   172  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   173  							types.NewStruct(nil, nil),
   174  							nil,
   175  						)
   176  
   177  					default:
   178  						panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   179  					}
   180  				}
   181  
   182  				name := field.Name
   183  				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   184  					name = nameOveride
   185  				}
   186  
   187  				typ = binder.CopyModifiersFromAst(field.Type, typ)
   188  
   189  				if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   190  					typ = types.NewPointer(typ)
   191  				}
   192  
   193  				ref := ""
   194  				tag := `json:"` + field.Name + `"`
   195  				if directive := field.Directives.ForName("gorm"); directive != nil {
   196  					arg := directive.Arguments.ForName("tag")
   197  					if arg != nil {
   198  						tag = fmt.Sprintf(`%s gorm:"%s"`, tag, arg.Value.Raw)
   199  					}
   200  					arg = directive.Arguments.ForName("ref")
   201  					if arg != nil {
   202  						arg2 := directive.Arguments.ForName("refTag")
   203  						if arg2 != nil {
   204  							ref = fmt.Sprintf("%s `json:\"-\" gorm:\"%s\"`", arg.Value.Raw, arg2.Value.Raw)
   205  						} else {
   206  							ref = fmt.Sprintf("%s `json:\"-\"`", arg.Value.Raw)
   207  						}
   208  					}
   209  				}
   210  
   211  				f := &Field{
   212  					GoName:      name,
   213  					Type:        typ,
   214  					Description: field.Description,
   215  					Tag:         tag,
   216  					Ref:         ref,
   217  				}
   218  
   219  				if m.FieldHook != nil {
   220  					mf, err := m.FieldHook(schemaType, field, f)
   221  					if err != nil {
   222  						return fmt.Errorf("generror: field %v.%v: %w", it.Name, field.Name, err)
   223  					}
   224  					f = mf
   225  				}
   226  
   227  				it.Fields = append(it.Fields, f)
   228  			}
   229  
   230  			b.Models = append(b.Models, it)
   231  		case ast.Enum:
   232  			it := &Enum{
   233  				Name:        schemaType.Name,
   234  				Description: schemaType.Description,
   235  			}
   236  
   237  			for _, v := range schemaType.EnumValues {
   238  				it.Values = append(it.Values, &EnumValue{
   239  					Name:        v.Name,
   240  					Description: v.Description,
   241  				})
   242  			}
   243  
   244  			b.Enums = append(b.Enums, it)
   245  		case ast.Scalar:
   246  			b.Scalars = append(b.Scalars, schemaType.Name)
   247  		}
   248  	}
   249  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   250  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   251  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   252  
   253  	// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
   254  	// check for cyclical relationships and recursive structs
   255  	if !cfg.StructFieldsAlwaysPointers {
   256  		findAndHandleCyclicalRelationships(b)
   257  	}
   258  
   259  	for _, it := range b.Enums {
   260  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   261  	}
   262  	for _, it := range b.Models {
   263  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   264  	}
   265  	for _, it := range b.Interfaces {
   266  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   267  	}
   268  	for _, it := range b.Scalars {
   269  		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
   270  	}
   271  
   272  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   273  		return nil
   274  	}
   275  
   276  	if m.MutateHook != nil {
   277  		b = m.MutateHook(b)
   278  	}
   279  
   280  	err := templates.Render(templates.Options{
   281  		PackageName:     cfg.Model.Package,
   282  		Filename:        cfg.Model.Filename,
   283  		Data:            b,
   284  		GeneratedHeader: true,
   285  		Packages:        cfg.Packages,
   286  	})
   287  	if err != nil {
   288  		return err
   289  	}
   290  
   291  	// We may have generated code in a package we already loaded, so we reload all packages
   292  	// to allow packages to be compared correctly
   293  	cfg.ReloadAllPackages()
   294  
   295  	return nil
   296  }
   297  
   298  // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field
   299  // name is used when no value argument is present.
   300  func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   301  	args := make([]string, 0)
   302  	for _, goTag := range fd.Directives.ForNames("goTag") {
   303  		key := ""
   304  		value := fd.Name
   305  
   306  		if arg := goTag.Arguments.ForName("key"); arg != nil {
   307  			if k, err := arg.Value.Value(nil); err == nil {
   308  				key = k.(string)
   309  			}
   310  		}
   311  
   312  		if arg := goTag.Arguments.ForName("value"); arg != nil {
   313  			if v, err := arg.Value.Value(nil); err == nil {
   314  				value = v.(string)
   315  			}
   316  		}
   317  
   318  		args = append(args, key+":\""+value+"\"")
   319  	}
   320  
   321  	if len(args) > 0 {
   322  		f.Tag = f.Tag + " " + strings.Join(args, " ")
   323  	}
   324  
   325  	return f, nil
   326  }
   327  
   328  func isStruct(t types.Type) bool {
   329  	_, is := t.Underlying().(*types.Struct)
   330  	return is
   331  }
   332  
   333  // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
   334  // with pointers. These relationships will produce compilation errors if they are not pointers.
   335  // Also handles recursive structs.
   336  func findAndHandleCyclicalRelationships(b *ModelBuild) {
   337  	for ii, structA := range b.Models {
   338  		for _, fieldA := range structA.Fields {
   339  			if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
   340  				fmt.Print()
   341  			}
   342  			if !isStruct(fieldA.Type) {
   343  				continue
   344  			}
   345  
   346  			// the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA"
   347  			// we only want the part after the last dot: "LoopA"
   348  			// this could lead to false positives, as we are only checking the name of the struct type, but these
   349  			// should be extremely rare, if it is even possible at all.
   350  			fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
   351  			fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]
   352  
   353  			// find this struct type amongst the generated structs
   354  			for jj, structB := range b.Models {
   355  				if structB.Name != fieldAStructName {
   356  					continue
   357  				}
   358  				// check if structB contains a cyclical reference back to structA
   359  				var cyclicalReferenceFound bool
   360  				for _, fieldB := range structB.Fields {
   361  					if !isStruct(fieldB.Type) {
   362  						continue
   363  					}
   364  
   365  					fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
   366  					fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
   367  					if fieldBStructName == structA.Name {
   368  						cyclicalReferenceFound = true
   369  						fieldB.Type = types.NewPointer(fieldB.Type)
   370  						// keep looping in case this struct has additional fields of this type
   371  					}
   372  				}
   373  
   374  				// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
   375  				if cyclicalReferenceFound && ii != jj {
   376  					fieldA.Type = types.NewPointer(fieldA.Type)
   377  					break
   378  				}
   379  			}
   380  		}
   381  	}
   382  }
   383  
   384  func (r *Plugin) InjectSourceEarly() *ast.Source {
   385  	return &ast.Source{
   386  		Name: "plugin/directives.graphql",
   387  		Input: `
   388  			directive @goField(
   389  				forceResolver: Boolean
   390  				name: String
   391  			) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
   392  
   393        directive @gorm(tag: String, ref: String, refTag: String) on FIELD_DEFINITION
   394  
   395  			scalar Time
   396  		`,
   397  		BuiltIn: false,
   398  	}
   399  }