github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/plugin/modelgen/models.go (about) 1 package modelgen 2 3 import ( 4 "fmt" 5 "go/types" 6 "sort" 7 8 "github.com/99designs/gqlgen/codegen/config" 9 "github.com/99designs/gqlgen/codegen/templates" 10 "github.com/99designs/gqlgen/plugin" 11 "github.com/vektah/gqlparser/v2/ast" 12 ) 13 14 type BuildMutateHook = func(b *ModelBuild) *ModelBuild 15 16 func defaultBuildMutateHook(b *ModelBuild) *ModelBuild { 17 return b 18 } 19 20 type ModelBuild struct { 21 PackageName string 22 Interfaces []*Interface 23 Models []*Object 24 Enums []*Enum 25 Scalars []string 26 } 27 28 type Interface struct { 29 Description string 30 Name string 31 } 32 33 type Object struct { 34 Description string 35 Name string 36 Fields []*Field 37 Implements []string 38 } 39 40 type Field struct { 41 Description string 42 Name string 43 Type types.Type 44 Tag string 45 } 46 47 type Enum struct { 48 Description string 49 Name string 50 Values []*EnumValue 51 } 52 53 type EnumValue struct { 54 Description string 55 Name string 56 } 57 58 func New() plugin.Plugin { 59 return &Plugin{ 60 MutateHook: defaultBuildMutateHook, 61 } 62 } 63 64 type Plugin struct { 65 MutateHook BuildMutateHook 66 } 67 68 var _ plugin.ConfigMutator = &Plugin{} 69 70 func (m *Plugin) Name() string { 71 return "modelgen" 72 } 73 74 func (m *Plugin) MutateConfig(cfg *config.Config) error { 75 binder := cfg.NewBinder() 76 77 b := &ModelBuild{ 78 PackageName: cfg.Model.Package, 79 } 80 81 for _, schemaType := range cfg.Schema.Types { 82 if cfg.Models.UserDefined(schemaType.Name) { 83 continue 84 } 85 switch schemaType.Kind { 86 case ast.Interface, ast.Union: 87 it := &Interface{ 88 Description: schemaType.Description, 89 Name: schemaType.Name, 90 } 91 92 b.Interfaces = append(b.Interfaces, it) 93 case ast.Object, ast.InputObject: 94 if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { 95 continue 96 } 97 it := &Object{ 98 Description: schemaType.Description, 99 Name: schemaType.Name, 100 } 101 for _, implementor := range cfg.Schema.GetImplements(schemaType) { 102 it.Implements = append(it.Implements, implementor.Name) 103 } 104 105 for _, field := range schemaType.Fields { 106 var typ types.Type 107 fieldDef := cfg.Schema.Types[field.Type.Name()] 108 109 if cfg.Models.UserDefined(field.Type.Name()) { 110 var err error 111 typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) 112 if err != nil { 113 return err 114 } 115 } else { 116 switch fieldDef.Kind { 117 case ast.Scalar: 118 // no user defined model, referencing a default scalar 119 typ = types.NewNamed( 120 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), 121 nil, 122 nil, 123 ) 124 125 case ast.Interface, ast.Union: 126 // no user defined model, referencing a generated interface type 127 typ = types.NewNamed( 128 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 129 types.NewInterfaceType([]*types.Func{}, []types.Type{}), 130 nil, 131 ) 132 133 case ast.Enum: 134 // no user defined model, must reference a generated enum 135 typ = types.NewNamed( 136 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 137 nil, 138 nil, 139 ) 140 141 case ast.Object, ast.InputObject: 142 // no user defined model, must reference a generated struct 143 typ = types.NewNamed( 144 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), 145 types.NewStruct(nil, nil), 146 nil, 147 ) 148 149 default: 150 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) 151 } 152 } 153 154 name := field.Name 155 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" { 156 name = nameOveride 157 } 158 159 typ = binder.CopyModifiersFromAst(field.Type, typ) 160 161 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { 162 typ = types.NewPointer(typ) 163 } 164 165 it.Fields = append(it.Fields, &Field{ 166 Name: name, 167 Type: typ, 168 Description: field.Description, 169 Tag: `json:"` + field.Name + `"`, 170 }) 171 } 172 173 b.Models = append(b.Models, it) 174 case ast.Enum: 175 it := &Enum{ 176 Name: schemaType.Name, 177 Description: schemaType.Description, 178 } 179 180 for _, v := range schemaType.EnumValues { 181 it.Values = append(it.Values, &EnumValue{ 182 Name: v.Name, 183 Description: v.Description, 184 }) 185 } 186 187 b.Enums = append(b.Enums, it) 188 case ast.Scalar: 189 b.Scalars = append(b.Scalars, schemaType.Name) 190 } 191 } 192 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name }) 193 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name }) 194 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name }) 195 196 for _, it := range b.Enums { 197 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 198 } 199 for _, it := range b.Models { 200 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 201 } 202 for _, it := range b.Interfaces { 203 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name)) 204 } 205 for _, it := range b.Scalars { 206 cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String") 207 } 208 209 if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 { 210 return nil 211 } 212 213 if m.MutateHook != nil { 214 b = m.MutateHook(b) 215 } 216 217 return templates.Render(templates.Options{ 218 PackageName: cfg.Model.Package, 219 Filename: cfg.Model.Filename, 220 Data: b, 221 GeneratedHeader: true, 222 Packages: cfg.Packages, 223 }) 224 } 225 226 func isStruct(t types.Type) bool { 227 _, is := t.Underlying().(*types.Struct) 228 return is 229 }