github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 _ "embed" 5 "fmt" 6 "go/types" 7 "os" 8 "sort" 9 "strings" 10 "text/template" 11 12 "github.com/geneva/gqlgen/codegen/config" 13 "github.com/geneva/gqlgen/codegen/templates" 14 "github.com/geneva/gqlgen/plugin" 15 "github.com/vektah/gqlparser/v2/ast" 16 ) 17 18 //go:embed models.gotpl 19 var modelTemplate string 20 21 type ( 22 BuildMutateHook = func(b *ModelBuild) *ModelBuild 23 FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) 24 ) 25 26 // DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook. 27 func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 28 var err error 29 f, err = GoFieldHook(td, fd, f) 30 if err != nil { 31 return f, err 32 } 33 return GoTagFieldHook(td, fd, f) 34 } 35 36 // DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild. 37 func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild { 38 return b 39 } 40 41 type ModelBuild struct { 42 PackageName string 43 Interfaces []*Interface 44 Models []*Object 45 Enums []*Enum 46 Scalars []string 47 } 48 49 type Interface struct { 50 Description string 51 Name string 52 Fields []*Field 53 Implements []string 54 OmitCheck bool 55 } 56 57 type Object struct { 58 Description string 59 Name string 60 Fields []*Field 61 Implements []string 62 } 63 64 type Field struct { 65 Description string 66 // Name is the field's name as it appears in the schema 67 Name string 68 // GoName is the field's name as it appears in the generated Go code 69 GoName string 70 Type types.Type 71 Tag string 72 Omittable bool 73 } 74 75 type Enum struct { 76 Description string 77 Name string 78 Values []*EnumValue 79 } 80 81 type EnumValue struct { 82 Description string 83 Name string 84 } 85 86 func New() plugin.Plugin { 87 return &Plugin{ 88 MutateHook: DefaultBuildMutateHook, 89 FieldHook: DefaultFieldMutateHook, 90 } 91 } 92 93 type Plugin struct { 94 MutateHook BuildMutateHook 95 FieldHook FieldMutateHook 96 } 97 98 var _ plugin.ConfigMutator = &Plugin{} 99 100 func (m *Plugin) Name() string { 101 return "modelgen" 102 } 103 104 func (m *Plugin) MutateConfig(cfg *config.Config) error { 105 b := &ModelBuild{ 106 PackageName: cfg.Model.Package, 107 } 108 109 for _, schemaType := range cfg.Schema.Types { 110 if cfg.Models.UserDefined(schemaType.Name) { 111 continue 112 } 113 switch schemaType.Kind { 114 case ast.Interface, ast.Union: 115 var fields []*Field 116 var err error 117 if !cfg.OmitGetters { 118 fields, err = m.generateFields(cfg, schemaType) 119 if err != nil { 120 return err 121 } 122 } 123 124 it := &Interface{ 125 Description: schemaType.Description, 126 Name: schemaType.Name, 127 Implements: schemaType.Interfaces, 128 Fields: fields, 129 OmitCheck: cfg.OmitInterfaceChecks, 130 } 131 132 b.Interfaces = append(b.Interfaces, it) 133 case ast.Object, ast.InputObject: 134 if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { 135 continue 136 } 137 138 fields, err := m.generateFields(cfg, schemaType) 139 if err != nil { 140 return err 141 } 142 143 it := &Object{ 144 Description: schemaType.Description, 145 Name: schemaType.Name, 146 Fields: fields, 147 } 148 149 // If Interface A implements interface B, and Interface C also implements interface B 150 // then both A and C have methods of B. 151 // The reason for checking unique is to prevent the same method B from being generated twice. 152 uniqueMap := map[string]bool{} 153 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 154 if !uniqueMap[implementor.Name] { 155 it.Implements = append(it.Implements, implementor.Name) 156 uniqueMap[implementor.Name] = true 157 } 158 // for interface implements 159 for _, iface := range implementor.Interfaces { 160 if !uniqueMap[iface] { 161 it.Implements = append(it.Implements, iface) 162 uniqueMap[iface] = true 163 } 164 } 165 } 166 167 b.Models = append(b.Models, it) 168 case ast.Enum: 169 it := &Enum{ 170 Name: schemaType.Name, 171 Description: schemaType.Description, 172 } 173 174 for _, v := range schemaType.EnumValues { 175 it.Values = append(it.Values, &EnumValue{ 176 Name: v.Name, 177 Description: v.Description, 178 }) 179 } 180 181 b.Enums = append(b.Enums, it) 182 case ast.Scalar: 183 b.Scalars = append(b.Scalars, schemaType.Name) 184 } 185 } 186 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 187 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 188 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 189 190 // if we are not just turning all struct-type fields in generated structs into pointers, we need to at least 191 // check for cyclical relationships and recursive structs 192 if !cfg.StructFieldsAlwaysPointers { 193 findAndHandleCyclicalRelationships(b) 194 } 195 196 for _, it := range b.Enums { 197 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 198 } 199 for _, it := range b.Models { 200 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 201 } 202 for _, it := range b.Interfaces { 203 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 204 } 205 for _, it := range b.Scalars { 206 cfg.Models.Add(it, "github.com/geneva/gqlgen/graphql.String") 207 } 208 209 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 210 return nil 211 } 212 213 if m.MutateHook != nil { 214 b = m.MutateHook(b) 215 } 216 217 getInterfaceByName := func(name string) *Interface { 218 // Allow looking up interfaces, so template can generate getters for each field 219 for _, i := range b.Interfaces { 220 if i.Name == name { 221 return i 222 } 223 } 224 225 return nil 226 } 227 gettersGenerated := make(map[string]map[string]struct{}) 228 generateGetter := func(model *Object, field *Field) string { 229 if model == nil || field == nil { 230 return "" 231 } 232 233 // Let templates check if a given getter has been generated already 234 typeGetters, exists := gettersGenerated[model.Name] 235 if !exists { 236 typeGetters = make(map[string]struct{}) 237 gettersGenerated[model.Name] = typeGetters 238 } 239 240 _, exists = typeGetters[field.GoName] 241 typeGetters[field.GoName] = struct{}{} 242 if exists { 243 return "" 244 } 245 246 _, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer) 247 var structFieldTypeIsPointer bool 248 for _, f := range model.Fields { 249 if f.GoName == field.GoName { 250 _, structFieldTypeIsPointer = f.Type.(*types.Pointer) 251 break 252 } 253 } 254 goType := templates.CurrentImports.LookupType(field.Type) 255 if strings.HasPrefix(goType, "[]") { 256 getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType) 257 getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName) 258 getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName) 259 getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName) 260 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 261 getter += "&" 262 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 263 getter += "*" 264 } 265 getter += "concrete) }\n" 266 getter += "\treturn interfaceSlice\n" 267 getter += "}" 268 return getter 269 } else { 270 getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType) 271 272 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 273 getter += "&" 274 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 275 getter += "*" 276 } 277 278 getter += fmt.Sprintf("this.%s }", field.GoName) 279 return getter 280 } 281 } 282 funcMap := template.FuncMap{ 283 "getInterfaceByName": getInterfaceByName, 284 "generateGetter": generateGetter, 285 } 286 newModelTemplate := modelTemplate 287 if cfg.Model.ModelTemplate != "" { 288 newModelTemplate = readModelTemplate(cfg.Model.ModelTemplate) 289 } 290 291 err := templates.Render(templates.Options{ 292 PackageName: cfg.Model.Package, 293 Filename: cfg.Model.Filename, 294 Data: b, 295 GeneratedHeader: true, 296 Packages: cfg.Packages, 297 Template: newModelTemplate, 298 Funcs: funcMap, 299 }) 300 if err != nil { 301 return err 302 } 303 304 // We may have generated code in a package we already loaded, so we reload all packages 305 // to allow packages to be compared correctly 306 cfg.ReloadAllPackages() 307 308 return nil 309 } 310 311 func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) { 312 binder := cfg.NewBinder() 313 fields := make([]*Field, 0) 314 315 var omittableType types.Type 316 317 for _, field := range schemaType.Fields { 318 var typ types.Type 319 fieldDef := cfg.Schema.Types[field.Type.Name()] 320 321 if cfg.Models.UserDefined(field.Type.Name()) { 322 var err error 323 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 324 if err != nil { 325 return nil, err 326 } 327 } else { 328 switch fieldDef.Kind { 329 case ast.Scalar: 330 // no user defined model, referencing a default scalar 331 typ = types.NewNamed( 332 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 333 nil, 334 nil, 335 ) 336 337 case ast.Interface, ast.Union: 338 // no user defined model, referencing a generated interface type 339 typ = types.NewNamed( 340 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 341 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 342 nil, 343 ) 344 345 case ast.Enum: 346 // no user defined model, must reference a generated enum 347 typ = types.NewNamed( 348 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 349 nil, 350 nil, 351 ) 352 353 case ast.Object, ast.InputObject: 354 // no user defined model, must reference a generated struct 355 typ = types.NewNamed( 356 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 357 types.NewStruct(nil, nil), 358 nil, 359 ) 360 361 default: 362 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 363 } 364 } 365 366 name := templates.ToGo(field.Name) 367 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 368 name = nameOveride 369 } 370 371 typ = binder.CopyModifiersFromAst(field.Type, typ) 372 373 if cfg.StructFieldsAlwaysPointers { 374 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 375 typ = types.NewPointer(typ) 376 } 377 } 378 379 f := &Field{ 380 Name: field.Name, 381 GoName: name, 382 Type: typ, 383 Description: field.Description, 384 Tag: getStructTagFromField(cfg, field), 385 Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull, 386 } 387 388 if m.FieldHook != nil { 389 mf, err := m.FieldHook(schemaType, field, f) 390 if err != nil { 391 return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) 392 } 393 f = mf 394 } 395 396 if f.Omittable { 397 if schemaType.Kind != ast.InputObject || field.Type.NonNull { 398 return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name) 399 } 400 401 var err error 402 403 if omittableType == nil { 404 omittableType, err = binder.FindTypeFromName("github.com/geneva/gqlgen/graphql.Omittable") 405 if err != nil { 406 return nil, err 407 } 408 } 409 410 f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type}) 411 if err != nil { 412 return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) 413 } 414 } 415 416 fields = append(fields, f) 417 } 418 419 // appending extra fields at the end of the fields list. 420 modelcfg := cfg.Models[schemaType.Name] 421 if len(modelcfg.ExtraFields) > 0 { 422 ff := make([]*Field, 0, len(modelcfg.ExtraFields)) 423 for fname, fspec := range modelcfg.ExtraFields { 424 ftype := buildType(fspec.Type) 425 426 tag := `json:"-"` 427 if fspec.OverrideTags != "" { 428 tag = fspec.OverrideTags 429 } 430 431 ff = append(ff, 432 &Field{ 433 Name: fname, 434 GoName: fname, 435 Type: ftype, 436 Description: fspec.Description, 437 Tag: tag, 438 }) 439 } 440 441 sort.Slice(ff, func(i, j int) bool { 442 return ff[i].Name < ff[j].Name 443 }) 444 445 fields = append(fields, ff...) 446 } 447 448 return fields, nil 449 } 450 451 func getStructTagFromField(cfg *config.Config, field *ast.FieldDefinition) string { 452 if !field.Type.NonNull && (cfg.EnableModelJsonOmitemptyTag == nil || *cfg.EnableModelJsonOmitemptyTag) { 453 return `json:"` + field.Name + `,omitempty"` 454 } 455 return `json:"` + field.Name + `"` 456 } 457 458 // GoTagFieldHook prepends the goTag directive to the generated Field f. 459 // When applying the Tag to the field, the field 460 // name is used if no value argument is present. 461 func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 462 args := make([]string, 0) 463 for _, goTag := range fd.Directives.ForNames("goTag") { 464 key := "" 465 value := fd.Name 466 467 if arg := goTag.Arguments.ForName("key"); arg != nil { 468 if k, err := arg.Value.Value(nil); err == nil { 469 key = k.(string) 470 } 471 } 472 473 if arg := goTag.Arguments.ForName("value"); arg != nil { 474 if v, err := arg.Value.Value(nil); err == nil { 475 value = v.(string) 476 } 477 } 478 479 args = append(args, key+":\""+value+"\"") 480 } 481 482 if len(args) > 0 { 483 f.Tag = removeDuplicateTags(f.Tag + " " + strings.Join(args, " ")) 484 } 485 486 return f, nil 487 } 488 489 // splitTagsBySpace split tags by space, except when space is inside quotes 490 func splitTagsBySpace(tagsString string) []string { 491 var tags []string 492 var currentTag string 493 inQuotes := false 494 495 for _, c := range tagsString { 496 if c == '"' { 497 inQuotes = !inQuotes 498 } 499 if c == ' ' && !inQuotes { 500 tags = append(tags, currentTag) 501 currentTag = "" 502 } else { 503 currentTag += string(c) 504 } 505 } 506 tags = append(tags, currentTag) 507 508 return tags 509 } 510 511 // containsInvalidSpace checks if the tagsString contains invalid space 512 func containsInvalidSpace(valuesString string) bool { 513 // get rid of quotes 514 valuesString = strings.ReplaceAll(valuesString, "\"", "") 515 if strings.Contains(valuesString, ",") { 516 // split by comma, 517 values := strings.Split(valuesString, ",") 518 for _, value := range values { 519 if strings.TrimSpace(value) != value { 520 return true 521 } 522 } 523 return false 524 } 525 if strings.Contains(valuesString, ";") { 526 // split by semicolon, which is common in gorm 527 values := strings.Split(valuesString, ";") 528 for _, value := range values { 529 if strings.TrimSpace(value) != value { 530 return true 531 } 532 } 533 return false 534 } 535 // single value 536 if strings.TrimSpace(valuesString) != valuesString { 537 return true 538 } 539 return false 540 } 541 542 func removeDuplicateTags(t string) string { 543 processed := make(map[string]bool) 544 tt := splitTagsBySpace(t) 545 returnTags := "" 546 547 // iterate backwards through tags so appended goTag directives are prioritized 548 for i := len(tt) - 1; i >= 0; i-- { 549 ti := tt[i] 550 // check if ti contains ":", and not contains any empty space. if not, tag is in wrong format 551 // correct example: json:"name" 552 if !strings.Contains(ti, ":") { 553 panic(fmt.Errorf("wrong format of tags: %s. goTag directive should be in format: @goTag(key: \"something\", value:\"value\"), ", t)) 554 } 555 556 kv := strings.Split(ti, ":") 557 if len(kv) == 0 || processed[kv[0]] { 558 continue 559 } 560 561 processed[kv[0]] = true 562 if len(returnTags) > 0 { 563 returnTags = " " + returnTags 564 } 565 566 isContained := containsInvalidSpace(kv[1]) 567 if isContained { 568 panic(fmt.Errorf("tag value should not contain any leading or trailing spaces: %s", kv[1])) 569 } 570 571 returnTags = kv[0] + ":" + kv[1] + returnTags 572 } 573 574 return returnTags 575 } 576 577 // GoFieldHook applies the goField directive to the generated Field f. 578 func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 579 args := make([]string, 0) 580 _ = args 581 for _, goField := range fd.Directives.ForNames("goField") { 582 if arg := goField.Arguments.ForName("name"); arg != nil { 583 if k, err := arg.Value.Value(nil); err == nil { 584 f.GoName = k.(string) 585 } 586 } 587 588 if arg := goField.Arguments.ForName("omittable"); arg != nil { 589 if k, err := arg.Value.Value(nil); err == nil { 590 f.Omittable = k.(bool) 591 } 592 } 593 } 594 return f, nil 595 } 596 597 func isStruct(t types.Type) bool { 598 _, is := t.Underlying().(*types.Struct) 599 return is 600 } 601 602 // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them 603 // with pointers. These relationships will produce compilation errors if they are not pointers. 604 // Also handles recursive structs. 605 func findAndHandleCyclicalRelationships(b *ModelBuild) { 606 for ii, structA := range b.Models { 607 for _, fieldA := range structA.Fields { 608 if strings.Contains(fieldA.Type.String(), "NotCyclicalA") { 609 fmt.Print() 610 } 611 if !isStruct(fieldA.Type) { 612 continue 613 } 614 615 // the field Type string will be in the form "github.com/geneva/gqlgen/codegen/testserver/followschema.LoopA" 616 // we only want the part after the last dot: "LoopA" 617 // this could lead to false positives, as we are only checking the name of the struct type, but these 618 // should be extremely rare, if it is even possible at all. 619 fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".") 620 fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1] 621 622 // find this struct type amongst the generated structs 623 for jj, structB := range b.Models { 624 if structB.Name != fieldAStructName { 625 continue 626 } 627 628 // check if structB contains a cyclical reference back to structA 629 var cyclicalReferenceFound bool 630 for _, fieldB := range structB.Fields { 631 if !isStruct(fieldB.Type) { 632 continue 633 } 634 635 fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".") 636 fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1] 637 if fieldBStructName == structA.Name { 638 cyclicalReferenceFound = true 639 fieldB.Type = types.NewPointer(fieldB.Type) 640 // keep looping in case this struct has additional fields of this type 641 } 642 } 643 644 // if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once 645 if cyclicalReferenceFound && ii != jj { 646 fieldA.Type = types.NewPointer(fieldA.Type) 647 break 648 } 649 } 650 } 651 } 652 } 653 654 func readModelTemplate(customModelTemplate string) string { 655 contentBytes, err := os.ReadFile(customModelTemplate) 656 if err != nil { 657 panic(err) 658 } 659 return string(contentBytes) 660 }