github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/plugin/schemagen/schema.go (about) 1 package schemagen 2 3 import ( 4 "bytes" 5 // Embedded file 6 _ "embed" 7 "fmt" 8 "go/constant" 9 "go/types" 10 "html/template" 11 "log" 12 "os" 13 "reflect" 14 "regexp" 15 "sort" 16 "strings" 17 18 "github.com/mstephano/gqlgen-schemagen/codegen" 19 "github.com/mstephano/gqlgen-schemagen/codegen/config" 20 "github.com/mstephano/gqlgen-schemagen/internal/code" 21 "github.com/mstephano/gqlgen-schemagen/plugin" 22 "github.com/vektah/gqlparser/v2/ast" 23 "github.com/vektah/gqlparser/v2/parser" 24 "golang.org/x/tools/go/packages" 25 ) 26 27 //go:embed schema.gotpl 28 var fileTemplate string 29 30 const basePath = "graph/schema/" 31 32 type addingModelType int 33 34 const ( 35 composition addingModelType = iota 36 reference 37 ) 38 39 type generator struct { 40 config *config.Config 41 filePath string 42 mapDataTypes map[string]string 43 mapDataTypeExclusions map[string]string 44 modelObjects map[string]*modelObject 45 modelEnums map[string]*enumObject 46 excludedModelObjects map[string]*types.Object 47 packages []*packages.Package 48 } 49 50 type modelObject struct { 51 Model types.Object 52 Fields map[string]*fieldObject 53 Compositions map[string]struct{} 54 References map[string]struct{} 55 } 56 57 type enumObject struct { 58 Model types.Object 59 Values []string 60 } 61 62 type fieldObject struct { 63 Field *types.Var 64 TagValue string 65 TypeName string 66 Required bool 67 } 68 69 type templateData struct { 70 SortedEnumsSlice []*enumObject 71 SortedObjectsSlice []*modelObject 72 } 73 74 // New creates a new schemagen plugin 75 func New(cfg *config.Config, fileName string, mapDataTypes map[string]string, mapDataTypeExclusions map[string]string) plugin.Plugin { 76 return &generator{ 77 config: cfg, 78 filePath: getFilePath(fileName), 79 mapDataTypes: mapDataTypes, 80 mapDataTypeExclusions: mapDataTypeExclusions, 81 modelObjects: make(map[string]*modelObject, 0), 82 modelEnums: make(map[string]*enumObject, 0), 83 excludedModelObjects: make(map[string]*types.Object, 0), 84 packages: make([]*packages.Package, 0), 85 } 86 } 87 88 func (g *generator) Name() string { 89 return "schemagen" 90 } 91 92 func (g *generator) InjectSourceEarly() *ast.Source { 93 g.deleteSource() 94 95 pkgs := &code.Packages{} 96 g.packages = pkgs.LoadAll(g.config.AutoBind...) 97 98 // First level models from schemaTypes 99 astSchemaDoc, err := parser.ParseSchemas(g.config.Sources...) 100 if err != nil { 101 panic(err) 102 } 103 104 schemaTypeNames := make([]string, 0) 105 for _, def := range astSchemaDoc.Definitions { 106 for _, field := range def.Fields { 107 if field.Type.Elem != nil { 108 schemaTypeNames = append(schemaTypeNames, field.Type.Elem.NamedType) 109 } else if field.Type.NamedType != "" { 110 schemaTypeNames = append(schemaTypeNames, field.Type.NamedType) 111 } 112 } 113 } 114 115 for _, p := range g.packages { 116 for _, schemaTypeName := range schemaTypeNames { 117 if schemaType := p.Types.Scope().Lookup(schemaTypeName); schemaType != nil { 118 if validateType(p, schemaType) { 119 if g.addModel(schemaType, schemaType, reference) { 120 fmt.Printf("adding schemaType: %+v\n", schemaTypeName) 121 } 122 } else { 123 fmt.Printf("ignoring schemaType: %+v, %+v\n", p.PkgPath, schemaTypeName) 124 } 125 } 126 } 127 } 128 129 // Add all models from all packages 130 loop := 0 131 for { 132 loop++ 133 modelAdded := 0 134 135 // Add references & composition 136 for _, mo := range g.modelObjects { 137 switch typ := mo.Model.Type().Underlying().(type) { 138 case *types.Struct: 139 for i := 0; i < typ.NumFields(); i++ { 140 field := typ.Field(i) 141 tagValue, _, hidden := getJSONTagValue(typ.Tag(i)) 142 if hidden { 143 continue 144 } 145 146 if g.preAddModel(mo, field, tagValue) { 147 modelAdded++ 148 } 149 } 150 case *types.Basic: // enum 151 if g.addEnum(mo) { 152 modelAdded++ 153 } 154 default: 155 fmt.Printf("Not supported type: %+v, %+v, %+v\n", mo.Model.Name(), typ.String(), reflect.TypeOf(mo.Model.Type().Underlying())) 156 } 157 } 158 fmt.Printf("Loop # %+v - models added: %+v\n", loop, modelAdded) 159 160 if modelAdded == 0 { 161 break 162 } 163 } 164 165 // Add all fields for every modelObject 166 modelObjects := make(map[string]*modelObject, 0) // keep single model by name, models with same name will have their fields consolidated 167 for _, mo := range g.modelObjects { 168 moToAddFields := mo 169 if obj, ok := modelObjects[mo.Model.Name()]; ok { 170 moToAddFields = obj // when same name exist, consolidate all the fields to the same object 171 172 // keep the one with references 173 if len(obj.References) == 0 && len(mo.References) > 0 { 174 mo.Fields = obj.Fields 175 moToAddFields = mo 176 modelObjects[mo.Model.Name()] = mo 177 } 178 } else { 179 modelObjects[mo.Model.Name()] = mo 180 } 181 182 g.addFieldsToModelObject(mo.Model.Type().Underlying(), *moToAddFields, false) 183 } 184 185 // Complete composition with missing fields 186 loop = 0 187 for { 188 loop++ 189 fieldAdded := 0 190 for _, v := range modelObjects { 191 for c := range v.Compositions { 192 if obj, exist := modelObjects[c]; exist { 193 for _, f := range v.Fields { 194 if _, fieldExist := obj.Fields[f.TagValue]; !fieldExist { 195 obj.Fields[f.TagValue] = &fieldObject{ 196 Field: f.Field, 197 TagValue: f.TagValue, 198 TypeName: f.TypeName, 199 Required: f.Required, 200 } 201 202 fieldAdded++ 203 } 204 } 205 } 206 } 207 } 208 fmt.Printf("Loop # %+v - fields added: %+v\n", loop, fieldAdded) 209 210 if fieldAdded == 0 { 211 break 212 } 213 } 214 215 // Sort modelObjects 216 modelObjectsSlice := make([]*modelObject, 0, len(modelObjects)) 217 for _, v := range modelObjects { 218 modelObjectsSlice = append(modelObjectsSlice, v) 219 } 220 sort.Slice(modelObjectsSlice, func(i, j int) bool { 221 return modelObjectsSlice[i].Model.Name() < modelObjectsSlice[j].Model.Name() 222 }) 223 // Sort modelObjects fields 224 for _, mo := range modelObjectsSlice { 225 fields := make(map[string]*fieldObject, len(mo.Fields)) 226 fieldKeys := make([]string, 0, len(mo.Fields)) 227 for k := range mo.Fields { 228 fieldKeys = append(fieldKeys, k) 229 } 230 sort.Strings((fieldKeys)) 231 232 for _, fieldKey := range fieldKeys { 233 field := mo.Fields[fieldKey] 234 fields[fieldKey] = field 235 } 236 mo.Fields = fields 237 } 238 239 // Sort modelEnums 240 modelEnumsSlice := make([]*enumObject, 0, len(g.modelEnums)) 241 for _, v := range g.modelEnums { 242 modelEnumsSlice = append(modelEnumsSlice, v) 243 } 244 sort.Slice(modelEnumsSlice, func(i, j int) bool { 245 return modelEnumsSlice[i].Model.Name() < modelEnumsSlice[j].Model.Name() 246 }) 247 // Sort modelEnums values 248 for _, v := range modelEnumsSlice { 249 sort.Strings(v.Values) 250 } 251 252 // Create input 253 input := g.renderTemplate(&templateData{ 254 SortedEnumsSlice: modelEnumsSlice, 255 SortedObjectsSlice: modelObjectsSlice, 256 }) 257 258 // Create source 259 fmt.Printf("Generating schema file ...\n") 260 source := &ast.Source{ 261 Name: g.filePath, 262 BuiltIn: false, 263 Input: input, 264 } 265 // fmt.Printf("%s\n", source.Input) 266 267 return source 268 } 269 270 func (g *generator) GenerateCode(data *codegen.Data) error { 271 for _, s := range data.Config.Sources { 272 if s.Name == g.filePath { 273 f, err := os.OpenFile(g.filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755) 274 if err != nil { 275 log.Fatal(err) 276 } 277 defer f.Close() 278 f.WriteString(s.Input) 279 fmt.Printf("Generated schema file location: %s", g.filePath) 280 break 281 } 282 } 283 284 return nil 285 } 286 287 func (g *generator) preAddModel(mo *modelObject, field *types.Var, tagValue string) bool { 288 addModelType := reference 289 addModelTypeName := "reference" 290 if tagValue == "" { 291 addModelType = composition 292 addModelTypeName = "composition" 293 } 294 295 pkgPath, fullTypeName, baseTypeName, typeName, isBasic := g.getTypeName(field.Name(), field.Type(), false) 296 if addModelType == reference && isBasic { 297 return false 298 } 299 300 added := false 301 exist := false 302 for _, p := range g.packages { 303 if p.PkgPath != pkgPath { 304 continue 305 } 306 307 if modelType := p.Types.Scope().Lookup(baseTypeName); modelType != nil { 308 if !exist { 309 if g.addModel(modelType, mo.Model, addModelType) { 310 fmt.Printf("adding %s: %+v\n", addModelTypeName, modelType.Name()) 311 added = true 312 } 313 exist = true 314 } else { 315 fmt.Printf("%s found but already exist: %+v, %+v\n", addModelTypeName, fullTypeName, modelType.Name()) 316 } 317 } 318 } 319 if !added && !exist && !strings.HasPrefix(typeName, "Map") { 320 fmt.Printf("missing %s - model: %+v - field: %+v - pkgPath: %+v\n", addModelTypeName, mo.Model.Name(), field.Name(), pkgPath) 321 } 322 323 return added 324 } 325 326 func (g *generator) addEnum(mo *modelObject) bool { 327 addModelTypeName := "enum" 328 329 pkgPath, fullTypeName, baseTypeName, typeName, isBasic := g.getEnumName(mo.Model.Type(), false) 330 if isBasic { 331 return false 332 } 333 334 added := false 335 exist := false 336 for _, p := range g.packages { 337 if p.PkgPath != pkgPath { 338 continue 339 } 340 341 if modelType := p.Types.Scope().Lookup(baseTypeName); modelType != nil { 342 if !exist { 343 if _, ok := g.modelEnums[typeName]; !ok { 344 g.modelEnums[typeName] = &enumObject{ 345 Model: mo.Model, 346 Values: make([]string, 0), 347 } 348 fmt.Printf("adding %s: %+v\n", addModelTypeName, modelType.Name()) 349 added = true 350 351 for _, n := range p.Types.Scope().Names() { 352 if o := p.Types.Scope().Lookup(n); modelType != nil { 353 if typ, ok := o.(*types.Const); ok { 354 if typ.Type().String() == fullTypeName && typ.Type().Underlying().String() == "string" { 355 g.modelEnums[typeName].Values = append(g.modelEnums[typeName].Values, constant.StringVal(typ.Val())) 356 } 357 } 358 } 359 } 360 } 361 exist = true 362 } else { 363 fmt.Printf("%s found but already exist: %+v, %+v\n", addModelTypeName, fullTypeName, modelType.Name()) 364 } 365 } 366 } 367 if !added && !exist { 368 for k := range mo.References { 369 fmt.Printf("missing %s - model: %+v - field: %+v - pkgPath: %+v\n", addModelTypeName, mo.Model.Name(), k, pkgPath) 370 break 371 } 372 } 373 374 return added 375 } 376 377 func (g *generator) addModel(model types.Object, fromModel types.Object, addModelType addingModelType) bool { 378 if _, exist := g.mapDataTypeExclusions[model.Pkg().Name()+"."+model.Name()]; exist { 379 if _, exist := g.excludedModelObjects[model.Name()]; !exist { 380 g.excludedModelObjects[model.Name()] = &model 381 fmt.Printf("excluded model: %+v\n", model.Name()) 382 } 383 return false 384 } 385 386 added := false 387 typeName := fmt.Sprintf("%s.%s", model.Pkg().Path(), model.Name()) 388 if _, exist := g.modelObjects[typeName]; !exist { 389 g.modelObjects[typeName] = &modelObject{ 390 Model: model, 391 Fields: make(map[string]*fieldObject, 0), 392 Compositions: make(map[string]struct{}, 0), 393 References: make(map[string]struct{}, 0), 394 } 395 added = true 396 } 397 398 mo := g.modelObjects[typeName] 399 switch addModelType { 400 case composition: 401 if _, exist := mo.Compositions[fromModel.Name()]; !exist { 402 mo.Compositions[fromModel.Name()] = struct{}{} 403 } 404 case reference: 405 if _, exist := mo.References[fromModel.Name()]; !exist { 406 mo.References[fromModel.Name()] = struct{}{} 407 } 408 } 409 410 return added 411 } 412 413 func (g *generator) getTypeName(fieldName string, obj types.Type, isSlice bool) (string, string, string, string, bool) { 414 pkgPath := "" 415 fullTypeName := "" 416 baseTypeName := "" 417 typeName := "" 418 isBasic := false 419 420 switch typ := obj.(type) { 421 case *types.Basic: 422 pkgPath = typ.Underlying().String() 423 fullTypeName = typ.Underlying().String() + "." + typ.Name() 424 baseTypeName = typ.Name() 425 isBasic = true 426 case *types.Named: 427 if _, ok := typ.Obj().Type().Underlying().(*types.Map); ok { 428 pkgPath = fmt.Sprintf("Map%s", typ.Obj().Name()) 429 fullTypeName = pkgPath 430 baseTypeName = pkgPath 431 } else { 432 pkgPath = typ.Obj().Pkg().Path() 433 fullTypeName = typ.Obj().Pkg().Path() + "." + typ.Obj().Name() 434 baseTypeName = typ.Obj().Pkg().Name() + "." + typ.Obj().Name() 435 if _, ok := g.mapDataTypes[baseTypeName]; !ok { 436 baseTypeName = typ.Obj().Name() 437 if _, ok := typ.Obj().Type().Underlying().(*types.Slice); ok { 438 isSlice = true 439 } 440 } 441 } 442 case *types.Pointer: 443 pkgPath, fullTypeName, baseTypeName, typeName, isBasic = g.getTypeName(fieldName, typ.Elem(), isSlice) 444 return pkgPath, fullTypeName, baseTypeName, typeName, isBasic 445 case *types.Slice: 446 pkgPath, fullTypeName, baseTypeName, typeName, isBasic = g.getTypeName(fieldName, typ.Elem(), true) 447 return pkgPath, fullTypeName, baseTypeName, typeName, isBasic 448 case *types.Map: 449 pkgPath = fmt.Sprintf("Map%s", typ.Elem().String()) 450 fullTypeName = pkgPath 451 baseTypeName = pkgPath 452 default: 453 fmt.Printf("type not supported - field: %+v - type: %+v\n", fieldName, reflect.TypeOf(obj)) 454 } 455 456 if mapName, ok := g.mapDataTypes[baseTypeName]; ok { 457 baseTypeName = mapName 458 isBasic = true 459 } 460 if isSlice { 461 typeName = "[" + baseTypeName + "!]" 462 } else { 463 typeName = baseTypeName 464 } 465 return pkgPath, fullTypeName, baseTypeName, typeName, isBasic 466 } 467 468 func (g *generator) getEnumName(obj types.Type, isSlice bool) (string, string, string, string, bool) { 469 pkgPath := "" 470 fullTypeName := "" 471 baseTypeName := "" 472 typeName := "" 473 isBasic := false 474 475 switch typ := obj.(type) { 476 case *types.Named: 477 pkgPath = typ.Obj().Pkg().Path() 478 fullTypeName = typ.Obj().Pkg().Path() + "." + typ.Obj().Name() 479 baseTypeName = typ.Obj().Pkg().Name() + "." + typ.Obj().Name() 480 if _, ok := g.mapDataTypes[baseTypeName]; !ok { 481 baseTypeName = typ.Obj().Name() 482 if _, ok := typ.Obj().Type().Underlying().(*types.Slice); ok { 483 isSlice = true 484 } 485 } 486 default: 487 fmt.Printf("type not supported - field: %+v - type: %+v\n", obj.String(), reflect.TypeOf(obj)) 488 } 489 490 if mapName, ok := g.mapDataTypes[baseTypeName]; ok { 491 baseTypeName = mapName 492 isBasic = true 493 } 494 if isSlice { 495 typeName = "[" + baseTypeName + "!]" 496 } else { 497 typeName = baseTypeName 498 } 499 return pkgPath, fullTypeName, baseTypeName, typeName, isBasic 500 } 501 502 func (g *generator) addFieldsToModelObject(obj types.Type, modelObject modelObject, isSlice bool) { 503 switch typ := obj.(type) { 504 case *types.Struct: 505 for i := 0; i < typ.NumFields(); i++ { 506 field := typ.Field(i) 507 tagValue, required, hidden := getJSONTagValue(typ.Tag(i)) 508 if hidden { 509 continue 510 } 511 512 _, fullTypeName, baseTypeName, typeName, isBasic := g.getTypeName(field.Name(), field.Type(), isSlice) 513 if tagValue != "" { 514 if isBasic { 515 if baseTypeName != "" { 516 modelObject.addField(field, tagValue, typeName, required) 517 } 518 } else { 519 _, ok := g.modelObjects[fullTypeName] 520 if ok || strings.HasPrefix(typeName, "Map") { 521 modelObject.addField(field, tagValue, typeName, required) 522 } 523 } 524 } else { 525 // first level composition 526 if o, ok := g.modelObjects[fullTypeName]; ok { 527 if u, ok := o.Model.Type().Underlying().(*types.Struct); ok { 528 for i := 0; i < u.NumFields(); i++ { 529 field := u.Field(i) 530 tagValue, required, hidden := getJSONTagValue(u.Tag(i)) 531 _, _, baseTypeName, typeName, isBasic := g.getTypeName(field.Name(), field.Type(), isSlice) 532 if tagValue != "" && !hidden && isBasic && baseTypeName != "" { 533 modelObject.addField(field, tagValue, typeName, required) 534 } 535 } 536 } 537 } 538 } 539 } 540 case *types.Slice: 541 g.addFieldsToModelObject(typ.Elem().Underlying(), modelObject, true) 542 return 543 case *types.Pointer: 544 g.addFieldsToModelObject(typ.Elem().Underlying(), modelObject, false) 545 return 546 case *types.Basic: // enum 547 return // managed in addEnum() 548 default: 549 fmt.Printf("field not supported - model: %s - field: %s - type: %+v\n", modelObject.Model.Name(), obj.String(), reflect.TypeOf(obj)) 550 } 551 } 552 553 func (g *generator) deleteSource() { 554 // Delete existing source to avoid type collisions 555 for i, s := range g.config.Sources { 556 if s.Name == g.filePath { 557 g.config.Sources = append(g.config.Sources[:i], g.config.Sources[i+1:]...) 558 break 559 } 560 } 561 } 562 563 func (g *generator) renderTemplate(templateData *templateData) string { 564 var buf bytes.Buffer 565 if err := template.Must(template.New("test.graphqls").Funcs(template.FuncMap{ 566 "GetSchemaField": GetSchemaFieldGotpl, 567 }).Parse(fileTemplate)).Execute(&buf, templateData); err != nil { 568 panic(err) 569 } 570 return buf.String() 571 } 572 573 func (o *modelObject) addField(field *types.Var, tagValue, typeName string, required bool) bool { 574 if _, exist := o.Fields[tagValue]; !exist { 575 o.Fields[tagValue] = &fieldObject{ 576 Field: field, 577 TagValue: tagValue, 578 TypeName: typeName, 579 Required: required, 580 } 581 return true 582 } 583 return false 584 } 585 586 func validateType(pkg *packages.Package, obj types.Object) bool { 587 if named, ok := obj.Type().(*types.Named); ok { 588 if s, ok := named.Underlying().(*types.Struct); ok { 589 for i := 0; i < s.NumFields(); i++ { 590 if s.Field(i).Pkg().Path() == pkg.PkgPath { 591 return true 592 } 593 } 594 } 595 } 596 return false 597 } 598 599 func getJSONTagValue(tag string) (string, bool, bool) { 600 required := false 601 hidden := false 602 if tag != "" && strings.Contains(tag, "json:") { 603 required := !strings.Contains(tag, "omitempty") 604 re := regexp.MustCompile(`json:\"(.*?)\"`) 605 res := re.FindAllStringSubmatch(tag, -1) 606 propertyName := strings.ReplaceAll(res[0][1], ",omitempty", "") 607 propertyName = strings.ReplaceAll(propertyName, ",", "") 608 609 if propertyName == "-" { 610 hidden = true 611 } 612 return propertyName, required, hidden 613 } 614 return "", required, hidden 615 } 616 617 // GetSchemaFieldGotpl is used in the go template 618 func GetSchemaFieldGotpl(tagValue, typeName string, required bool) string { 619 field := fmt.Sprintf("%s: %s", tagValue, typeName) 620 if required { 621 field = fmt.Sprintf("%s: %s!", tagValue, typeName) 622 } 623 return field 624 } 625 626 func getFilePath(filename string) string { 627 return fmt.Sprintf("%s%s", basePath, filename) 628 }