github.com/99designs/gqlgen@v0.17.45/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 _ "embed" 5 "fmt" 6 "go/types" 7 "os" 8 "sort" 9 "strings" 10 "text/template" 11 12 "github.com/vektah/gqlparser/v2/ast" 13 14 "github.com/99designs/gqlgen/codegen/config" 15 "github.com/99designs/gqlgen/codegen/templates" 16 "github.com/99designs/gqlgen/plugin" 17 ) 18 19 //go:embed models.gotpl 20 var modelTemplate string 21 22 type ( 23 BuildMutateHook = func(b *ModelBuild) *ModelBuild 24 FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) 25 ) 26 27 // DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook. 28 func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 29 var err error 30 f, err = GoFieldHook(td, fd, f) 31 if err != nil { 32 return f, err 33 } 34 return GoTagFieldHook(td, fd, f) 35 } 36 37 // DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild. 38 func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild { 39 return b 40 } 41 42 type ModelBuild struct { 43 PackageName string 44 Interfaces []*Interface 45 Models []*Object 46 Enums []*Enum 47 Scalars []string 48 } 49 50 type Interface struct { 51 Description string 52 Name string 53 Fields []*Field 54 Implements []string 55 OmitCheck bool 56 Models []*Object 57 } 58 59 type Object struct { 60 Description string 61 Name string 62 Fields []*Field 63 Implements []string 64 } 65 66 type Field struct { 67 Description string 68 // Name is the field's name as it appears in the schema 69 Name string 70 // GoName is the field's name as it appears in the generated Go code 71 GoName string 72 Type types.Type 73 Tag string 74 IsResolver bool 75 Omittable bool 76 } 77 78 type Enum struct { 79 Description string 80 Name string 81 Values []*EnumValue 82 } 83 84 type EnumValue struct { 85 Description string 86 Name string 87 } 88 89 func New() plugin.Plugin { 90 return &Plugin{ 91 MutateHook: DefaultBuildMutateHook, 92 FieldHook: DefaultFieldMutateHook, 93 } 94 } 95 96 type Plugin struct { 97 MutateHook BuildMutateHook 98 FieldHook FieldMutateHook 99 } 100 101 var _ plugin.ConfigMutator = &Plugin{} 102 103 func (m *Plugin) Name() string { 104 return "modelgen" 105 } 106 107 func (m *Plugin) MutateConfig(cfg *config.Config) error { 108 b := &ModelBuild{ 109 PackageName: cfg.Model.Package, 110 } 111 112 for _, schemaType := range cfg.Schema.Types { 113 if cfg.Models.UserDefined(schemaType.Name) { 114 continue 115 } 116 switch schemaType.Kind { 117 case ast.Interface, ast.Union: 118 var fields []*Field 119 var err error 120 if !cfg.OmitGetters { 121 fields, err = m.generateFields(cfg, schemaType) 122 if err != nil { 123 return err 124 } 125 } 126 127 it := &Interface{ 128 Description: schemaType.Description, 129 Name: schemaType.Name, 130 Implements: schemaType.Interfaces, 131 Fields: fields, 132 OmitCheck: cfg.OmitInterfaceChecks, 133 } 134 135 // if the interface has a key directive as an entity interface, allow it to implement _Entity 136 if schemaType.Directives.ForName("key") != nil { 137 it.Implements = append(it.Implements, "_Entity") 138 } 139 140 b.Interfaces = append(b.Interfaces, it) 141 case ast.Object, ast.InputObject: 142 if cfg.IsRoot(schemaType) { 143 if !cfg.OmitRootModels { 144 b.Models = append(b.Models, &Object{ 145 Description: schemaType.Description, 146 Name: schemaType.Name, 147 }) 148 } 149 continue 150 } 151 152 fields, err := m.generateFields(cfg, schemaType) 153 if err != nil { 154 return err 155 } 156 157 it := &Object{ 158 Description: schemaType.Description, 159 Name: schemaType.Name, 160 Fields: fields, 161 } 162 163 // If Interface A implements interface B, and Interface C also implements interface B 164 // then both A and C have methods of B. 165 // The reason for checking unique is to prevent the same method B from being generated twice. 166 uniqueMap := map[string]bool{} 167 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 168 if !uniqueMap[implementor.Name] { 169 it.Implements = append(it.Implements, implementor.Name) 170 uniqueMap[implementor.Name] = true 171 } 172 // for interface implements 173 for _, iface := range implementor.Interfaces { 174 if !uniqueMap[iface] { 175 it.Implements = append(it.Implements, iface) 176 uniqueMap[iface] = true 177 } 178 } 179 180 } 181 182 b.Models = append(b.Models, it) 183 case ast.Enum: 184 it := &Enum{ 185 Name: schemaType.Name, 186 Description: schemaType.Description, 187 } 188 189 for _, v := range schemaType.EnumValues { 190 it.Values = append(it.Values, &EnumValue{ 191 Name: v.Name, 192 Description: v.Description, 193 }) 194 } 195 196 b.Enums = append(b.Enums, it) 197 case ast.Scalar: 198 b.Scalars = append(b.Scalars, schemaType.Name) 199 } 200 } 201 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 202 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 203 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 204 205 // if we are not just turning all struct-type fields in generated structs into pointers, we need to at least 206 // check for cyclical relationships and recursive structs 207 if !cfg.StructFieldsAlwaysPointers { 208 findAndHandleCyclicalRelationships(b) 209 } 210 211 for _, it := range b.Enums { 212 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 213 } 214 for _, it := range b.Models { 215 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 216 } 217 for _, it := range b.Interfaces { 218 // On a given interface we want to keep a reference to all the models that implement it 219 for _, model := range b.Models { 220 for _, impl := range model.Implements { 221 if impl == it.Name { 222 // check if this isn't an implementation of an entity interface 223 if impl != "_Entity" { 224 // If this model has an implementation, add it to the Interface's Models 225 it.Models = append(it.Models, model) 226 } 227 } 228 } 229 } 230 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 231 } 232 for _, it := range b.Scalars { 233 cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String") 234 } 235 236 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 237 return nil 238 } 239 240 if m.MutateHook != nil { 241 b = m.MutateHook(b) 242 } 243 244 getInterfaceByName := func(name string) *Interface { 245 // Allow looking up interfaces, so template can generate getters for each field 246 for _, i := range b.Interfaces { 247 if i.Name == name { 248 return i 249 } 250 } 251 252 return nil 253 } 254 gettersGenerated := make(map[string]map[string]struct{}) 255 generateGetter := func(model *Object, field *Field) string { 256 if model == nil || field == nil { 257 return "" 258 } 259 260 // Let templates check if a given getter has been generated already 261 typeGetters, exists := gettersGenerated[model.Name] 262 if !exists { 263 typeGetters = make(map[string]struct{}) 264 gettersGenerated[model.Name] = typeGetters 265 } 266 267 _, exists = typeGetters[field.GoName] 268 typeGetters[field.GoName] = struct{}{} 269 if exists { 270 return "" 271 } 272 273 _, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer) 274 var structFieldTypeIsPointer bool 275 for _, f := range model.Fields { 276 if f.GoName == field.GoName { 277 _, structFieldTypeIsPointer = f.Type.(*types.Pointer) 278 break 279 } 280 } 281 goType := templates.CurrentImports.LookupType(field.Type) 282 if strings.HasPrefix(goType, "[]") { 283 getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType) 284 getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName) 285 getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName) 286 getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName) 287 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 288 getter += "&" 289 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 290 getter += "*" 291 } 292 getter += "concrete) }\n" 293 getter += "\treturn interfaceSlice\n" 294 getter += "}" 295 return getter 296 } else { 297 getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType) 298 299 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 300 getter += "&" 301 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 302 getter += "*" 303 } 304 305 getter += fmt.Sprintf("this.%s }", field.GoName) 306 return getter 307 } 308 } 309 funcMap := template.FuncMap{ 310 "getInterfaceByName": getInterfaceByName, 311 "generateGetter": generateGetter, 312 } 313 newModelTemplate := modelTemplate 314 if cfg.Model.ModelTemplate != "" { 315 newModelTemplate = readModelTemplate(cfg.Model.ModelTemplate) 316 } 317 318 err := templates.Render(templates.Options{ 319 PackageName: cfg.Model.Package, 320 Filename: cfg.Model.Filename, 321 Data: b, 322 GeneratedHeader: true, 323 Packages: cfg.Packages, 324 Template: newModelTemplate, 325 Funcs: funcMap, 326 }) 327 if err != nil { 328 return err 329 } 330 331 // We may have generated code in a package we already loaded, so we reload all packages 332 // to allow packages to be compared correctly 333 cfg.ReloadAllPackages() 334 335 return nil 336 } 337 338 func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) { 339 binder := cfg.NewBinder() 340 fields := make([]*Field, 0) 341 342 var omittableType types.Type 343 344 for _, field := range schemaType.Fields { 345 var typ types.Type 346 fieldDef := cfg.Schema.Types[field.Type.Name()] 347 348 if cfg.Models.UserDefined(field.Type.Name()) { 349 var err error 350 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 351 if err != nil { 352 return nil, err 353 } 354 } else { 355 switch fieldDef.Kind { 356 case ast.Scalar: 357 // no user defined model, referencing a default scalar 358 typ = types.NewNamed( 359 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 360 nil, 361 nil, 362 ) 363 364 case ast.Interface, ast.Union: 365 // no user defined model, referencing a generated interface type 366 typ = types.NewNamed( 367 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 368 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 369 nil, 370 ) 371 372 case ast.Enum: 373 // no user defined model, must reference a generated enum 374 typ = types.NewNamed( 375 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 376 nil, 377 nil, 378 ) 379 380 case ast.Object, ast.InputObject: 381 // no user defined model, must reference a generated struct 382 typ = types.NewNamed( 383 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 384 types.NewStruct(nil, nil), 385 nil, 386 ) 387 388 default: 389 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 390 } 391 } 392 393 name := templates.ToGo(field.Name) 394 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 395 name = nameOveride 396 } 397 398 typ = binder.CopyModifiersFromAst(field.Type, typ) 399 400 if cfg.StructFieldsAlwaysPointers { 401 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 402 typ = types.NewPointer(typ) 403 } 404 } 405 406 f := &Field{ 407 Name: field.Name, 408 GoName: name, 409 Type: typ, 410 Description: field.Description, 411 Tag: getStructTagFromField(cfg, field), 412 Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull, 413 } 414 415 if m.FieldHook != nil { 416 mf, err := m.FieldHook(schemaType, field, f) 417 if err != nil { 418 return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) 419 } 420 f = mf 421 } 422 423 if f.IsResolver && cfg.OmitResolverFields { 424 continue 425 } 426 427 if f.Omittable { 428 if schemaType.Kind != ast.InputObject || field.Type.NonNull { 429 return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name) 430 } 431 432 var err error 433 434 if omittableType == nil { 435 omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable") 436 if err != nil { 437 return nil, err 438 } 439 } 440 441 f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type}) 442 if err != nil { 443 return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) 444 } 445 } 446 447 fields = append(fields, f) 448 } 449 450 // appending extra fields at the end of the fields list. 451 modelcfg := cfg.Models[schemaType.Name] 452 if len(modelcfg.ExtraFields) > 0 { 453 ff := make([]*Field, 0, len(modelcfg.ExtraFields)) 454 for fname, fspec := range modelcfg.ExtraFields { 455 ftype := buildType(fspec.Type) 456 457 tag := `json:"-"` 458 if fspec.OverrideTags != "" { 459 tag = fspec.OverrideTags 460 } 461 462 ff = append(ff, 463 &Field{ 464 Name: fname, 465 GoName: fname, 466 Type: ftype, 467 Description: fspec.Description, 468 Tag: tag, 469 }) 470 } 471 472 sort.Slice(ff, func(i, j int) bool { 473 return ff[i].Name < ff[j].Name 474 }) 475 476 fields = append(fields, ff...) 477 } 478 479 return fields, nil 480 } 481 482 func getStructTagFromField(cfg *config.Config, field *ast.FieldDefinition) string { 483 if !field.Type.NonNull && (cfg.EnableModelJsonOmitemptyTag == nil || *cfg.EnableModelJsonOmitemptyTag) { 484 return `json:"` + field.Name + `,omitempty"` 485 } 486 return `json:"` + field.Name + `"` 487 } 488 489 // GoTagFieldHook prepends the goTag directive to the generated Field f. 490 // When applying the Tag to the field, the field 491 // name is used if no value argument is present. 492 func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 493 args := make([]string, 0) 494 for _, goTag := range fd.Directives.ForNames("goTag") { 495 key := "" 496 value := fd.Name 497 498 if arg := goTag.Arguments.ForName("key"); arg != nil { 499 if k, err := arg.Value.Value(nil); err == nil { 500 key = k.(string) 501 } 502 } 503 504 if arg := goTag.Arguments.ForName("value"); arg != nil { 505 if v, err := arg.Value.Value(nil); err == nil { 506 value = v.(string) 507 } 508 } 509 510 args = append(args, key+":\""+value+"\"") 511 } 512 513 if len(args) > 0 { 514 f.Tag = removeDuplicateTags(f.Tag + " " + strings.Join(args, " ")) 515 } 516 517 return f, nil 518 } 519 520 // splitTagsBySpace split tags by space, except when space is inside quotes 521 func splitTagsBySpace(tagsString string) []string { 522 var tags []string 523 var currentTag string 524 inQuotes := false 525 526 for _, c := range tagsString { 527 if c == '"' { 528 inQuotes = !inQuotes 529 } 530 if c == ' ' && !inQuotes { 531 tags = append(tags, currentTag) 532 currentTag = "" 533 } else { 534 currentTag += string(c) 535 } 536 } 537 tags = append(tags, currentTag) 538 539 return tags 540 } 541 542 // containsInvalidSpace checks if the tagsString contains invalid space 543 func containsInvalidSpace(valuesString string) bool { 544 // get rid of quotes 545 valuesString = strings.ReplaceAll(valuesString, "\"", "") 546 if strings.Contains(valuesString, ",") { 547 // split by comma, 548 values := strings.Split(valuesString, ",") 549 for _, value := range values { 550 if strings.TrimSpace(value) != value { 551 return true 552 } 553 } 554 return false 555 } 556 if strings.Contains(valuesString, ";") { 557 // split by semicolon, which is common in gorm 558 values := strings.Split(valuesString, ";") 559 for _, value := range values { 560 if strings.TrimSpace(value) != value { 561 return true 562 } 563 } 564 return false 565 } 566 // single value 567 if strings.TrimSpace(valuesString) != valuesString { 568 return true 569 } 570 return false 571 } 572 573 func removeDuplicateTags(t string) string { 574 processed := make(map[string]bool) 575 tt := splitTagsBySpace(t) 576 returnTags := "" 577 578 // iterate backwards through tags so appended goTag directives are prioritized 579 for i := len(tt) - 1; i >= 0; i-- { 580 ti := tt[i] 581 // check if ti contains ":", and not contains any empty space. if not, tag is in wrong format 582 // correct example: json:"name" 583 if !strings.Contains(ti, ":") { 584 panic(fmt.Errorf("wrong format of tags: %s. goTag directive should be in format: @goTag(key: \"something\", value:\"value\"), ", t)) 585 } 586 587 kv := strings.Split(ti, ":") 588 if len(kv) == 0 || processed[kv[0]] { 589 continue 590 } 591 592 key := kv[0] 593 value := strings.Join(kv[1:], ":") 594 processed[key] = true 595 if len(returnTags) > 0 { 596 returnTags = " " + returnTags 597 } 598 599 isContained := containsInvalidSpace(value) 600 if isContained { 601 panic(fmt.Errorf("tag value should not contain any leading or trailing spaces: %s", value)) 602 } 603 604 returnTags = key + ":" + value + returnTags 605 } 606 607 return returnTags 608 } 609 610 // GoFieldHook applies the goField directive to the generated Field f. 611 func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 612 args := make([]string, 0) 613 _ = args 614 for _, goField := range fd.Directives.ForNames("goField") { 615 if arg := goField.Arguments.ForName("name"); arg != nil { 616 if k, err := arg.Value.Value(nil); err == nil { 617 f.GoName = k.(string) 618 } 619 } 620 621 if arg := goField.Arguments.ForName("forceResolver"); arg != nil { 622 if k, err := arg.Value.Value(nil); err == nil { 623 f.IsResolver = k.(bool) 624 } 625 } 626 627 if arg := goField.Arguments.ForName("omittable"); arg != nil { 628 if k, err := arg.Value.Value(nil); err == nil { 629 f.Omittable = k.(bool) 630 } 631 } 632 } 633 return f, nil 634 } 635 636 func isStruct(t types.Type) bool { 637 _, is := t.Underlying().(*types.Struct) 638 return is 639 } 640 641 // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them 642 // with pointers. These relationships will produce compilation errors if they are not pointers. 643 // Also handles recursive structs. 644 func findAndHandleCyclicalRelationships(b *ModelBuild) { 645 for ii, structA := range b.Models { 646 for _, fieldA := range structA.Fields { 647 if strings.Contains(fieldA.Type.String(), "NotCyclicalA") { 648 fmt.Print() 649 } 650 if !isStruct(fieldA.Type) { 651 continue 652 } 653 654 // the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA" 655 // we only want the part after the last dot: "LoopA" 656 // this could lead to false positives, as we are only checking the name of the struct type, but these 657 // should be extremely rare, if it is even possible at all. 658 fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".") 659 fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1] 660 661 // find this struct type amongst the generated structs 662 for jj, structB := range b.Models { 663 if structB.Name != fieldAStructName { 664 continue 665 } 666 667 // check if structB contains a cyclical reference back to structA 668 var cyclicalReferenceFound bool 669 for _, fieldB := range structB.Fields { 670 if !isStruct(fieldB.Type) { 671 continue 672 } 673 674 fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".") 675 fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1] 676 if fieldBStructName == structA.Name { 677 cyclicalReferenceFound = true 678 fieldB.Type = types.NewPointer(fieldB.Type) 679 // keep looping in case this struct has additional fields of this type 680 } 681 } 682 683 // if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once 684 if cyclicalReferenceFound && ii != jj { 685 fieldA.Type = types.NewPointer(fieldA.Type) 686 break 687 } 688 } 689 } 690 } 691 } 692 693 func readModelTemplate(customModelTemplate string) string { 694 contentBytes, err := os.ReadFile(customModelTemplate) 695 if err != nil { 696 panic(err) 697 } 698 return string(contentBytes) 699 }