github.com/fortexxx/gqlgen@v0.10.3-0.20191216030626-ca5ea8b21ead/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  
     8  	"github.com/99designs/gqlgen/codegen/config"
     9  	"github.com/99designs/gqlgen/codegen/templates"
    10  	"github.com/99designs/gqlgen/plugin"
    11  	"github.com/vektah/gqlparser/ast"
    12  )
    13  
    14  type BuildMutateHook = func(b *ModelBuild) *ModelBuild
    15  
    16  func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    17  	return b
    18  }
    19  
    20  type ModelBuild struct {
    21  	PackageName string
    22  	Interfaces  []*Interface
    23  	Models      []*Object
    24  	Enums       []*Enum
    25  	Scalars     []string
    26  }
    27  
    28  type Interface struct {
    29  	Description string
    30  	Name        string
    31  }
    32  
    33  type Object struct {
    34  	Description string
    35  	Name        string
    36  	Fields      []*Field
    37  	Implements  []string
    38  }
    39  
    40  type Field struct {
    41  	Description string
    42  	Name        string
    43  	Type        types.Type
    44  	Tag         string
    45  }
    46  
    47  type Enum struct {
    48  	Description string
    49  	Name        string
    50  	Values      []*EnumValue
    51  }
    52  
    53  type EnumValue struct {
    54  	Description string
    55  	Name        string
    56  }
    57  
    58  func New() plugin.Plugin {
    59  	return &Plugin{
    60  		MutateHook: defaultBuildMutateHook,
    61  	}
    62  }
    63  
    64  type Plugin struct {
    65  	MutateHook BuildMutateHook
    66  }
    67  
    68  var _ plugin.ConfigMutator = &Plugin{}
    69  
    70  func (m *Plugin) Name() string {
    71  	return "modelgen"
    72  }
    73  
    74  func (m *Plugin) MutateConfig(cfg *config.Config) error {
    75  	schema, err := cfg.LoadSchema()
    76  	if err != nil {
    77  		return err
    78  	}
    79  
    80  	err = cfg.Autobind(schema)
    81  	if err != nil {
    82  		return err
    83  	}
    84  
    85  	cfg.InjectBuiltins(schema)
    86  
    87  	binder, err := cfg.NewBinder(schema)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	b := &ModelBuild{
    93  		PackageName: cfg.Model.Package,
    94  	}
    95  
    96  	var hasEntity bool
    97  	for _, schemaType := range schema.Types {
    98  		if cfg.Models.UserDefined(schemaType.Name) {
    99  			continue
   100  		}
   101  		var ent bool
   102  		for _, dir := range schemaType.Directives {
   103  			if dir.Name == "key" {
   104  				hasEntity = true
   105  				ent = true
   106  			}
   107  		}
   108  		switch schemaType.Kind {
   109  		case ast.Interface, ast.Union:
   110  			it := &Interface{
   111  				Description: schemaType.Description,
   112  				Name:        schemaType.Name,
   113  			}
   114  
   115  			b.Interfaces = append(b.Interfaces, it)
   116  		case ast.Object, ast.InputObject:
   117  			if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription {
   118  				continue
   119  			}
   120  			it := &Object{
   121  				Description: schemaType.Description,
   122  				Name:        schemaType.Name,
   123  			}
   124  
   125  			for _, implementor := range schema.GetImplements(schemaType) {
   126  				it.Implements = append(it.Implements, implementor.Name)
   127  			}
   128  			if ent { // only when Object. Directive validation should have occurred on InputObject otherwise.
   129  				it.Implements = append(it.Implements, "_Entity")
   130  			}
   131  
   132  			for _, field := range schemaType.Fields {
   133  				var typ types.Type
   134  				fieldDef := schema.Types[field.Type.Name()]
   135  
   136  				if cfg.Models.UserDefined(field.Type.Name()) {
   137  					typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   138  					if err != nil {
   139  						return err
   140  					}
   141  				} else {
   142  					switch fieldDef.Kind {
   143  					case ast.Scalar:
   144  						// no user defined model, referencing a default scalar
   145  						typ = types.NewNamed(
   146  							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   147  							nil,
   148  							nil,
   149  						)
   150  
   151  					case ast.Interface, ast.Union:
   152  						// no user defined model, referencing a generated interface type
   153  						typ = types.NewNamed(
   154  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   155  							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   156  							nil,
   157  						)
   158  
   159  					case ast.Enum:
   160  						// no user defined model, must reference a generated enum
   161  						typ = types.NewNamed(
   162  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   163  							nil,
   164  							nil,
   165  						)
   166  
   167  					case ast.Object, ast.InputObject:
   168  						// no user defined model, must reference a generated struct
   169  						typ = types.NewNamed(
   170  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   171  							types.NewStruct(nil, nil),
   172  							nil,
   173  						)
   174  
   175  					default:
   176  						panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   177  					}
   178  				}
   179  
   180  				name := field.Name
   181  				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   182  					name = nameOveride
   183  				}
   184  
   185  				typ = binder.CopyModifiersFromAst(field.Type, typ)
   186  
   187  				if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   188  					typ = types.NewPointer(typ)
   189  				}
   190  
   191  				it.Fields = append(it.Fields, &Field{
   192  					Name:        name,
   193  					Type:        typ,
   194  					Description: field.Description,
   195  					Tag:         `json:"` + field.Name + `"`,
   196  				})
   197  			}
   198  
   199  			b.Models = append(b.Models, it)
   200  		case ast.Enum:
   201  			it := &Enum{
   202  				Name:        schemaType.Name,
   203  				Description: schemaType.Description,
   204  			}
   205  
   206  			for _, v := range schemaType.EnumValues {
   207  				it.Values = append(it.Values, &EnumValue{
   208  					Name:        v.Name,
   209  					Description: v.Description,
   210  				})
   211  			}
   212  
   213  			b.Enums = append(b.Enums, it)
   214  		case ast.Scalar:
   215  			b.Scalars = append(b.Scalars, schemaType.Name)
   216  		}
   217  	}
   218  
   219  	if hasEntity {
   220  		it := &Interface{
   221  			Description: "_Entity represents all types with @key",
   222  			Name:        "_Entity",
   223  		}
   224  		b.Interfaces = append(b.Interfaces, it)
   225  	}
   226  
   227  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   228  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   229  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   230  
   231  	for _, it := range b.Enums {
   232  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   233  	}
   234  	for _, it := range b.Models {
   235  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   236  	}
   237  	for _, it := range b.Interfaces {
   238  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   239  	}
   240  	for _, it := range b.Scalars {
   241  		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
   242  	}
   243  
   244  	if len(b.Models) == 0 && len(b.Enums) == 0 {
   245  		return nil
   246  	}
   247  
   248  	if m.MutateHook != nil {
   249  		b = m.MutateHook(b)
   250  	}
   251  
   252  	return templates.Render(templates.Options{
   253  		PackageName:     cfg.Model.Package,
   254  		Filename:        cfg.Model.Filename,
   255  		Data:            b,
   256  		GeneratedHeader: true,
   257  	})
   258  }
   259  
   260  func isStruct(t types.Type) bool {
   261  	_, is := t.Underlying().(*types.Struct)
   262  	return is
   263  }