github.com/operandinc/gqlgen@v0.16.1/plugin/modelgen/models.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/operandinc/gqlgen/codegen/config"
    10  	"github.com/operandinc/gqlgen/codegen/templates"
    11  	"github.com/operandinc/gqlgen/plugin"
    12  	"github.com/vektah/gqlparser/v2/ast"
    13  )
    14  
    15  type BuildMutateHook = func(b *ModelBuild) *ModelBuild
    16  
    17  type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
    18  
    19  // defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook.
    20  func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
    21  	return GoTagFieldHook(td, fd, f)
    22  }
    23  
    24  func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
    25  	return b
    26  }
    27  
    28  type ModelBuild struct {
    29  	PackageName string
    30  	Interfaces  []*Interface
    31  	Models      []*Object
    32  	Enums       []*Enum
    33  	Scalars     []string
    34  }
    35  
    36  type Interface struct {
    37  	Description string
    38  	Name        string
    39  }
    40  
    41  type Object struct {
    42  	Description string
    43  	Name        string
    44  	Fields      []*Field
    45  	Implements  []string
    46  }
    47  
    48  type Field struct {
    49  	Description string
    50  	Name        string
    51  	Type        types.Type
    52  	Tag         string
    53  }
    54  
    55  type Enum struct {
    56  	Description string
    57  	Name        string
    58  	Values      []*EnumValue
    59  }
    60  
    61  type EnumValue struct {
    62  	Description string
    63  	Name        string
    64  }
    65  
    66  func New() plugin.Plugin {
    67  	return &Plugin{
    68  		MutateHook: defaultBuildMutateHook,
    69  		FieldHook:  defaultFieldMutateHook,
    70  	}
    71  }
    72  
    73  type Plugin struct {
    74  	MutateHook BuildMutateHook
    75  	FieldHook  FieldMutateHook
    76  }
    77  
    78  var _ plugin.ConfigMutator = &Plugin{}
    79  
    80  func (m *Plugin) Name() string {
    81  	return "modelgen"
    82  }
    83  
    84  func (m *Plugin) MutateConfig(cfg *config.Config) error {
    85  	binder := cfg.NewBinder()
    86  
    87  	b := &ModelBuild{
    88  		PackageName: cfg.Model.Package,
    89  	}
    90  
    91  	for _, schemaType := range cfg.Schema.Types {
    92  		if cfg.Models.UserDefined(schemaType.Name) {
    93  			continue
    94  		}
    95  		switch schemaType.Kind {
    96  		case ast.Interface, ast.Union:
    97  			it := &Interface{
    98  				Description: schemaType.Description,
    99  				Name:        schemaType.Name,
   100  			}
   101  
   102  			b.Interfaces = append(b.Interfaces, it)
   103  		case ast.Object, ast.InputObject:
   104  			if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
   105  				continue
   106  			}
   107  			it := &Object{
   108  				Description: schemaType.Description,
   109  				Name:        schemaType.Name,
   110  			}
   111  			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
   112  				it.Implements = append(it.Implements, implementor.Name)
   113  			}
   114  
   115  			for _, field := range schemaType.Fields {
   116  				var typ types.Type
   117  				fieldDef := cfg.Schema.Types[field.Type.Name()]
   118  
   119  				if cfg.Models.UserDefined(field.Type.Name()) {
   120  					var err error
   121  					typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
   122  					if err != nil {
   123  						return err
   124  					}
   125  				} else {
   126  					switch fieldDef.Kind {
   127  					case ast.Scalar:
   128  						// no user defined model, referencing a default scalar
   129  						typ = types.NewNamed(
   130  							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
   131  							nil,
   132  							nil,
   133  						)
   134  
   135  					case ast.Interface, ast.Union:
   136  						// no user defined model, referencing a generated interface type
   137  						typ = types.NewNamed(
   138  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   139  							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
   140  							nil,
   141  						)
   142  
   143  					case ast.Enum:
   144  						// no user defined model, must reference a generated enum
   145  						typ = types.NewNamed(
   146  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   147  							nil,
   148  							nil,
   149  						)
   150  
   151  					case ast.Object, ast.InputObject:
   152  						// no user defined model, must reference a generated struct
   153  						typ = types.NewNamed(
   154  							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
   155  							types.NewStruct(nil, nil),
   156  							nil,
   157  						)
   158  
   159  					default:
   160  						panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
   161  					}
   162  				}
   163  
   164  				name := field.Name
   165  				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
   166  					name = nameOveride
   167  				}
   168  
   169  				typ = binder.CopyModifiersFromAst(field.Type, typ)
   170  
   171  				if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
   172  					typ = types.NewPointer(typ)
   173  				}
   174  
   175  				f := &Field{
   176  					Name:        name,
   177  					Type:        typ,
   178  					Description: field.Description,
   179  					Tag:         `json:"` + field.Name + `"`,
   180  				}
   181  
   182  				if m.FieldHook != nil {
   183  					mf, err := m.FieldHook(schemaType, field, f)
   184  					if err != nil {
   185  						return fmt.Errorf("generror: field %v.%v: %w", it.Name, field.Name, err)
   186  					}
   187  					f = mf
   188  				}
   189  
   190  				it.Fields = append(it.Fields, f)
   191  			}
   192  
   193  			b.Models = append(b.Models, it)
   194  		case ast.Enum:
   195  			it := &Enum{
   196  				Name:        schemaType.Name,
   197  				Description: schemaType.Description,
   198  			}
   199  
   200  			for _, v := range schemaType.EnumValues {
   201  				it.Values = append(it.Values, &EnumValue{
   202  					Name:        v.Name,
   203  					Description: v.Description,
   204  				})
   205  			}
   206  
   207  			b.Enums = append(b.Enums, it)
   208  		case ast.Scalar:
   209  			b.Scalars = append(b.Scalars, schemaType.Name)
   210  		}
   211  	}
   212  	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
   213  	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
   214  	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
   215  
   216  	for _, it := range b.Enums {
   217  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   218  	}
   219  	for _, it := range b.Models {
   220  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   221  	}
   222  	for _, it := range b.Interfaces {
   223  		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
   224  	}
   225  	for _, it := range b.Scalars {
   226  		cfg.Models.Add(it, "github.com/operandinc/gqlgen/graphql.String")
   227  	}
   228  
   229  	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
   230  		return nil
   231  	}
   232  
   233  	if m.MutateHook != nil {
   234  		b = m.MutateHook(b)
   235  	}
   236  
   237  	err := templates.Render(templates.Options{
   238  		PackageName:     cfg.Model.Package,
   239  		Filename:        cfg.Model.Filename,
   240  		Data:            b,
   241  		GeneratedHeader: true,
   242  		Packages:        cfg.Packages,
   243  	})
   244  	if err != nil {
   245  		return err
   246  	}
   247  
   248  	// We may have generated code in a package we already loaded, so we reload all packages
   249  	// to allow packages to be compared correctly
   250  	cfg.ReloadAllPackages()
   251  
   252  	return nil
   253  }
   254  
   255  // GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field
   256  // name is used when no value argument is present.
   257  func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
   258  	args := make([]string, 0)
   259  	for _, goTag := range fd.Directives.ForNames("goTag") {
   260  		key := ""
   261  		value := fd.Name
   262  
   263  		if arg := goTag.Arguments.ForName("key"); arg != nil {
   264  			if k, err := arg.Value.Value(nil); err == nil {
   265  				key = k.(string)
   266  			}
   267  		}
   268  
   269  		if arg := goTag.Arguments.ForName("value"); arg != nil {
   270  			if v, err := arg.Value.Value(nil); err == nil {
   271  				value = v.(string)
   272  			}
   273  		}
   274  
   275  		args = append(args, key+":\""+value+"\"")
   276  	}
   277  
   278  	if len(args) > 0 {
   279  		f.Tag = f.Tag + " " + strings.Join(args, " ")
   280  	}
   281  
   282  	return f, nil
   283  }
   284  
   285  func isStruct(t types.Type) bool {
   286  	_, is := t.Underlying().(*types.Struct)
   287  	return is
   288  }