github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 _ "embed" 5 "fmt" 6 "go/types" 7 "sort" 8 "strings" 9 "text/template" 10 11 "github.com/mstephano/gqlgen-schemagen/codegen/config" 12 "github.com/mstephano/gqlgen-schemagen/codegen/templates" 13 "github.com/mstephano/gqlgen-schemagen/plugin" 14 "github.com/vektah/gqlparser/v2/ast" 15 ) 16 17 //go:embed models.gotpl 18 var modelTemplate string 19 20 type ( 21 BuildMutateHook = func(b *ModelBuild) *ModelBuild 22 FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) 23 ) 24 25 // DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook. 26 func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 27 var err error 28 f, err = GoFieldHook(td, fd, f) 29 if err != nil { 30 return f, err 31 } 32 return GoTagFieldHook(td, fd, f) 33 } 34 35 // DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild. 36 func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild { 37 return b 38 } 39 40 type ModelBuild struct { 41 PackageName string 42 Interfaces []*Interface 43 Models []*Object 44 Enums []*Enum 45 Scalars []string 46 } 47 48 type Interface struct { 49 Description string 50 Name string 51 Fields []*Field 52 Implements []string 53 } 54 55 type Object struct { 56 Description string 57 Name string 58 Fields []*Field 59 Implements []string 60 } 61 62 type Field struct { 63 Description string 64 // Name is the field's name as it appears in the schema 65 Name string 66 // GoName is the field's name as it appears in the generated Go code 67 GoName string 68 Type types.Type 69 Tag string 70 } 71 72 type Enum struct { 73 Description string 74 Name string 75 Values []*EnumValue 76 } 77 78 type EnumValue struct { 79 Description string 80 Name string 81 } 82 83 func New() plugin.Plugin { 84 return &Plugin{ 85 MutateHook: DefaultBuildMutateHook, 86 FieldHook: DefaultFieldMutateHook, 87 } 88 } 89 90 type Plugin struct { 91 MutateHook BuildMutateHook 92 FieldHook FieldMutateHook 93 } 94 95 var _ plugin.ConfigMutator = &Plugin{} 96 97 func (m *Plugin) Name() string { 98 return "modelgen" 99 } 100 101 func (m *Plugin) MutateConfig(cfg *config.Config) error { 102 b := &ModelBuild{ 103 PackageName: cfg.Model.Package, 104 } 105 106 for _, schemaType := range cfg.Schema.Types { 107 if cfg.Models.UserDefined(schemaType.Name) { 108 continue 109 } 110 switch schemaType.Kind { 111 case ast.Interface, ast.Union: 112 var fields []*Field 113 var err error 114 if !cfg.OmitGetters { 115 fields, err = m.generateFields(cfg, schemaType) 116 if err != nil { 117 return err 118 } 119 } 120 121 it := &Interface{ 122 Description: schemaType.Description, 123 Name: schemaType.Name, 124 Implements: schemaType.Interfaces, 125 Fields: fields, 126 } 127 128 b.Interfaces = append(b.Interfaces, it) 129 case ast.Object, ast.InputObject: 130 if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { 131 continue 132 } 133 134 fields, err := m.generateFields(cfg, schemaType) 135 if err != nil { 136 return err 137 } 138 139 it := &Object{ 140 Description: schemaType.Description, 141 Name: schemaType.Name, 142 Fields: fields, 143 } 144 145 // If Interface A implements interface B, and Interface C also implements interface B 146 // then both A and C have methods of B. 147 // The reason for checking unique is to prevent the same method B from being generated twice. 148 uniqueMap := map[string]bool{} 149 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 150 if !uniqueMap[implementor.Name] { 151 it.Implements = append(it.Implements, implementor.Name) 152 uniqueMap[implementor.Name] = true 153 } 154 // for interface implements 155 for _, iface := range implementor.Interfaces { 156 if !uniqueMap[iface] { 157 it.Implements = append(it.Implements, iface) 158 uniqueMap[iface] = true 159 } 160 } 161 } 162 163 b.Models = append(b.Models, it) 164 case ast.Enum: 165 it := &Enum{ 166 Name: schemaType.Name, 167 Description: schemaType.Description, 168 } 169 170 for _, v := range schemaType.EnumValues { 171 it.Values = append(it.Values, &EnumValue{ 172 Name: v.Name, 173 Description: v.Description, 174 }) 175 } 176 177 b.Enums = append(b.Enums, it) 178 case ast.Scalar: 179 b.Scalars = append(b.Scalars, schemaType.Name) 180 } 181 } 182 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 183 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 184 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 185 186 // if we are not just turning all struct-type fields in generated structs into pointers, we need to at least 187 // check for cyclical relationships and recursive structs 188 if !cfg.StructFieldsAlwaysPointers { 189 findAndHandleCyclicalRelationships(b) 190 } 191 192 for _, it := range b.Enums { 193 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 194 } 195 for _, it := range b.Models { 196 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 197 } 198 for _, it := range b.Interfaces { 199 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 200 } 201 for _, it := range b.Scalars { 202 cfg.Models.Add(it, "github.com/mstephano/gqlgen-schemagen/graphql.String") 203 } 204 205 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 206 return nil 207 } 208 209 if m.MutateHook != nil { 210 b = m.MutateHook(b) 211 } 212 213 getInterfaceByName := func(name string) *Interface { 214 // Allow looking up interfaces, so template can generate getters for each field 215 for _, i := range b.Interfaces { 216 if i.Name == name { 217 return i 218 } 219 } 220 221 return nil 222 } 223 gettersGenerated := make(map[string]map[string]struct{}) 224 generateGetter := func(model *Object, field *Field) string { 225 if model == nil || field == nil { 226 return "" 227 } 228 229 // Let templates check if a given getter has been generated already 230 typeGetters, exists := gettersGenerated[model.Name] 231 if !exists { 232 typeGetters = make(map[string]struct{}) 233 gettersGenerated[model.Name] = typeGetters 234 } 235 236 _, exists = typeGetters[field.GoName] 237 typeGetters[field.GoName] = struct{}{} 238 if exists { 239 return "" 240 } 241 242 _, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer) 243 var structFieldTypeIsPointer bool 244 for _, f := range model.Fields { 245 if f.GoName == field.GoName { 246 _, structFieldTypeIsPointer = f.Type.(*types.Pointer) 247 break 248 } 249 } 250 goType := templates.CurrentImports.LookupType(field.Type) 251 if strings.HasPrefix(goType, "[]") { 252 getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType) 253 getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName) 254 getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName) 255 getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName) 256 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 257 getter += "&" 258 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 259 getter += "*" 260 } 261 getter += "concrete) }\n" 262 getter += "\treturn interfaceSlice\n" 263 getter += "}" 264 return getter 265 } else { 266 getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType) 267 268 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 269 getter += "&" 270 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 271 getter += "*" 272 } 273 274 getter += fmt.Sprintf("this.%s }", field.GoName) 275 return getter 276 } 277 } 278 funcMap := template.FuncMap{ 279 "getInterfaceByName": getInterfaceByName, 280 "generateGetter": generateGetter, 281 } 282 283 err := templates.Render(templates.Options{ 284 PackageName: cfg.Model.Package, 285 Filename: cfg.Model.Filename, 286 Data: b, 287 GeneratedHeader: true, 288 Packages: cfg.Packages, 289 Template: modelTemplate, 290 Funcs: funcMap, 291 }) 292 if err != nil { 293 return err 294 } 295 296 // We may have generated code in a package we already loaded, so we reload all packages 297 // to allow packages to be compared correctly 298 cfg.ReloadAllPackages() 299 300 return nil 301 } 302 303 func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) { 304 binder := cfg.NewBinder() 305 fields := make([]*Field, 0) 306 307 for _, field := range schemaType.Fields { 308 var typ types.Type 309 fieldDef := cfg.Schema.Types[field.Type.Name()] 310 311 if cfg.Models.UserDefined(field.Type.Name()) { 312 var err error 313 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 314 if err != nil { 315 return nil, err 316 } 317 } else { 318 switch fieldDef.Kind { 319 case ast.Scalar: 320 // no user defined model, referencing a default scalar 321 typ = types.NewNamed( 322 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 323 nil, 324 nil, 325 ) 326 327 case ast.Interface, ast.Union: 328 // no user defined model, referencing a generated interface type 329 typ = types.NewNamed( 330 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 331 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 332 nil, 333 ) 334 335 case ast.Enum: 336 // no user defined model, must reference a generated enum 337 typ = types.NewNamed( 338 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 339 nil, 340 nil, 341 ) 342 343 case ast.Object, ast.InputObject: 344 // no user defined model, must reference a generated struct 345 typ = types.NewNamed( 346 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 347 types.NewStruct(nil, nil), 348 nil, 349 ) 350 351 default: 352 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 353 } 354 } 355 356 name := templates.ToGo(field.Name) 357 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 358 name = nameOveride 359 } 360 361 typ = binder.CopyModifiersFromAst(field.Type, typ) 362 363 if cfg.StructFieldsAlwaysPointers { 364 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 365 typ = types.NewPointer(typ) 366 } 367 } 368 369 f := &Field{ 370 Name: field.Name, 371 GoName: name, 372 Type: typ, 373 Description: field.Description, 374 Tag: `json:"` + field.Name + `"`, 375 } 376 377 if m.FieldHook != nil { 378 mf, err := m.FieldHook(schemaType, field, f) 379 if err != nil { 380 return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) 381 } 382 f = mf 383 } 384 385 fields = append(fields, f) 386 } 387 388 return fields, nil 389 } 390 391 // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field 392 // name is used when no value argument is present. 393 func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 394 args := make([]string, 0) 395 for _, goTag := range fd.Directives.ForNames("goTag") { 396 key := "" 397 value := fd.Name 398 399 if arg := goTag.Arguments.ForName("key"); arg != nil { 400 if k, err := arg.Value.Value(nil); err == nil { 401 key = k.(string) 402 } 403 } 404 405 if arg := goTag.Arguments.ForName("value"); arg != nil { 406 if v, err := arg.Value.Value(nil); err == nil { 407 value = v.(string) 408 } 409 } 410 411 args = append(args, key+":\""+value+"\"") 412 } 413 414 if len(args) > 0 { 415 f.Tag = f.Tag + " " + strings.Join(args, " ") 416 } 417 418 return f, nil 419 } 420 421 // GoFieldHook applies the goField directive to the generated Field f. 422 func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 423 args := make([]string, 0) 424 _ = args 425 for _, goField := range fd.Directives.ForNames("goField") { 426 if arg := goField.Arguments.ForName("name"); arg != nil { 427 if k, err := arg.Value.Value(nil); err == nil { 428 f.GoName = k.(string) 429 } 430 } 431 } 432 return f, nil 433 } 434 435 func isStruct(t types.Type) bool { 436 _, is := t.Underlying().(*types.Struct) 437 return is 438 } 439 440 // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them 441 // with pointers. These relationships will produce compilation errors if they are not pointers. 442 // Also handles recursive structs. 443 func findAndHandleCyclicalRelationships(b *ModelBuild) { 444 for ii, structA := range b.Models { 445 for _, fieldA := range structA.Fields { 446 if strings.Contains(fieldA.Type.String(), "NotCyclicalA") { 447 fmt.Print() 448 } 449 if !isStruct(fieldA.Type) { 450 continue 451 } 452 453 // the field Type string will be in the form "github.com/mstephano/gqlgen-schemagen/codegen/testserver/followschema.LoopA" 454 // we only want the part after the last dot: "LoopA" 455 // this could lead to false positives, as we are only checking the name of the struct type, but these 456 // should be extremely rare, if it is even possible at all. 457 fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".") 458 fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1] 459 460 // find this struct type amongst the generated structs 461 for jj, structB := range b.Models { 462 if structB.Name != fieldAStructName { 463 continue 464 } 465 466 // check if structB contains a cyclical reference back to structA 467 var cyclicalReferenceFound bool 468 for _, fieldB := range structB.Fields { 469 if !isStruct(fieldB.Type) { 470 continue 471 } 472 473 fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".") 474 fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1] 475 if fieldBStructName == structA.Name { 476 cyclicalReferenceFound = true 477 fieldB.Type = types.NewPointer(fieldB.Type) 478 // keep looping in case this struct has additional fields of this type 479 } 480 } 481 482 // if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once 483 if cyclicalReferenceFound && ii != jj { 484 fieldA.Type = types.NewPointer(fieldA.Type) 485 break 486 } 487 } 488 } 489 } 490 }