github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 _ "embed" 5 "fmt" 6 "go/types" 7 "sort" 8 "strings" 9 "text/template" 10 11 "github.com/spread-ai/gqlgen/codegen/config" 12 "github.com/spread-ai/gqlgen/codegen/templates" 13 "github.com/spread-ai/gqlgen/plugin" 14 "github.com/vektah/gqlparser/v2/ast" 15 ) 16 17 //go:embed models.gotpl 18 var modelTemplate string 19 20 type BuildMutateHook = func(b *ModelBuild) *ModelBuild 21 22 type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) 23 24 // defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook. 25 func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 26 return GoTagFieldHook(td, fd, f) 27 } 28 29 func defaultBuildMutateHook(b *ModelBuild) *ModelBuild { 30 return b 31 } 32 33 type ModelBuild struct { 34 PackageName string 35 Interfaces []*Interface 36 Models []*Object 37 Enums []*Enum 38 Scalars []string 39 } 40 41 type Interface struct { 42 Description string 43 Name string 44 Fields []*Field 45 Implements []string 46 } 47 48 type Object struct { 49 Description string 50 Name string 51 Fields []*Field 52 Implements []string 53 } 54 55 type Field struct { 56 Description string 57 // Name is the field's name as it appears in the schema 58 Name string 59 // GoName is the field's name as it appears in the generated Go code 60 GoName string 61 Type types.Type 62 Tag string 63 } 64 65 type Enum struct { 66 Description string 67 Name string 68 Values []*EnumValue 69 } 70 71 type EnumValue struct { 72 Description string 73 Name string 74 } 75 76 func New() plugin.Plugin { 77 return &Plugin{ 78 MutateHook: defaultBuildMutateHook, 79 FieldHook: defaultFieldMutateHook, 80 } 81 } 82 83 type Plugin struct { 84 MutateHook BuildMutateHook 85 FieldHook FieldMutateHook 86 } 87 88 var _ plugin.ConfigMutator = &Plugin{} 89 90 func (m *Plugin) Name() string { 91 return "modelgen" 92 } 93 94 func (m *Plugin) MutateConfig(cfg *config.Config) error { 95 96 b := &ModelBuild{ 97 PackageName: cfg.Model.Package, 98 } 99 100 for _, schemaType := range cfg.Schema.Types { 101 if cfg.Models.UserDefined(schemaType.Name) { 102 continue 103 } 104 switch schemaType.Kind { 105 case ast.Interface, ast.Union: 106 var fields []*Field 107 var err error 108 if !cfg.OmitGetters { 109 fields, err = m.generateFields(cfg, schemaType) 110 if err != nil { 111 return err 112 } 113 } 114 115 it := &Interface{ 116 Description: schemaType.Description, 117 Name: schemaType.Name, 118 Implements: schemaType.Interfaces, 119 Fields: fields, 120 } 121 122 b.Interfaces = append(b.Interfaces, it) 123 case ast.Object, ast.InputObject: 124 if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { 125 continue 126 } 127 128 fields, err := m.generateFields(cfg, schemaType) 129 if err != nil { 130 return err 131 } 132 133 it := &Object{ 134 Description: schemaType.Description, 135 Name: schemaType.Name, 136 Fields: fields, 137 } 138 139 // If Interface A implements interface B, and Interface C also implements interface B 140 // then both A and C have methods of B. 141 // The reason for checking unique is to prevent the same method B from being generated twice. 142 uniqueMap := map[string]bool{} 143 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 144 if !uniqueMap[implementor.Name] { 145 it.Implements = append(it.Implements, implementor.Name) 146 uniqueMap[implementor.Name] = true 147 } 148 // for interface implements 149 for _, iface := range implementor.Interfaces { 150 if !uniqueMap[iface] { 151 it.Implements = append(it.Implements, iface) 152 uniqueMap[iface] = true 153 } 154 } 155 } 156 157 b.Models = append(b.Models, it) 158 case ast.Enum: 159 it := &Enum{ 160 Name: schemaType.Name, 161 Description: schemaType.Description, 162 } 163 164 for _, v := range schemaType.EnumValues { 165 it.Values = append(it.Values, &EnumValue{ 166 Name: v.Name, 167 Description: v.Description, 168 }) 169 } 170 171 b.Enums = append(b.Enums, it) 172 case ast.Scalar: 173 b.Scalars = append(b.Scalars, schemaType.Name) 174 } 175 } 176 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 177 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 178 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 179 180 // if we are not just turning all struct-type fields in generated structs into pointers, we need to at least 181 // check for cyclical relationships and recursive structs 182 if !cfg.StructFieldsAlwaysPointers { 183 findAndHandleCyclicalRelationships(b) 184 } 185 186 for _, it := range b.Enums { 187 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 188 } 189 for _, it := range b.Models { 190 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 191 } 192 for _, it := range b.Interfaces { 193 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 194 } 195 for _, it := range b.Scalars { 196 cfg.Models.Add(it, "github.com/spread-ai/gqlgen/graphql.String") 197 } 198 199 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 200 return nil 201 } 202 203 if m.MutateHook != nil { 204 b = m.MutateHook(b) 205 } 206 207 getInterfaceByName := func(name string) *Interface { 208 // Allow looking up interfaces, so template can generate getters for each field 209 for _, i := range b.Interfaces { 210 if i.Name == name { 211 return i 212 } 213 } 214 215 return nil 216 } 217 gettersGenerated := make(map[string]map[string]struct{}) 218 generateGetter := func(model *Object, field *Field) string { 219 if model == nil || field == nil { 220 return "" 221 } 222 223 // Let templates check if a given getter has been generated already 224 typeGetters, exists := gettersGenerated[model.Name] 225 if !exists { 226 typeGetters = make(map[string]struct{}) 227 gettersGenerated[model.Name] = typeGetters 228 } 229 230 _, exists = typeGetters[field.GoName] 231 typeGetters[field.GoName] = struct{}{} 232 if exists { 233 return "" 234 } 235 236 _, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer) 237 var structFieldTypeIsPointer bool 238 for _, f := range model.Fields { 239 if f.GoName == field.GoName { 240 _, structFieldTypeIsPointer = f.Type.(*types.Pointer) 241 break 242 } 243 } 244 goType := templates.CurrentImports.LookupType(field.Type) 245 if strings.HasPrefix(goType, "[]") { 246 getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType) 247 getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName) 248 getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName) 249 getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName) 250 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 251 getter += "&" 252 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 253 getter += "*" 254 } 255 getter += "concrete) }\n" 256 getter += "\treturn interfaceSlice\n" 257 getter += "}" 258 return getter 259 } else { 260 getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType) 261 262 if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer { 263 getter += "&" 264 } else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer { 265 getter += "*" 266 } 267 268 getter += fmt.Sprintf("this.%s }", field.GoName) 269 return getter 270 } 271 } 272 funcMap := template.FuncMap{ 273 "getInterfaceByName": getInterfaceByName, 274 "generateGetter": generateGetter, 275 } 276 277 err := templates.Render(templates.Options{ 278 PackageName: cfg.Model.Package, 279 Filename: cfg.Model.Filename, 280 Data: b, 281 GeneratedHeader: true, 282 Packages: cfg.Packages, 283 Template: modelTemplate, 284 Funcs: funcMap, 285 }) 286 if err != nil { 287 return err 288 } 289 290 // We may have generated code in a package we already loaded, so we reload all packages 291 // to allow packages to be compared correctly 292 cfg.ReloadAllPackages() 293 294 return nil 295 } 296 297 func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) { 298 binder := cfg.NewBinder() 299 fields := make([]*Field, 0) 300 301 for _, field := range schemaType.Fields { 302 var typ types.Type 303 fieldDef := cfg.Schema.Types[field.Type.Name()] 304 305 if cfg.Models.UserDefined(field.Type.Name()) { 306 var err error 307 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 308 if err != nil { 309 return nil, err 310 } 311 } else { 312 switch fieldDef.Kind { 313 case ast.Scalar: 314 // no user defined model, referencing a default scalar 315 typ = types.NewNamed( 316 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 317 nil, 318 nil, 319 ) 320 321 case ast.Interface, ast.Union: 322 // no user defined model, referencing a generated interface type 323 typ = types.NewNamed( 324 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 325 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 326 nil, 327 ) 328 329 case ast.Enum: 330 // no user defined model, must reference a generated enum 331 typ = types.NewNamed( 332 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 333 nil, 334 nil, 335 ) 336 337 case ast.Object, ast.InputObject: 338 // no user defined model, must reference a generated struct 339 typ = types.NewNamed( 340 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 341 types.NewStruct(nil, nil), 342 nil, 343 ) 344 345 default: 346 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 347 } 348 } 349 350 name := templates.ToGo(field.Name) 351 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 352 name = nameOveride 353 } 354 355 typ = binder.CopyModifiersFromAst(field.Type, typ) 356 357 if cfg.StructFieldsAlwaysPointers { 358 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 359 typ = types.NewPointer(typ) 360 } 361 } 362 363 f := &Field{ 364 Name: field.Name, 365 GoName: name, 366 Type: typ, 367 Description: field.Description, 368 Tag: `json:"` + field.Name + `"`, 369 } 370 371 if m.FieldHook != nil { 372 mf, err := m.FieldHook(schemaType, field, f) 373 if err != nil { 374 return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) 375 } 376 f = mf 377 } 378 379 fields = append(fields, f) 380 } 381 382 return fields, nil 383 } 384 385 // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field 386 // name is used when no value argument is present. 387 func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 388 args := make([]string, 0) 389 for _, goTag := range fd.Directives.ForNames("goTag") { 390 key := "" 391 value := fd.Name 392 393 if arg := goTag.Arguments.ForName("key"); arg != nil { 394 if k, err := arg.Value.Value(nil); err == nil { 395 key = k.(string) 396 } 397 } 398 399 if arg := goTag.Arguments.ForName("value"); arg != nil { 400 if v, err := arg.Value.Value(nil); err == nil { 401 value = v.(string) 402 } 403 } 404 405 if key == "json" { 406 if value == "omitempty" { 407 f.Tag = strings.ReplaceAll(f.Tag, `json:"`+f.Name+`"`, `json:"`+f.Name+`,omitempty"`) 408 } else { 409 f.Tag = strings.ReplaceAll(f.Tag, `json:"`+f.Name+`"`, "") 410 args = append(args, key+":\""+value+"\"") 411 } 412 } else { 413 args = append(args, key+":\""+value+"\"") 414 } 415 } 416 417 if len(args) > 0 { 418 f.Tag = f.Tag + " " + strings.Join(args, " ") 419 } 420 421 return f, nil 422 } 423 424 func isStruct(t types.Type) bool { 425 _, is := t.Underlying().(*types.Struct) 426 return is 427 } 428 429 // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them 430 // with pointers. These relationships will produce compilation errors if they are not pointers. 431 // Also handles recursive structs. 432 func findAndHandleCyclicalRelationships(b *ModelBuild) { 433 for ii, structA := range b.Models { 434 for _, fieldA := range structA.Fields { 435 if strings.Contains(fieldA.Type.String(), "NotCyclicalA") { 436 fmt.Print() 437 } 438 if !isStruct(fieldA.Type) { 439 continue 440 } 441 442 // the field Type string will be in the form "github.com/spread-ai/gqlgen/codegen/testserver/followschema.LoopA" 443 // we only want the part after the last dot: "LoopA" 444 // this could lead to false positives, as we are only checking the name of the struct type, but these 445 // should be extremely rare, if it is even possible at all. 446 fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".") 447 fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1] 448 449 // find this struct type amongst the generated structs 450 for jj, structB := range b.Models { 451 if structB.Name != fieldAStructName { 452 continue 453 } 454 455 // check if structB contains a cyclical reference back to structA 456 var cyclicalReferenceFound bool 457 for _, fieldB := range structB.Fields { 458 if !isStruct(fieldB.Type) { 459 continue 460 } 461 462 fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".") 463 fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1] 464 if fieldBStructName == structA.Name { 465 cyclicalReferenceFound = true 466 fieldB.Type = types.NewPointer(fieldB.Type) 467 // keep looping in case this struct has additional fields of this type 468 } 469 } 470 471 // if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once 472 if cyclicalReferenceFound && ii != jj { 473 fieldA.Type = types.NewPointer(fieldA.Type) 474 break 475 } 476 } 477 } 478 } 479 }