github.com/senomas/gqlgen@v0.17.11-0.20220626120754-9aee61b0716a/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 "fmt" 5 "go/types" 6 "sort" 7 "strings" 8 9 "github.com/99designs/gqlgen/codegen/config" 10 "github.com/99designs/gqlgen/codegen/templates" 11 "github.com/99designs/gqlgen/plugin" 12 "github.com/vektah/gqlparser/v2/ast" 13 ) 14 15 type BuildMutateHook = func(b *ModelBuild) *ModelBuild 16 17 type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) 18 19 // defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook. 20 func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 21 return GoTagFieldHook(td, fd, f) 22 } 23 24 func defaultBuildMutateHook(b *ModelBuild) *ModelBuild { 25 return b 26 } 27 28 type ModelBuild struct { 29 PackageName string 30 Interfaces []*Interface 31 Models []*Object 32 Enums []*Enum 33 Scalars []string 34 } 35 36 type Interface struct { 37 Description string 38 Name string 39 Implements []string 40 } 41 42 type Object struct { 43 Description string 44 Name string 45 Fields []*Field 46 Implements []string 47 } 48 49 type Field struct { 50 Description string 51 GoName string 52 Type types.Type 53 Tag string 54 Ref string 55 } 56 57 type Enum struct { 58 Description string 59 Name string 60 Values []*EnumValue 61 } 62 63 type EnumValue struct { 64 Description string 65 Name string 66 } 67 68 func New() plugin.Plugin { 69 return &Plugin{ 70 MutateHook: defaultBuildMutateHook, 71 FieldHook: defaultFieldMutateHook, 72 } 73 } 74 75 type Plugin struct { 76 MutateHook BuildMutateHook 77 FieldHook FieldMutateHook 78 } 79 80 var _ plugin.ConfigMutator = &Plugin{} 81 82 func (m *Plugin) Name() string { 83 return "modelgen" 84 } 85 86 func (m *Plugin) MutateConfig(cfg *config.Config) error { 87 binder := cfg.NewBinder() 88 89 b := &ModelBuild{ 90 PackageName: cfg.Model.Package, 91 } 92 93 for _, schemaType := range cfg.Schema.Types { 94 if cfg.Models.UserDefined(schemaType.Name) { 95 continue 96 } 97 switch schemaType.Kind { 98 case ast.Interface, ast.Union: 99 it := &Interface{ 100 Description: schemaType.Description, 101 Name: schemaType.Name, 102 Implements: schemaType.Interfaces, 103 } 104 105 b.Interfaces = append(b.Interfaces, it) 106 case ast.Object, ast.InputObject: 107 if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { 108 continue 109 } 110 it := &Object{ 111 Description: schemaType.Description, 112 Name: schemaType.Name, 113 } 114 115 // If Interface A implements interface B, and Interface C also implements interface B 116 // then both A and C have methods of B. 117 // The reason for checking unique is to prevent the same method B from being generated twice. 118 uniqueMap := map[string]bool{} 119 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 120 if !uniqueMap[implementor.Name] { 121 it.Implements = append(it.Implements, implementor.Name) 122 uniqueMap[implementor.Name] = true 123 } 124 // for interface implements 125 for _, iface := range implementor.Interfaces { 126 if !uniqueMap[iface] { 127 it.Implements = append(it.Implements, iface) 128 uniqueMap[iface] = true 129 } 130 } 131 } 132 133 for _, field := range schemaType.Fields { 134 var typ types.Type 135 fieldDef := cfg.Schema.Types[field.Type.Name()] 136 137 if cfg.Models.UserDefined(field.Type.Name()) { 138 var err error 139 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 140 if err != nil { 141 return err 142 } 143 } else { 144 switch fieldDef.Kind { 145 case ast.Scalar: 146 // no user defined model, referencing a default scalar 147 typ = types.NewNamed( 148 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 149 nil, 150 nil, 151 ) 152 153 case ast.Interface, ast.Union: 154 // no user defined model, referencing a generated interface type 155 typ = types.NewNamed( 156 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 157 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 158 nil, 159 ) 160 161 case ast.Enum: 162 // no user defined model, must reference a generated enum 163 typ = types.NewNamed( 164 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 165 nil, 166 nil, 167 ) 168 169 case ast.Object, ast.InputObject: 170 // no user defined model, must reference a generated struct 171 typ = types.NewNamed( 172 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 173 types.NewStruct(nil, nil), 174 nil, 175 ) 176 177 default: 178 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 179 } 180 } 181 182 name := field.Name 183 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 184 name = nameOveride 185 } 186 187 typ = binder.CopyModifiersFromAst(field.Type, typ) 188 189 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 190 typ = types.NewPointer(typ) 191 } 192 193 ref := "" 194 tag := `json:"` + field.Name + `"` 195 if directive := field.Directives.ForName("gorm"); directive != nil { 196 arg := directive.Arguments.ForName("tag") 197 if arg != nil { 198 tag = fmt.Sprintf(`%s gorm:"%s"`, tag, arg.Value.Raw) 199 } 200 arg = directive.Arguments.ForName("ref") 201 if arg != nil { 202 arg2 := directive.Arguments.ForName("refTag") 203 if arg2 != nil { 204 ref = fmt.Sprintf("%s `json:\"-\" gorm:\"%s\"`", arg.Value.Raw, arg2.Value.Raw) 205 } else { 206 ref = fmt.Sprintf("%s `json:\"-\"`", arg.Value.Raw) 207 } 208 } 209 } 210 211 f := &Field{ 212 GoName: name, 213 Type: typ, 214 Description: field.Description, 215 Tag: tag, 216 Ref: ref, 217 } 218 219 if m.FieldHook != nil { 220 mf, err := m.FieldHook(schemaType, field, f) 221 if err != nil { 222 return fmt.Errorf("generror: field %v.%v: %w", it.Name, field.Name, err) 223 } 224 f = mf 225 } 226 227 it.Fields = append(it.Fields, f) 228 } 229 230 b.Models = append(b.Models, it) 231 case ast.Enum: 232 it := &Enum{ 233 Name: schemaType.Name, 234 Description: schemaType.Description, 235 } 236 237 for _, v := range schemaType.EnumValues { 238 it.Values = append(it.Values, &EnumValue{ 239 Name: v.Name, 240 Description: v.Description, 241 }) 242 } 243 244 b.Enums = append(b.Enums, it) 245 case ast.Scalar: 246 b.Scalars = append(b.Scalars, schemaType.Name) 247 } 248 } 249 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 250 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 251 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 252 253 // if we are not just turning all struct-type fields in generated structs into pointers, we need to at least 254 // check for cyclical relationships and recursive structs 255 if !cfg.StructFieldsAlwaysPointers { 256 findAndHandleCyclicalRelationships(b) 257 } 258 259 for _, it := range b.Enums { 260 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 261 } 262 for _, it := range b.Models { 263 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 264 } 265 for _, it := range b.Interfaces { 266 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 267 } 268 for _, it := range b.Scalars { 269 cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String") 270 } 271 272 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 273 return nil 274 } 275 276 if m.MutateHook != nil { 277 b = m.MutateHook(b) 278 } 279 280 err := templates.Render(templates.Options{ 281 PackageName: cfg.Model.Package, 282 Filename: cfg.Model.Filename, 283 Data: b, 284 GeneratedHeader: true, 285 Packages: cfg.Packages, 286 }) 287 if err != nil { 288 return err 289 } 290 291 // We may have generated code in a package we already loaded, so we reload all packages 292 // to allow packages to be compared correctly 293 cfg.ReloadAllPackages() 294 295 return nil 296 } 297 298 // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field 299 // name is used when no value argument is present. 300 func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 301 args := make([]string, 0) 302 for _, goTag := range fd.Directives.ForNames("goTag") { 303 key := "" 304 value := fd.Name 305 306 if arg := goTag.Arguments.ForName("key"); arg != nil { 307 if k, err := arg.Value.Value(nil); err == nil { 308 key = k.(string) 309 } 310 } 311 312 if arg := goTag.Arguments.ForName("value"); arg != nil { 313 if v, err := arg.Value.Value(nil); err == nil { 314 value = v.(string) 315 } 316 } 317 318 args = append(args, key+":\""+value+"\"") 319 } 320 321 if len(args) > 0 { 322 f.Tag = f.Tag + " " + strings.Join(args, " ") 323 } 324 325 return f, nil 326 } 327 328 func isStruct(t types.Type) bool { 329 _, is := t.Underlying().(*types.Struct) 330 return is 331 } 332 333 // findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them 334 // with pointers. These relationships will produce compilation errors if they are not pointers. 335 // Also handles recursive structs. 336 func findAndHandleCyclicalRelationships(b *ModelBuild) { 337 for ii, structA := range b.Models { 338 for _, fieldA := range structA.Fields { 339 if strings.Contains(fieldA.Type.String(), "NotCyclicalA") { 340 fmt.Print() 341 } 342 if !isStruct(fieldA.Type) { 343 continue 344 } 345 346 // the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA" 347 // we only want the part after the last dot: "LoopA" 348 // this could lead to false positives, as we are only checking the name of the struct type, but these 349 // should be extremely rare, if it is even possible at all. 350 fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".") 351 fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1] 352 353 // find this struct type amongst the generated structs 354 for jj, structB := range b.Models { 355 if structB.Name != fieldAStructName { 356 continue 357 } 358 // check if structB contains a cyclical reference back to structA 359 var cyclicalReferenceFound bool 360 for _, fieldB := range structB.Fields { 361 if !isStruct(fieldB.Type) { 362 continue 363 } 364 365 fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".") 366 fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1] 367 if fieldBStructName == structA.Name { 368 cyclicalReferenceFound = true 369 fieldB.Type = types.NewPointer(fieldB.Type) 370 // keep looping in case this struct has additional fields of this type 371 } 372 } 373 374 // if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once 375 if cyclicalReferenceFound && ii != jj { 376 fieldA.Type = types.NewPointer(fieldA.Type) 377 break 378 } 379 } 380 } 381 } 382 } 383 384 func (r *Plugin) InjectSourceEarly() *ast.Source { 385 return &ast.Source{ 386 Name: "plugin/directives.graphql", 387 Input: ` 388 directive @goField( 389 forceResolver: Boolean 390 name: String 391 ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION 392 393 directive @gorm(tag: String, ref: String, refTag: String) on FIELD_DEFINITION 394 395 scalar Time 396 `, 397 BuiltIn: false, 398 } 399 }