github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  
     8  	"github.com/99designs/gqlgen/codegen/config"
     9  	"github.com/99designs/gqlgen/codegen/templates"
    10  	"github.com/99designs/gqlgen/plugin"
    11  	"github.com/vektah/gqlparser/v2/ast"
    12  )
    13  
    14  type BuildMutateHook = func(b *ModelBuild) *ModelBuild
    15  
    16  func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    17  	return b
    18  }
    19  
    20  type ModelBuild struct {
    21  	PackageName string
    22  	Interfaces  []*Interface
    23  	Models      []*Object
    24  	Enums       []*Enum
    25  	Scalars     []string
    26  }
    27  
    28  type Interface struct {
    29  	Description string
    30  	Name        string
    31  }
    32  
    33  type Object struct {
    34  	Description string
    35  	Name        string
    36  	Fields      []*Field
    37  	Implements  []string
    38  }
    39  
    40  type Field struct {
    41  	Description string
    42  	Name        string
    43  	Type        types.Type
    44  	Tag         string
    45  }
    46  
    47  type Enum struct {
    48  	Description string
    49  	Name        string
    50  	Values      []*EnumValue
    51  }
    52  
    53  type EnumValue struct {
    54  	Description string
    55  	Name        string
    56  }
    57  
    58  func New() plugin.Plugin {
    59  	return &Plugin{
    60  		MutateHook: defaultBuildMutateHook,
    61  	}
    62  }
    63  
    64  type Plugin struct {
    65  	MutateHook BuildMutateHook
    66  }
    67  
    68  var _ plugin.ConfigMutator = &Plugin{}
    69  
    70  func (m *Plugin) Name() string {
    71  	return "modelgen"
    72  }
    73  
    74  func (m *Plugin) MutateConfig(cfg *config.Config) error {
    75  	binder := cfg.NewBinder()
    76  
    77  	b := &ModelBuild{
    78  		PackageName: cfg.Model.Package,
    79  	}
    80  
    81  	for _, schemaType := range cfg.Schema.Types {
    82  		if cfg.Models.UserDefined(schemaType.Name) {
    83  			continue
    84  		}
    85  		switch schemaType.Kind {
    86  		case ast.Interface, ast.Union:
    87  			it := &Interface{
    88  				Description: schemaType.Description,
    89  				Name:        schemaType.Name,
    90  			}
    91  
    92  			b.Interfaces = append(b.Interfaces, it)
    93  		case ast.Object, ast.InputObject:
    94  			if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
    95  				continue
    96  			}
    97  			it := &Object{
    98  				Description: schemaType.Description,
    99  				Name:        schemaType.Name,
   100  			}
   101  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   102  				it.Implements = append(it.Implements, implementor.Name)
   103  			}
   104  
   105  			for _, field := range schemaType.Fields {
   106  				var typ types.Type
   107  				fieldDef := cfg.Schema.Types[field.Type.Name()]
   108  
   109  				if cfg.Models.UserDefined(field.Type.Name()) {
   110  					var err error
   111  					typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   112  					if err != nil {
   113  						return err
   114  					}
   115  				} else {
   116  					switch fieldDef.Kind {
   117  					case ast.Scalar:
   118  						// no user defined model, referencing a default scalar
   119  						typ = types.NewNamed(
   120  							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   121  							nil,
   122  							nil,
   123  						)
   124  
   125  					case ast.Interface, ast.Union:
   126  						// no user defined model, referencing a generated interface type
   127  						typ = types.NewNamed(
   128  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   129  							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   130  							nil,
   131  						)
   132  
   133  					case ast.Enum:
   134  						// no user defined model, must reference a generated enum
   135  						typ = types.NewNamed(
   136  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   137  							nil,
   138  							nil,
   139  						)
   140  
   141  					case ast.Object, ast.InputObject:
   142  						// no user defined model, must reference a generated struct
   143  						typ = types.NewNamed(
   144  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   145  							types.NewStruct(nil, nil),
   146  							nil,
   147  						)
   148  
   149  					default:
   150  						panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   151  					}
   152  				}
   153  
   154  				name := field.Name
   155  				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   156  					name = nameOveride
   157  				}
   158  
   159  				typ = binder.CopyModifiersFromAst(field.Type, typ)
   160  
   161  				if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   162  					typ = types.NewPointer(typ)
   163  				}
   164  
   165  				it.Fields = append(it.Fields, &Field{
   166  					Name:        name,
   167  					Type:        typ,
   168  					Description: field.Description,
   169  					Tag:         `json:"` + field.Name + `"`,
   170  				})
   171  			}
   172  
   173  			b.Models = append(b.Models, it)
   174  		case ast.Enum:
   175  			it := &Enum{
   176  				Name:        schemaType.Name,
   177  				Description: schemaType.Description,
   178  			}
   179  
   180  			for _, v := range schemaType.EnumValues {
   181  				it.Values = append(it.Values, &EnumValue{
   182  					Name:        v.Name,
   183  					Description: v.Description,
   184  				})
   185  			}
   186  
   187  			b.Enums = append(b.Enums, it)
   188  		case ast.Scalar:
   189  			b.Scalars = append(b.Scalars, schemaType.Name)
   190  		}
   191  	}
   192  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   193  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   194  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   195  
   196  	for _, it := range b.Enums {
   197  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   198  	}
   199  	for _, it := range b.Models {
   200  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   201  	}
   202  	for _, it := range b.Interfaces {
   203  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   204  	}
   205  	for _, it := range b.Scalars {
   206  		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
   207  	}
   208  
   209  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   210  		return nil
   211  	}
   212  
   213  	if m.MutateHook != nil {
   214  		b = m.MutateHook(b)
   215  	}
   216  
   217  	return templates.Render(templates.Options{
   218  		PackageName:     cfg.Model.Package,
   219  		Filename:        cfg.Model.Filename,
   220  		Data:            b,
   221  		GeneratedHeader: true,
   222  		Packages:        cfg.Packages,
   223  	})
   224  }
   225  
   226  func isStruct(t types.Type) bool {
   227  	_, is := t.Underlying().(*types.Struct)
   228  	return is
   229  }