github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/plugin/schemagen/schema.go (about)

     1  package schemagen
     2  
     3  import (
     4  	"bytes"
     5  	// Embedded file
     6  	_ "embed"
     7  	"fmt"
     8  	"go/constant"
     9  	"go/types"
    10  	"html/template"
    11  	"log"
    12  	"os"
    13  	"reflect"
    14  	"regexp"
    15  	"sort"
    16  	"strings"
    17  
    18  	"github.com/mstephano/gqlgen-schemagen/codegen"
    19  	"github.com/mstephano/gqlgen-schemagen/codegen/config"
    20  	"github.com/mstephano/gqlgen-schemagen/internal/code"
    21  	"github.com/mstephano/gqlgen-schemagen/plugin"
    22  	"github.com/vektah/gqlparser/v2/ast"
    23  	"github.com/vektah/gqlparser/v2/parser"
    24  	"golang.org/x/tools/go/packages"
    25  )
    26  
    27  //go:embed schema.gotpl
    28  var fileTemplate string
    29  
    30  const basePath = "graph/schema/"
    31  
    32  type addingModelType int
    33  
    34  const (
    35  	composition addingModelType = iota
    36  	reference
    37  )
    38  
    39  type generator struct {
    40  	config                *config.Config
    41  	filePath              string
    42  	mapDataTypes          map[string]string
    43  	mapDataTypeExclusions map[string]string
    44  	modelObjects          map[string]*modelObject
    45  	modelEnums            map[string]*enumObject
    46  	excludedModelObjects  map[string]*types.Object
    47  	packages              []*packages.Package
    48  }
    49  
    50  type modelObject struct {
    51  	Model        types.Object
    52  	Fields       map[string]*fieldObject
    53  	Compositions map[string]struct{}
    54  	References   map[string]struct{}
    55  }
    56  
    57  type enumObject struct {
    58  	Model  types.Object
    59  	Values []string
    60  }
    61  
    62  type fieldObject struct {
    63  	Field    *types.Var
    64  	TagValue string
    65  	TypeName string
    66  	Required bool
    67  }
    68  
    69  type templateData struct {
    70  	SortedEnumsSlice   []*enumObject
    71  	SortedObjectsSlice []*modelObject
    72  }
    73  
    74  // New creates a new schemagen plugin
    75  func New(cfg *config.Config, fileName string, mapDataTypes map[string]string, mapDataTypeExclusions map[string]string) plugin.Plugin {
    76  	return &generator{
    77  		config:                cfg,
    78  		filePath:              getFilePath(fileName),
    79  		mapDataTypes:          mapDataTypes,
    80  		mapDataTypeExclusions: mapDataTypeExclusions,
    81  		modelObjects:          make(map[string]*modelObject, 0),
    82  		modelEnums:            make(map[string]*enumObject, 0),
    83  		excludedModelObjects:  make(map[string]*types.Object, 0),
    84  		packages:              make([]*packages.Package, 0),
    85  	}
    86  }
    87  
    88  func (g *generator) Name() string {
    89  	return "schemagen"
    90  }
    91  
    92  func (g *generator) InjectSourceEarly() *ast.Source {
    93  	g.deleteSource()
    94  
    95  	pkgs := &code.Packages{}
    96  	g.packages = pkgs.LoadAll(g.config.AutoBind...)
    97  
    98  	// First level models from schemaTypes
    99  	astSchemaDoc, err := parser.ParseSchemas(g.config.Sources...)
   100  	if err != nil {
   101  		panic(err)
   102  	}
   103  
   104  	schemaTypeNames := make([]string, 0)
   105  	for _, def := range astSchemaDoc.Definitions {
   106  		for _, field := range def.Fields {
   107  			if field.Type.Elem != nil {
   108  				schemaTypeNames = append(schemaTypeNames, field.Type.Elem.NamedType)
   109  			} else if field.Type.NamedType != "" {
   110  				schemaTypeNames = append(schemaTypeNames, field.Type.NamedType)
   111  			}
   112  		}
   113  	}
   114  
   115  	for _, p := range g.packages {
   116  		for _, schemaTypeName := range schemaTypeNames {
   117  			if schemaType := p.Types.Scope().Lookup(schemaTypeName); schemaType != nil {
   118  				if validateType(p, schemaType) {
   119  					if g.addModel(schemaType, schemaType, reference) {
   120  						fmt.Printf("adding schemaType: %+v\n", schemaTypeName)
   121  					}
   122  				} else {
   123  					fmt.Printf("ignoring schemaType: %+v, %+v\n", p.PkgPath, schemaTypeName)
   124  				}
   125  			}
   126  		}
   127  	}
   128  
   129  	// Add all models from all packages
   130  	loop := 0
   131  	for {
   132  		loop++
   133  		modelAdded := 0
   134  
   135  		// Add references & composition
   136  		for _, mo := range g.modelObjects {
   137  			switch typ := mo.Model.Type().Underlying().(type) {
   138  			case *types.Struct:
   139  				for i := 0; i < typ.NumFields(); i++ {
   140  					field := typ.Field(i)
   141  					tagValue, _, hidden := getJSONTagValue(typ.Tag(i))
   142  					if hidden {
   143  						continue
   144  					}
   145  
   146  					if g.preAddModel(mo, field, tagValue) {
   147  						modelAdded++
   148  					}
   149  				}
   150  			case *types.Basic: // enum
   151  				if g.addEnum(mo) {
   152  					modelAdded++
   153  				}
   154  			default:
   155  				fmt.Printf("Not supported type: %+v, %+v, %+v\n", mo.Model.Name(), typ.String(), reflect.TypeOf(mo.Model.Type().Underlying()))
   156  			}
   157  		}
   158  		fmt.Printf("Loop # %+v - models added: %+v\n", loop, modelAdded)
   159  
   160  		if modelAdded == 0 {
   161  			break
   162  		}
   163  	}
   164  
   165  	// Add all fields for every modelObject
   166  	modelObjects := make(map[string]*modelObject, 0) // keep single model by name, models with same name will have their fields consolidated
   167  	for _, mo := range g.modelObjects {
   168  		moToAddFields := mo
   169  		if obj, ok := modelObjects[mo.Model.Name()]; ok {
   170  			moToAddFields = obj // when same name exist, consolidate all the fields to the same object
   171  
   172  			// keep the one with references
   173  			if len(obj.References) == 0 && len(mo.References) > 0 {
   174  				mo.Fields = obj.Fields
   175  				moToAddFields = mo
   176  				modelObjects[mo.Model.Name()] = mo
   177  			}
   178  		} else {
   179  			modelObjects[mo.Model.Name()] = mo
   180  		}
   181  
   182  		g.addFieldsToModelObject(mo.Model.Type().Underlying(), *moToAddFields, false)
   183  	}
   184  
   185  	// Complete composition with missing fields
   186  	loop = 0
   187  	for {
   188  		loop++
   189  		fieldAdded := 0
   190  		for _, v := range modelObjects {
   191  			for c := range v.Compositions {
   192  				if obj, exist := modelObjects[c]; exist {
   193  					for _, f := range v.Fields {
   194  						if _, fieldExist := obj.Fields[f.TagValue]; !fieldExist {
   195  							obj.Fields[f.TagValue] = &fieldObject{
   196  								Field:    f.Field,
   197  								TagValue: f.TagValue,
   198  								TypeName: f.TypeName,
   199  								Required: f.Required,
   200  							}
   201  
   202  							fieldAdded++
   203  						}
   204  					}
   205  				}
   206  			}
   207  		}
   208  		fmt.Printf("Loop # %+v - fields added: %+v\n", loop, fieldAdded)
   209  
   210  		if fieldAdded == 0 {
   211  			break
   212  		}
   213  	}
   214  
   215  	// Sort modelObjects
   216  	modelObjectsSlice := make([]*modelObject, 0, len(modelObjects))
   217  	for _, v := range modelObjects {
   218  		modelObjectsSlice = append(modelObjectsSlice, v)
   219  	}
   220  	sort.Slice(modelObjectsSlice, func(i, j int) bool {
   221  		return modelObjectsSlice[i].Model.Name() < modelObjectsSlice[j].Model.Name()
   222  	})
   223  	// Sort modelObjects fields
   224  	for _, mo := range modelObjectsSlice {
   225  		fields := make(map[string]*fieldObject, len(mo.Fields))
   226  		fieldKeys := make([]string, 0, len(mo.Fields))
   227  		for k := range mo.Fields {
   228  			fieldKeys = append(fieldKeys, k)
   229  		}
   230  		sort.Strings((fieldKeys))
   231  
   232  		for _, fieldKey := range fieldKeys {
   233  			field := mo.Fields[fieldKey]
   234  			fields[fieldKey] = field
   235  		}
   236  		mo.Fields = fields
   237  	}
   238  
   239  	// Sort modelEnums
   240  	modelEnumsSlice := make([]*enumObject, 0, len(g.modelEnums))
   241  	for _, v := range g.modelEnums {
   242  		modelEnumsSlice = append(modelEnumsSlice, v)
   243  	}
   244  	sort.Slice(modelEnumsSlice, func(i, j int) bool {
   245  		return modelEnumsSlice[i].Model.Name() < modelEnumsSlice[j].Model.Name()
   246  	})
   247  	// Sort modelEnums values
   248  	for _, v := range modelEnumsSlice {
   249  		sort.Strings(v.Values)
   250  	}
   251  
   252  	// Create input
   253  	input := g.renderTemplate(&templateData{
   254  		SortedEnumsSlice:   modelEnumsSlice,
   255  		SortedObjectsSlice: modelObjectsSlice,
   256  	})
   257  
   258  	// Create source
   259  	fmt.Printf("Generating schema file ...\n")
   260  	source := &ast.Source{
   261  		Name:    g.filePath,
   262  		BuiltIn: false,
   263  		Input:   input,
   264  	}
   265  	// fmt.Printf("%s\n", source.Input)
   266  
   267  	return source
   268  }
   269  
   270  func (g *generator) GenerateCode(data *codegen.Data) error {
   271  	for _, s := range data.Config.Sources {
   272  		if s.Name == g.filePath {
   273  			f, err := os.OpenFile(g.filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755)
   274  			if err != nil {
   275  				log.Fatal(err)
   276  			}
   277  			defer f.Close()
   278  			f.WriteString(s.Input)
   279  			fmt.Printf("Generated schema file location: %s", g.filePath)
   280  			break
   281  		}
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  func (g *generator) preAddModel(mo *modelObject, field *types.Var, tagValue string) bool {
   288  	addModelType := reference
   289  	addModelTypeName := "reference"
   290  	if tagValue == "" {
   291  		addModelType = composition
   292  		addModelTypeName = "composition"
   293  	}
   294  
   295  	pkgPath, fullTypeName, baseTypeName, typeName, isBasic := g.getTypeName(field.Name(), field.Type(), false)
   296  	if addModelType == reference && isBasic {
   297  		return false
   298  	}
   299  
   300  	added := false
   301  	exist := false
   302  	for _, p := range g.packages {
   303  		if p.PkgPath != pkgPath {
   304  			continue
   305  		}
   306  
   307  		if modelType := p.Types.Scope().Lookup(baseTypeName); modelType != nil {
   308  			if !exist {
   309  				if g.addModel(modelType, mo.Model, addModelType) {
   310  					fmt.Printf("adding %s: %+v\n", addModelTypeName, modelType.Name())
   311  					added = true
   312  				}
   313  				exist = true
   314  			} else {
   315  				fmt.Printf("%s found but already exist: %+v, %+v\n", addModelTypeName, fullTypeName, modelType.Name())
   316  			}
   317  		}
   318  	}
   319  	if !added && !exist && !strings.HasPrefix(typeName, "Map") {
   320  		fmt.Printf("missing %s - model: %+v - field: %+v - pkgPath: %+v\n", addModelTypeName, mo.Model.Name(), field.Name(), pkgPath)
   321  	}
   322  
   323  	return added
   324  }
   325  
   326  func (g *generator) addEnum(mo *modelObject) bool {
   327  	addModelTypeName := "enum"
   328  
   329  	pkgPath, fullTypeName, baseTypeName, typeName, isBasic := g.getEnumName(mo.Model.Type(), false)
   330  	if isBasic {
   331  		return false
   332  	}
   333  
   334  	added := false
   335  	exist := false
   336  	for _, p := range g.packages {
   337  		if p.PkgPath != pkgPath {
   338  			continue
   339  		}
   340  
   341  		if modelType := p.Types.Scope().Lookup(baseTypeName); modelType != nil {
   342  			if !exist {
   343  				if _, ok := g.modelEnums[typeName]; !ok {
   344  					g.modelEnums[typeName] = &enumObject{
   345  						Model:  mo.Model,
   346  						Values: make([]string, 0),
   347  					}
   348  					fmt.Printf("adding %s: %+v\n", addModelTypeName, modelType.Name())
   349  					added = true
   350  
   351  					for _, n := range p.Types.Scope().Names() {
   352  						if o := p.Types.Scope().Lookup(n); modelType != nil {
   353  							if typ, ok := o.(*types.Const); ok {
   354  								if typ.Type().String() == fullTypeName && typ.Type().Underlying().String() == "string" {
   355  									g.modelEnums[typeName].Values = append(g.modelEnums[typeName].Values, constant.StringVal(typ.Val()))
   356  								}
   357  							}
   358  						}
   359  					}
   360  				}
   361  				exist = true
   362  			} else {
   363  				fmt.Printf("%s found but already exist: %+v, %+v\n", addModelTypeName, fullTypeName, modelType.Name())
   364  			}
   365  		}
   366  	}
   367  	if !added && !exist {
   368  		for k := range mo.References {
   369  			fmt.Printf("missing %s - model: %+v - field: %+v - pkgPath: %+v\n", addModelTypeName, mo.Model.Name(), k, pkgPath)
   370  			break
   371  		}
   372  	}
   373  
   374  	return added
   375  }
   376  
   377  func (g *generator) addModel(model types.Object, fromModel types.Object, addModelType addingModelType) bool {
   378  	if _, exist := g.mapDataTypeExclusions[model.Pkg().Name()+"."+model.Name()]; exist {
   379  		if _, exist := g.excludedModelObjects[model.Name()]; !exist {
   380  			g.excludedModelObjects[model.Name()] = &model
   381  			fmt.Printf("excluded model: %+v\n", model.Name())
   382  		}
   383  		return false
   384  	}
   385  
   386  	added := false
   387  	typeName := fmt.Sprintf("%s.%s", model.Pkg().Path(), model.Name())
   388  	if _, exist := g.modelObjects[typeName]; !exist {
   389  		g.modelObjects[typeName] = &modelObject{
   390  			Model:        model,
   391  			Fields:       make(map[string]*fieldObject, 0),
   392  			Compositions: make(map[string]struct{}, 0),
   393  			References:   make(map[string]struct{}, 0),
   394  		}
   395  		added = true
   396  	}
   397  
   398  	mo := g.modelObjects[typeName]
   399  	switch addModelType {
   400  	case composition:
   401  		if _, exist := mo.Compositions[fromModel.Name()]; !exist {
   402  			mo.Compositions[fromModel.Name()] = struct{}{}
   403  		}
   404  	case reference:
   405  		if _, exist := mo.References[fromModel.Name()]; !exist {
   406  			mo.References[fromModel.Name()] = struct{}{}
   407  		}
   408  	}
   409  
   410  	return added
   411  }
   412  
   413  func (g *generator) getTypeName(fieldName string, obj types.Type, isSlice bool) (string, string, string, string, bool) {
   414  	pkgPath := ""
   415  	fullTypeName := ""
   416  	baseTypeName := ""
   417  	typeName := ""
   418  	isBasic := false
   419  
   420  	switch typ := obj.(type) {
   421  	case *types.Basic:
   422  		pkgPath = typ.Underlying().String()
   423  		fullTypeName = typ.Underlying().String() + "." + typ.Name()
   424  		baseTypeName = typ.Name()
   425  		isBasic = true
   426  	case *types.Named:
   427  		if _, ok := typ.Obj().Type().Underlying().(*types.Map); ok {
   428  			pkgPath = fmt.Sprintf("Map%s", typ.Obj().Name())
   429  			fullTypeName = pkgPath
   430  			baseTypeName = pkgPath
   431  		} else {
   432  			pkgPath = typ.Obj().Pkg().Path()
   433  			fullTypeName = typ.Obj().Pkg().Path() + "." + typ.Obj().Name()
   434  			baseTypeName = typ.Obj().Pkg().Name() + "." + typ.Obj().Name()
   435  			if _, ok := g.mapDataTypes[baseTypeName]; !ok {
   436  				baseTypeName = typ.Obj().Name()
   437  				if _, ok := typ.Obj().Type().Underlying().(*types.Slice); ok {
   438  					isSlice = true
   439  				}
   440  			}
   441  		}
   442  	case *types.Pointer:
   443  		pkgPath, fullTypeName, baseTypeName, typeName, isBasic = g.getTypeName(fieldName, typ.Elem(), isSlice)
   444  		return pkgPath, fullTypeName, baseTypeName, typeName, isBasic
   445  	case *types.Slice:
   446  		pkgPath, fullTypeName, baseTypeName, typeName, isBasic = g.getTypeName(fieldName, typ.Elem(), true)
   447  		return pkgPath, fullTypeName, baseTypeName, typeName, isBasic
   448  	case *types.Map:
   449  		pkgPath = fmt.Sprintf("Map%s", typ.Elem().String())
   450  		fullTypeName = pkgPath
   451  		baseTypeName = pkgPath
   452  	default:
   453  		fmt.Printf("type not supported - field: %+v - type: %+v\n", fieldName, reflect.TypeOf(obj))
   454  	}
   455  
   456  	if mapName, ok := g.mapDataTypes[baseTypeName]; ok {
   457  		baseTypeName = mapName
   458  		isBasic = true
   459  	}
   460  	if isSlice {
   461  		typeName = "[" + baseTypeName + "!]"
   462  	} else {
   463  		typeName = baseTypeName
   464  	}
   465  	return pkgPath, fullTypeName, baseTypeName, typeName, isBasic
   466  }
   467  
   468  func (g *generator) getEnumName(obj types.Type, isSlice bool) (string, string, string, string, bool) {
   469  	pkgPath := ""
   470  	fullTypeName := ""
   471  	baseTypeName := ""
   472  	typeName := ""
   473  	isBasic := false
   474  
   475  	switch typ := obj.(type) {
   476  	case *types.Named:
   477  		pkgPath = typ.Obj().Pkg().Path()
   478  		fullTypeName = typ.Obj().Pkg().Path() + "." + typ.Obj().Name()
   479  		baseTypeName = typ.Obj().Pkg().Name() + "." + typ.Obj().Name()
   480  		if _, ok := g.mapDataTypes[baseTypeName]; !ok {
   481  			baseTypeName = typ.Obj().Name()
   482  			if _, ok := typ.Obj().Type().Underlying().(*types.Slice); ok {
   483  				isSlice = true
   484  			}
   485  		}
   486  	default:
   487  		fmt.Printf("type not supported - field: %+v - type: %+v\n", obj.String(), reflect.TypeOf(obj))
   488  	}
   489  
   490  	if mapName, ok := g.mapDataTypes[baseTypeName]; ok {
   491  		baseTypeName = mapName
   492  		isBasic = true
   493  	}
   494  	if isSlice {
   495  		typeName = "[" + baseTypeName + "!]"
   496  	} else {
   497  		typeName = baseTypeName
   498  	}
   499  	return pkgPath, fullTypeName, baseTypeName, typeName, isBasic
   500  }
   501  
   502  func (g *generator) addFieldsToModelObject(obj types.Type, modelObject modelObject, isSlice bool) {
   503  	switch typ := obj.(type) {
   504  	case *types.Struct:
   505  		for i := 0; i < typ.NumFields(); i++ {
   506  			field := typ.Field(i)
   507  			tagValue, required, hidden := getJSONTagValue(typ.Tag(i))
   508  			if hidden {
   509  				continue
   510  			}
   511  
   512  			_, fullTypeName, baseTypeName, typeName, isBasic := g.getTypeName(field.Name(), field.Type(), isSlice)
   513  			if tagValue != "" {
   514  				if isBasic {
   515  					if baseTypeName != "" {
   516  						modelObject.addField(field, tagValue, typeName, required)
   517  					}
   518  				} else {
   519  					_, ok := g.modelObjects[fullTypeName]
   520  					if ok || strings.HasPrefix(typeName, "Map") {
   521  						modelObject.addField(field, tagValue, typeName, required)
   522  					}
   523  				}
   524  			} else {
   525  				// first level composition
   526  				if o, ok := g.modelObjects[fullTypeName]; ok {
   527  					if u, ok := o.Model.Type().Underlying().(*types.Struct); ok {
   528  						for i := 0; i < u.NumFields(); i++ {
   529  							field := u.Field(i)
   530  							tagValue, required, hidden := getJSONTagValue(u.Tag(i))
   531  							_, _, baseTypeName, typeName, isBasic := g.getTypeName(field.Name(), field.Type(), isSlice)
   532  							if tagValue != "" && !hidden && isBasic && baseTypeName != "" {
   533  								modelObject.addField(field, tagValue, typeName, required)
   534  							}
   535  						}
   536  					}
   537  				}
   538  			}
   539  		}
   540  	case *types.Slice:
   541  		g.addFieldsToModelObject(typ.Elem().Underlying(), modelObject, true)
   542  		return
   543  	case *types.Pointer:
   544  		g.addFieldsToModelObject(typ.Elem().Underlying(), modelObject, false)
   545  		return
   546  	case *types.Basic: // enum
   547  		return // managed in addEnum()
   548  	default:
   549  		fmt.Printf("field not supported - model: %s - field: %s - type: %+v\n", modelObject.Model.Name(), obj.String(), reflect.TypeOf(obj))
   550  	}
   551  }
   552  
   553  func (g *generator) deleteSource() {
   554  	// Delete existing source to avoid type collisions
   555  	for i, s := range g.config.Sources {
   556  		if s.Name == g.filePath {
   557  			g.config.Sources = append(g.config.Sources[:i], g.config.Sources[i+1:]...)
   558  			break
   559  		}
   560  	}
   561  }
   562  
   563  func (g *generator) renderTemplate(templateData *templateData) string {
   564  	var buf bytes.Buffer
   565  	if err := template.Must(template.New("test.graphqls").Funcs(template.FuncMap{
   566  		"GetSchemaField": GetSchemaFieldGotpl,
   567  	}).Parse(fileTemplate)).Execute(&buf, templateData); err != nil {
   568  		panic(err)
   569  	}
   570  	return buf.String()
   571  }
   572  
   573  func (o *modelObject) addField(field *types.Var, tagValue, typeName string, required bool) bool {
   574  	if _, exist := o.Fields[tagValue]; !exist {
   575  		o.Fields[tagValue] = &fieldObject{
   576  			Field:    field,
   577  			TagValue: tagValue,
   578  			TypeName: typeName,
   579  			Required: required,
   580  		}
   581  		return true
   582  	}
   583  	return false
   584  }
   585  
   586  func validateType(pkg *packages.Package, obj types.Object) bool {
   587  	if named, ok := obj.Type().(*types.Named); ok {
   588  		if s, ok := named.Underlying().(*types.Struct); ok {
   589  			for i := 0; i < s.NumFields(); i++ {
   590  				if s.Field(i).Pkg().Path() == pkg.PkgPath {
   591  					return true
   592  				}
   593  			}
   594  		}
   595  	}
   596  	return false
   597  }
   598  
   599  func getJSONTagValue(tag string) (string, bool, bool) {
   600  	required := false
   601  	hidden := false
   602  	if tag != "" && strings.Contains(tag, "json:") {
   603  		required := !strings.Contains(tag, "omitempty")
   604  		re := regexp.MustCompile(`json:\"(.*?)\"`)
   605  		res := re.FindAllStringSubmatch(tag, -1)
   606  		propertyName := strings.ReplaceAll(res[0][1], ",omitempty", "")
   607  		propertyName = strings.ReplaceAll(propertyName, ",", "")
   608  
   609  		if propertyName == "-" {
   610  			hidden = true
   611  		}
   612  		return propertyName, required, hidden
   613  	}
   614  	return "", required, hidden
   615  }
   616  
   617  // GetSchemaFieldGotpl is used in the go template
   618  func GetSchemaFieldGotpl(tagValue, typeName string, required bool) string {
   619  	field := fmt.Sprintf("%s: %s", tagValue, typeName)
   620  	if required {
   621  		field = fmt.Sprintf("%s: %s!", tagValue, typeName)
   622  	}
   623  	return field
   624  }
   625  
   626  func getFilePath(filename string) string {
   627  	return fmt.Sprintf("%s%s", basePath, filename)
   628  }