github.com/operandinc/gqlgen@v0.16.1/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 "fmt" 5 "go/types" 6 "sort" 7 "strings" 8 9 "github.com/operandinc/gqlgen/codegen/config" 10 "github.com/operandinc/gqlgen/codegen/templates" 11 "github.com/operandinc/gqlgen/plugin" 12 "github.com/vektah/gqlparser/v2/ast" 13 ) 14 15 type BuildMutateHook = func(b *ModelBuild) *ModelBuild 16 17 type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) 18 19 // defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook. 20 func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 21 return GoTagFieldHook(td, fd, f) 22 } 23 24 func defaultBuildMutateHook(b *ModelBuild) *ModelBuild { 25 return b 26 } 27 28 type ModelBuild struct { 29 PackageName string 30 Interfaces []*Interface 31 Models []*Object 32 Enums []*Enum 33 Scalars []string 34 } 35 36 type Interface struct { 37 Description string 38 Name string 39 } 40 41 type Object struct { 42 Description string 43 Name string 44 Fields []*Field 45 Implements []string 46 } 47 48 type Field struct { 49 Description string 50 Name string 51 Type types.Type 52 Tag string 53 } 54 55 type Enum struct { 56 Description string 57 Name string 58 Values []*EnumValue 59 } 60 61 type EnumValue struct { 62 Description string 63 Name string 64 } 65 66 func New() plugin.Plugin { 67 return &Plugin{ 68 MutateHook: defaultBuildMutateHook, 69 FieldHook: defaultFieldMutateHook, 70 } 71 } 72 73 type Plugin struct { 74 MutateHook BuildMutateHook 75 FieldHook FieldMutateHook 76 } 77 78 var _ plugin.ConfigMutator = &Plugin{} 79 80 func (m *Plugin) Name() string { 81 return "modelgen" 82 } 83 84 func (m *Plugin) MutateConfig(cfg *config.Config) error { 85 binder := cfg.NewBinder() 86 87 b := &ModelBuild{ 88 PackageName: cfg.Model.Package, 89 } 90 91 for _, schemaType := range cfg.Schema.Types { 92 if cfg.Models.UserDefined(schemaType.Name) { 93 continue 94 } 95 switch schemaType.Kind { 96 case ast.Interface, ast.Union: 97 it := &Interface{ 98 Description: schemaType.Description, 99 Name: schemaType.Name, 100 } 101 102 b.Interfaces = append(b.Interfaces, it) 103 case ast.Object, ast.InputObject: 104 if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { 105 continue 106 } 107 it := &Object{ 108 Description: schemaType.Description, 109 Name: schemaType.Name, 110 } 111 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 112 it.Implements = append(it.Implements, implementor.Name) 113 } 114 115 for _, field := range schemaType.Fields { 116 var typ types.Type 117 fieldDef := cfg.Schema.Types[field.Type.Name()] 118 119 if cfg.Models.UserDefined(field.Type.Name()) { 120 var err error 121 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 122 if err != nil { 123 return err 124 } 125 } else { 126 switch fieldDef.Kind { 127 case ast.Scalar: 128 // no user defined model, referencing a default scalar 129 typ = types.NewNamed( 130 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 131 nil, 132 nil, 133 ) 134 135 case ast.Interface, ast.Union: 136 // no user defined model, referencing a generated interface type 137 typ = types.NewNamed( 138 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 139 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 140 nil, 141 ) 142 143 case ast.Enum: 144 // no user defined model, must reference a generated enum 145 typ = types.NewNamed( 146 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 147 nil, 148 nil, 149 ) 150 151 case ast.Object, ast.InputObject: 152 // no user defined model, must reference a generated struct 153 typ = types.NewNamed( 154 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 155 types.NewStruct(nil, nil), 156 nil, 157 ) 158 159 default: 160 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 161 } 162 } 163 164 name := field.Name 165 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 166 name = nameOveride 167 } 168 169 typ = binder.CopyModifiersFromAst(field.Type, typ) 170 171 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 172 typ = types.NewPointer(typ) 173 } 174 175 f := &Field{ 176 Name: name, 177 Type: typ, 178 Description: field.Description, 179 Tag: `json:"` + field.Name + `"`, 180 } 181 182 if m.FieldHook != nil { 183 mf, err := m.FieldHook(schemaType, field, f) 184 if err != nil { 185 return fmt.Errorf("generror: field %v.%v: %w", it.Name, field.Name, err) 186 } 187 f = mf 188 } 189 190 it.Fields = append(it.Fields, f) 191 } 192 193 b.Models = append(b.Models, it) 194 case ast.Enum: 195 it := &Enum{ 196 Name: schemaType.Name, 197 Description: schemaType.Description, 198 } 199 200 for _, v := range schemaType.EnumValues { 201 it.Values = append(it.Values, &EnumValue{ 202 Name: v.Name, 203 Description: v.Description, 204 }) 205 } 206 207 b.Enums = append(b.Enums, it) 208 case ast.Scalar: 209 b.Scalars = append(b.Scalars, schemaType.Name) 210 } 211 } 212 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 213 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 214 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 215 216 for _, it := range b.Enums { 217 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 218 } 219 for _, it := range b.Models { 220 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 221 } 222 for _, it := range b.Interfaces { 223 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 224 } 225 for _, it := range b.Scalars { 226 cfg.Models.Add(it, "github.com/operandinc/gqlgen/graphql.String") 227 } 228 229 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 230 return nil 231 } 232 233 if m.MutateHook != nil { 234 b = m.MutateHook(b) 235 } 236 237 err := templates.Render(templates.Options{ 238 PackageName: cfg.Model.Package, 239 Filename: cfg.Model.Filename, 240 Data: b, 241 GeneratedHeader: true, 242 Packages: cfg.Packages, 243 }) 244 if err != nil { 245 return err 246 } 247 248 // We may have generated code in a package we already loaded, so we reload all packages 249 // to allow packages to be compared correctly 250 cfg.ReloadAllPackages() 251 252 return nil 253 } 254 255 // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field 256 // name is used when no value argument is present. 257 func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { 258 args := make([]string, 0) 259 for _, goTag := range fd.Directives.ForNames("goTag") { 260 key := "" 261 value := fd.Name 262 263 if arg := goTag.Arguments.ForName("key"); arg != nil { 264 if k, err := arg.Value.Value(nil); err == nil { 265 key = k.(string) 266 } 267 } 268 269 if arg := goTag.Arguments.ForName("value"); arg != nil { 270 if v, err := arg.Value.Value(nil); err == nil { 271 value = v.(string) 272 } 273 } 274 275 args = append(args, key+":\""+value+"\"") 276 } 277 278 if len(args) > 0 { 279 f.Tag = f.Tag + " " + strings.Join(args, " ") 280 } 281 282 return f, nil 283 } 284 285 func isStruct(t types.Type) bool { 286 _, is := t.Underlying().(*types.Struct) 287 return is 288 }