github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  
     8  	"github.com/animeshon/gqlgen/codegen/config"
     9  	"github.com/animeshon/gqlgen/codegen/templates"
    10  	"github.com/animeshon/gqlgen/plugin"
    11  	"github.com/vektah/gqlparser/v2/ast"
    12  )
    13  
    14  type BuildMutateHook = func(b *ModelBuild) *ModelBuild
    15  
    16  func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    17  	return b
    18  }
    19  
    20  type ModelBuild struct {
    21  	PackageName string
    22  	Interfaces  []*Interface
    23  	Models      []*Object
    24  	Enums       []*Enum
    25  	Scalars     []string
    26  }
    27  
    28  type Interface struct {
    29  	Description string
    30  	Name        string
    31  }
    32  
    33  type Object struct {
    34  	Description string
    35  	Name        string
    36  	Fields      []*Field
    37  	Implements  []string
    38  }
    39  
    40  type Field struct {
    41  	Description string
    42  	Name        string
    43  	Type        types.Type
    44  	Tag         string
    45  }
    46  
    47  func (f *Field) Nullable(name string) {
    48  	f.Tag = `json:"` + name + `,omitempty"`
    49  }
    50  
    51  type Enum struct {
    52  	Description string
    53  	Name        string
    54  	Values      []*EnumValue
    55  }
    56  
    57  type EnumValue struct {
    58  	Description string
    59  	Name        string
    60  }
    61  
    62  func New() plugin.Plugin {
    63  	return &Plugin{
    64  		MutateHook: defaultBuildMutateHook,
    65  	}
    66  }
    67  
    68  type Plugin struct {
    69  	MutateHook BuildMutateHook
    70  }
    71  
    72  var _ plugin.ConfigMutator = &Plugin{}
    73  
    74  func (m *Plugin) Name() string {
    75  	return "modelgen"
    76  }
    77  
    78  func (m *Plugin) MutateConfig(cfg *config.Config) error {
    79  	binder := cfg.NewBinder()
    80  
    81  	b := &ModelBuild{
    82  		PackageName: cfg.Model.Package,
    83  	}
    84  
    85  	for _, schemaType := range cfg.Schema.Types {
    86  		if cfg.Models.UserDefined(schemaType.Name) {
    87  			continue
    88  		}
    89  		switch schemaType.Kind {
    90  		case ast.Interface, ast.Union:
    91  			it := &Interface{
    92  				Description: schemaType.Description,
    93  				Name:        schemaType.Name,
    94  			}
    95  
    96  			b.Interfaces = append(b.Interfaces, it)
    97  		case ast.Object, ast.InputObject:
    98  			if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
    99  				continue
   100  			}
   101  			it := &Object{
   102  				Description: schemaType.Description,
   103  				Name:        schemaType.Name,
   104  			}
   105  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   106  				it.Implements = append(it.Implements, implementor.Name)
   107  			}
   108  
   109  			for _, field := range schemaType.Fields {
   110  				var typ types.Type
   111  				fieldDef := cfg.Schema.Types[field.Type.Name()]
   112  
   113  				if cfg.Models.UserDefined(field.Type.Name()) {
   114  					var err error
   115  					typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   116  					if err != nil {
   117  						return err
   118  					}
   119  				} else {
   120  					switch fieldDef.Kind {
   121  					case ast.Scalar:
   122  						// no user defined model, referencing a default scalar
   123  						typ = types.NewNamed(
   124  							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   125  							nil,
   126  							nil,
   127  						)
   128  
   129  					case ast.Interface, ast.Union:
   130  						// no user defined model, referencing a generated interface type
   131  						typ = types.NewNamed(
   132  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   133  							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   134  							nil,
   135  						)
   136  
   137  					case ast.Enum:
   138  						// no user defined model, must reference a generated enum
   139  						typ = types.NewNamed(
   140  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   141  							nil,
   142  							nil,
   143  						)
   144  
   145  					case ast.Object, ast.InputObject:
   146  						// no user defined model, must reference a generated struct
   147  						typ = types.NewNamed(
   148  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   149  							types.NewStruct(nil, nil),
   150  							nil,
   151  						)
   152  
   153  					default:
   154  						panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   155  					}
   156  				}
   157  
   158  				name := field.Name
   159  				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   160  					name = nameOveride
   161  				}
   162  
   163  				typ = binder.CopyModifiersFromAst(field.Type, typ)
   164  
   165  				if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   166  					typ = types.NewPointer(typ)
   167  				}
   168  
   169  				f := &Field{
   170  					Name:        name,
   171  					Type:        typ,
   172  					Description: field.Description,
   173  					Tag:         `json:"` + field.Name + `"`,
   174  				}
   175  				// if field is nullable, omit in json marshall to save bandwidth and avoid unexpected behaviour sending null
   176  				if !field.Type.NonNull {
   177  					f.Nullable(field.Name)
   178  				}
   179  
   180  				it.Fields = append(it.Fields, f)
   181  			}
   182  
   183  			b.Models = append(b.Models, it)
   184  		case ast.Enum:
   185  			it := &Enum{
   186  				Name:        schemaType.Name,
   187  				Description: schemaType.Description,
   188  			}
   189  
   190  			for _, v := range schemaType.EnumValues {
   191  				it.Values = append(it.Values, &EnumValue{
   192  					Name:        v.Name,
   193  					Description: v.Description,
   194  				})
   195  			}
   196  
   197  			b.Enums = append(b.Enums, it)
   198  		case ast.Scalar:
   199  			b.Scalars = append(b.Scalars, schemaType.Name)
   200  		}
   201  	}
   202  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   203  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   204  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   205  
   206  	for _, it := range b.Enums {
   207  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   208  	}
   209  	for _, it := range b.Models {
   210  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   211  	}
   212  	for _, it := range b.Interfaces {
   213  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   214  	}
   215  	for _, it := range b.Scalars {
   216  		cfg.Models.Add(it, "github.com/animeshon/gqlgen/graphql.String")
   217  	}
   218  
   219  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   220  		return nil
   221  	}
   222  
   223  	if m.MutateHook != nil {
   224  		b = m.MutateHook(b)
   225  	}
   226  
   227  	return templates.Render(templates.Options{
   228  		PackageName:     cfg.Model.Package,
   229  		Filename:        cfg.Model.Filename,
   230  		Data:            b,
   231  		GeneratedHeader: true,
   232  		Packages:        cfg.Packages,
   233  	})
   234  }
   235  
   236  func isStruct(t types.Type) bool {
   237  	_, is := t.Underlying().(*types.Struct)
   238  	return is
   239  }