github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/plugin/modelgen/models_test.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"errors"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"sort"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/spread-ai/gqlgen/plugin/modelgen/out_struct_pointers"
    16  
    17  	"github.com/spread-ai/gqlgen/codegen/config"
    18  	"github.com/spread-ai/gqlgen/plugin/modelgen/out"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestModelGeneration(t *testing.T) {
    24  	cfg, err := config.LoadConfig("testdata/gqlgen.yml")
    25  	require.NoError(t, err)
    26  	require.NoError(t, cfg.Init())
    27  	p := Plugin{
    28  		MutateHook: mutateHook,
    29  		FieldHook:  defaultFieldMutateHook,
    30  	}
    31  	require.NoError(t, p.MutateConfig(cfg))
    32  	require.NoError(t, goBuild(t, "./out/"))
    33  
    34  	require.True(t, cfg.Models.UserDefined("MissingTypeNotNull"))
    35  	require.True(t, cfg.Models.UserDefined("MissingTypeNullable"))
    36  	require.True(t, cfg.Models.UserDefined("MissingEnum"))
    37  	require.True(t, cfg.Models.UserDefined("MissingUnion"))
    38  	require.True(t, cfg.Models.UserDefined("MissingInterface"))
    39  	require.True(t, cfg.Models.UserDefined("TypeWithDescription"))
    40  	require.True(t, cfg.Models.UserDefined("EnumWithDescription"))
    41  	require.True(t, cfg.Models.UserDefined("InterfaceWithDescription"))
    42  	require.True(t, cfg.Models.UserDefined("UnionWithDescription"))
    43  
    44  	t.Run("no pointer pointers", func(t *testing.T) {
    45  		generated, err := os.ReadFile("./out/generated.go")
    46  		require.NoError(t, err)
    47  		require.NotContains(t, string(generated), "**")
    48  	})
    49  
    50  	t.Run("description is generated", func(t *testing.T) {
    51  		node, err := parser.ParseFile(token.NewFileSet(), "./out/generated.go", nil, parser.ParseComments)
    52  		require.NoError(t, err)
    53  		for _, commentGroup := range node.Comments {
    54  			text := commentGroup.Text()
    55  			words := strings.Split(text, " ")
    56  			require.True(t, len(words) > 1, "expected description %q to have more than one word", text)
    57  		}
    58  	})
    59  
    60  	t.Run("tags are applied", func(t *testing.T) {
    61  		file, err := os.ReadFile("./out/generated.go")
    62  		require.NoError(t, err)
    63  
    64  		fileText := string(file)
    65  
    66  		expectedTags := []string{
    67  			`json:"missing2" database:"MissingTypeNotNullmissing2"`,
    68  			`json:"name" database:"MissingInputname"`,
    69  			`json:"missing2" database:"MissingTypeNullablemissing2"`,
    70  			`json:"name" database:"TypeWithDescriptionname"`,
    71  		}
    72  
    73  		for _, tag := range expectedTags {
    74  			require.True(t, strings.Contains(fileText, tag))
    75  		}
    76  	})
    77  
    78  	t.Run("field hooks are applied", func(t *testing.T) {
    79  		file, err := os.ReadFile("./out/generated.go")
    80  		require.NoError(t, err)
    81  
    82  		fileText := string(file)
    83  
    84  		expectedTags := []string{
    85  			`json:"name" anotherTag:"tag"`,
    86  			`json:"enum" yetAnotherTag:"12"`,
    87  			`json:"noVal" yaml:"noVal"`,
    88  			`json:"repeated" someTag:"value" repeated:"true"`,
    89  		}
    90  
    91  		for _, tag := range expectedTags {
    92  			require.True(t, strings.Contains(fileText, tag))
    93  		}
    94  	})
    95  
    96  	t.Run("concrete types implement interface", func(t *testing.T) {
    97  		var _ out.FooBarer = out.FooBarr{}
    98  	})
    99  
   100  	t.Run("implemented interfaces", func(t *testing.T) {
   101  		pkg, err := parseAst("out")
   102  		require.NoError(t, err)
   103  
   104  		path := filepath.Join("out", "generated.go")
   105  		generated := pkg.Files[path]
   106  
   107  		type field struct {
   108  			typ  string
   109  			name string
   110  		}
   111  		cases := []struct {
   112  			name       string
   113  			wantFields []field
   114  		}{
   115  			{
   116  				name: "A",
   117  				wantFields: []field{
   118  					{
   119  						typ:  "method",
   120  						name: "IsA",
   121  					},
   122  					{
   123  						typ:  "method",
   124  						name: "GetA",
   125  					},
   126  				},
   127  			},
   128  			{
   129  				name: "B",
   130  				wantFields: []field{
   131  					{
   132  						typ:  "method",
   133  						name: "IsB",
   134  					},
   135  					{
   136  						typ:  "method",
   137  						name: "GetB",
   138  					},
   139  				},
   140  			},
   141  			{
   142  				name: "C",
   143  				wantFields: []field{
   144  					{
   145  						typ:  "method",
   146  						name: "IsA",
   147  					},
   148  					{
   149  						typ:  "method",
   150  						name: "IsC",
   151  					},
   152  					{
   153  						typ:  "method",
   154  						name: "GetA",
   155  					},
   156  					{
   157  						typ:  "method",
   158  						name: "GetC",
   159  					},
   160  				},
   161  			},
   162  			{
   163  				name: "D",
   164  				wantFields: []field{
   165  					{
   166  						typ:  "method",
   167  						name: "IsA",
   168  					},
   169  					{
   170  						typ:  "method",
   171  						name: "IsB",
   172  					},
   173  					{
   174  						typ:  "method",
   175  						name: "IsD",
   176  					},
   177  					{
   178  						typ:  "method",
   179  						name: "GetA",
   180  					},
   181  					{
   182  						typ:  "method",
   183  						name: "GetB",
   184  					},
   185  					{
   186  						typ:  "method",
   187  						name: "GetD",
   188  					},
   189  				},
   190  			},
   191  		}
   192  		for _, tc := range cases {
   193  			tc := tc
   194  			t.Run(tc.name, func(t *testing.T) {
   195  				typeSpec, ok := generated.Scope.Lookup(tc.name).Decl.(*ast.TypeSpec)
   196  				require.True(t, ok)
   197  
   198  				fields := typeSpec.Type.(*ast.InterfaceType).Methods.List
   199  				for i, want := range tc.wantFields {
   200  					if want.typ == "ident" {
   201  						ident, ok := fields[i].Type.(*ast.Ident)
   202  						require.True(t, ok)
   203  						assert.Equal(t, want.name, ident.Name)
   204  					}
   205  					if want.typ == "method" {
   206  						require.GreaterOrEqual(t, 1, len(fields[i].Names))
   207  						name := fields[i].Names[0].Name
   208  						assert.Equal(t, want.name, name)
   209  					}
   210  				}
   211  			})
   212  		}
   213  	})
   214  
   215  	t.Run("implemented interfaces type CDImplemented", func(t *testing.T) {
   216  		pkg, err := parseAst("out")
   217  		require.NoError(t, err)
   218  
   219  		path := filepath.Join("out", "generated.go")
   220  		generated := pkg.Files[path]
   221  
   222  		wantMethods := []string{
   223  			"IsA",
   224  			"IsB",
   225  			"IsC",
   226  			"IsD",
   227  		}
   228  
   229  		gots := make([]string, 0, len(wantMethods))
   230  		for _, decl := range generated.Decls {
   231  			if funcDecl, ok := decl.(*ast.FuncDecl); ok {
   232  				switch funcDecl.Name.Name {
   233  				case "IsA", "IsB", "IsC", "IsD":
   234  					gots = append(gots, funcDecl.Name.Name)
   235  					require.Len(t, funcDecl.Recv.List, 1)
   236  					recvIdent, ok := funcDecl.Recv.List[0].Type.(*ast.Ident)
   237  					require.True(t, ok)
   238  					require.Equal(t, "CDImplemented", recvIdent.Name)
   239  				}
   240  			}
   241  		}
   242  
   243  		sort.Strings(gots)
   244  		require.Equal(t, wantMethods, gots)
   245  	})
   246  
   247  	t.Run("cyclical struct fields become pointers", func(t *testing.T) {
   248  		require.Nil(t, out.CyclicalA{}.FieldOne)
   249  		require.Nil(t, out.CyclicalA{}.FieldTwo)
   250  		require.Nil(t, out.CyclicalA{}.FieldThree)
   251  		require.NotNil(t, out.CyclicalA{}.FieldFour)
   252  		require.Nil(t, out.CyclicalB{}.FieldOne)
   253  		require.Nil(t, out.CyclicalB{}.FieldTwo)
   254  		require.Nil(t, out.CyclicalB{}.FieldThree)
   255  		require.Nil(t, out.CyclicalB{}.FieldFour)
   256  		require.NotNil(t, out.CyclicalB{}.FieldFive)
   257  	})
   258  
   259  	t.Run("non-cyclical struct fields become pointers", func(t *testing.T) {
   260  		require.NotNil(t, out.NotCyclicalB{}.FieldOne)
   261  		require.Nil(t, out.NotCyclicalB{}.FieldTwo)
   262  	})
   263  
   264  	t.Run("recursive struct fields become pointers", func(t *testing.T) {
   265  		require.Nil(t, out.Recursive{}.FieldOne)
   266  		require.Nil(t, out.Recursive{}.FieldTwo)
   267  		require.Nil(t, out.Recursive{}.FieldThree)
   268  		require.NotNil(t, out.Recursive{}.FieldFour)
   269  	})
   270  
   271  	t.Run("overridden struct field names use same capitalization as config", func(t *testing.T) {
   272  		require.NotNil(t, out.RenameFieldTest{}.GOODnaME)
   273  	})
   274  }
   275  
   276  func TestModelGenerationStructFieldPointers(t *testing.T) {
   277  	cfg, err := config.LoadConfig("testdata/gqlgen_struct_field_pointers.yml")
   278  	require.NoError(t, err)
   279  	require.NoError(t, cfg.Init())
   280  	p := Plugin{
   281  		MutateHook: mutateHook,
   282  		FieldHook:  defaultFieldMutateHook,
   283  	}
   284  	require.NoError(t, p.MutateConfig(cfg))
   285  
   286  	t.Run("no pointer pointers", func(t *testing.T) {
   287  		generated, err := os.ReadFile("./out_struct_pointers/generated.go")
   288  		require.NoError(t, err)
   289  		require.NotContains(t, string(generated), "**")
   290  	})
   291  
   292  	t.Run("cyclical struct fields become pointers", func(t *testing.T) {
   293  		require.Nil(t, out_struct_pointers.CyclicalA{}.FieldOne)
   294  		require.Nil(t, out_struct_pointers.CyclicalA{}.FieldTwo)
   295  		require.Nil(t, out_struct_pointers.CyclicalA{}.FieldThree)
   296  		require.NotNil(t, out_struct_pointers.CyclicalA{}.FieldFour)
   297  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldOne)
   298  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldTwo)
   299  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldThree)
   300  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldFour)
   301  		require.NotNil(t, out_struct_pointers.CyclicalB{}.FieldFive)
   302  	})
   303  
   304  	t.Run("non-cyclical struct fields do not become pointers", func(t *testing.T) {
   305  		require.NotNil(t, out_struct_pointers.NotCyclicalB{}.FieldOne)
   306  		require.NotNil(t, out_struct_pointers.NotCyclicalB{}.FieldTwo)
   307  	})
   308  
   309  	t.Run("recursive struct fields become pointers", func(t *testing.T) {
   310  		require.Nil(t, out_struct_pointers.Recursive{}.FieldOne)
   311  		require.Nil(t, out_struct_pointers.Recursive{}.FieldTwo)
   312  		require.Nil(t, out_struct_pointers.Recursive{}.FieldThree)
   313  		require.NotNil(t, out_struct_pointers.Recursive{}.FieldFour)
   314  	})
   315  
   316  	t.Run("no getters", func(t *testing.T) {
   317  		generated, err := os.ReadFile("./out_struct_pointers/generated.go")
   318  		require.NoError(t, err)
   319  		require.NotContains(t, string(generated), "func (this")
   320  	})
   321  }
   322  
   323  func mutateHook(b *ModelBuild) *ModelBuild {
   324  	for _, model := range b.Models {
   325  		for _, field := range model.Fields {
   326  			field.Tag += ` database:"` + model.Name + field.Name + `"`
   327  		}
   328  	}
   329  
   330  	return b
   331  }
   332  
   333  func parseAst(path string) (*ast.Package, error) {
   334  	// test setup to parse the types
   335  	fset := token.NewFileSet()
   336  	pkgs, err := parser.ParseDir(fset, path, nil, parser.AllErrors)
   337  	if err != nil {
   338  		return nil, err
   339  	}
   340  	return pkgs["out"], nil
   341  }
   342  
   343  func goBuild(t *testing.T, path string) error {
   344  	t.Helper()
   345  	cmd := exec.Command("go", "build", path)
   346  	out, err := cmd.CombinedOutput()
   347  	if err != nil {
   348  		return errors.New(string(out))
   349  	}
   350  
   351  	return nil
   352  }