github.com/senomas/gqlgen@v0.17.11-0.20220626120754-9aee61b0716a/plugin/modelgen/models_test.go (about)

     1  package modelgen
     2  
     3  import (
     4  	"go/ast"
     5  	"go/parser"
     6  	"go/token"
     7  	"os"
     8  	"path/filepath"
     9  	"sort"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/99designs/gqlgen/plugin/modelgen/out_struct_pointers"
    14  
    15  	"github.com/99designs/gqlgen/codegen/config"
    16  	"github.com/99designs/gqlgen/plugin/modelgen/out"
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  func TestModelGeneration(t *testing.T) {
    22  	cfg, err := config.LoadConfig("testdata/gqlgen.yml")
    23  	require.NoError(t, err)
    24  	require.NoError(t, cfg.Init())
    25  	p := Plugin{
    26  		MutateHook: mutateHook,
    27  		FieldHook:  defaultFieldMutateHook,
    28  	}
    29  	require.NoError(t, p.MutateConfig(cfg))
    30  
    31  	require.True(t, cfg.Models.UserDefined("MissingTypeNotNull"))
    32  	require.True(t, cfg.Models.UserDefined("MissingTypeNullable"))
    33  	require.True(t, cfg.Models.UserDefined("MissingEnum"))
    34  	require.True(t, cfg.Models.UserDefined("MissingUnion"))
    35  	require.True(t, cfg.Models.UserDefined("MissingInterface"))
    36  	require.True(t, cfg.Models.UserDefined("TypeWithDescription"))
    37  	require.True(t, cfg.Models.UserDefined("EnumWithDescription"))
    38  	require.True(t, cfg.Models.UserDefined("InterfaceWithDescription"))
    39  	require.True(t, cfg.Models.UserDefined("UnionWithDescription"))
    40  
    41  	t.Run("no pointer pointers", func(t *testing.T) {
    42  		generated, err := os.ReadFile("./out/generated.go")
    43  		require.NoError(t, err)
    44  		require.NotContains(t, string(generated), "**")
    45  	})
    46  
    47  	t.Run("description is generated", func(t *testing.T) {
    48  		node, err := parser.ParseFile(token.NewFileSet(), "./out/generated.go", nil, parser.ParseComments)
    49  		require.NoError(t, err)
    50  		for _, commentGroup := range node.Comments {
    51  			text := commentGroup.Text()
    52  			words := strings.Split(text, " ")
    53  			require.True(t, len(words) > 1, "expected description %q to have more than one word", text)
    54  		}
    55  	})
    56  
    57  	t.Run("tags are applied", func(t *testing.T) {
    58  		file, err := os.ReadFile("./out/generated.go")
    59  		require.NoError(t, err)
    60  
    61  		fileText := string(file)
    62  
    63  		expectedTags := []string{
    64  			`json:"missing2" database:"MissingTypeNotNullmissing2"`,
    65  			`json:"name" database:"MissingInputname"`,
    66  			`json:"missing2" database:"MissingTypeNullablemissing2"`,
    67  			`json:"name" database:"TypeWithDescriptionname"`,
    68  		}
    69  
    70  		for _, tag := range expectedTags {
    71  			require.True(t, strings.Contains(fileText, tag))
    72  		}
    73  	})
    74  
    75  	t.Run("field hooks are applied", func(t *testing.T) {
    76  		file, err := os.ReadFile("./out/generated.go")
    77  		require.NoError(t, err)
    78  
    79  		fileText := string(file)
    80  
    81  		expectedTags := []string{
    82  			`json:"name" anotherTag:"tag"`,
    83  			`json:"enum" yetAnotherTag:"12"`,
    84  			`json:"noVal" yaml:"noVal"`,
    85  			`json:"repeated" someTag:"value" repeated:"true"`,
    86  		}
    87  
    88  		for _, tag := range expectedTags {
    89  			require.True(t, strings.Contains(fileText, tag))
    90  		}
    91  	})
    92  
    93  	t.Run("concrete types implement interface", func(t *testing.T) {
    94  		var _ out.FooBarer = out.FooBarr{}
    95  	})
    96  
    97  	t.Run("implemented interfaces", func(t *testing.T) {
    98  		pkg, err := parseAst("out")
    99  		require.NoError(t, err)
   100  
   101  		path := filepath.Join("out", "generated.go")
   102  		generated := pkg.Files[path]
   103  
   104  		type field struct {
   105  			typ  string
   106  			name string
   107  		}
   108  		cases := []struct {
   109  			name       string
   110  			wantFields []field
   111  		}{
   112  			{
   113  				name: "A",
   114  				wantFields: []field{
   115  					{
   116  						typ:  "method",
   117  						name: "IsA",
   118  					},
   119  				},
   120  			},
   121  			{
   122  				name: "B",
   123  				wantFields: []field{
   124  					{
   125  						typ:  "method",
   126  						name: "IsB",
   127  					},
   128  				},
   129  			},
   130  			{
   131  				name: "C",
   132  				wantFields: []field{
   133  					{
   134  						typ:  "ident",
   135  						name: "A",
   136  					},
   137  					{
   138  						typ:  "method",
   139  						name: "IsC",
   140  					},
   141  				},
   142  			},
   143  			{
   144  				name: "D",
   145  				wantFields: []field{
   146  					{
   147  						typ:  "ident",
   148  						name: "A",
   149  					},
   150  					{
   151  						typ:  "ident",
   152  						name: "B",
   153  					},
   154  					{
   155  						typ:  "method",
   156  						name: "IsD",
   157  					},
   158  				},
   159  			},
   160  		}
   161  		for _, tc := range cases {
   162  			tc := tc
   163  			t.Run(tc.name, func(t *testing.T) {
   164  				typeSpec, ok := generated.Scope.Lookup(tc.name).Decl.(*ast.TypeSpec)
   165  				require.True(t, ok)
   166  
   167  				fields := typeSpec.Type.(*ast.InterfaceType).Methods.List
   168  				for i, want := range tc.wantFields {
   169  					if want.typ == "ident" {
   170  						ident, ok := fields[i].Type.(*ast.Ident)
   171  						require.True(t, ok)
   172  						assert.Equal(t, want.name, ident.Name)
   173  					}
   174  					if want.typ == "method" {
   175  						require.GreaterOrEqual(t, 1, len(fields[i].Names))
   176  						name := fields[i].Names[0].Name
   177  						assert.Equal(t, want.name, name)
   178  					}
   179  				}
   180  			})
   181  		}
   182  	})
   183  
   184  	t.Run("implemented interfaces type CDImplemented", func(t *testing.T) {
   185  		pkg, err := parseAst("out")
   186  		require.NoError(t, err)
   187  
   188  		path := filepath.Join("out", "generated.go")
   189  		generated := pkg.Files[path]
   190  
   191  		wantMethods := []string{
   192  			"IsA",
   193  			"IsB",
   194  			"IsC",
   195  			"IsD",
   196  		}
   197  
   198  		gots := make([]string, 0, len(wantMethods))
   199  		for _, decl := range generated.Decls {
   200  			if funcDecl, ok := decl.(*ast.FuncDecl); ok {
   201  				switch funcDecl.Name.Name {
   202  				case "IsA", "IsB", "IsC", "IsD":
   203  					gots = append(gots, funcDecl.Name.Name)
   204  					require.Len(t, funcDecl.Recv.List, 1)
   205  					recvIdent, ok := funcDecl.Recv.List[0].Type.(*ast.Ident)
   206  					require.True(t, ok)
   207  					require.Equal(t, "CDImplemented", recvIdent.Name)
   208  				}
   209  			}
   210  		}
   211  
   212  		sort.Strings(gots)
   213  		require.Equal(t, wantMethods, gots)
   214  	})
   215  
   216  	t.Run("cyclical struct fields become pointers", func(t *testing.T) {
   217  		require.Nil(t, out.CyclicalA{}.FieldOne)
   218  		require.Nil(t, out.CyclicalA{}.FieldTwo)
   219  		require.Nil(t, out.CyclicalA{}.FieldThree)
   220  		require.NotNil(t, out.CyclicalA{}.FieldFour)
   221  		require.Nil(t, out.CyclicalB{}.FieldOne)
   222  		require.Nil(t, out.CyclicalB{}.FieldTwo)
   223  		require.Nil(t, out.CyclicalB{}.FieldThree)
   224  		require.Nil(t, out.CyclicalB{}.FieldFour)
   225  		require.NotNil(t, out.CyclicalB{}.FieldFive)
   226  	})
   227  
   228  	t.Run("non-cyclical struct fields become pointers", func(t *testing.T) {
   229  		require.NotNil(t, out.NotCyclicalB{}.FieldOne)
   230  		require.Nil(t, out.NotCyclicalB{}.FieldTwo)
   231  	})
   232  
   233  	t.Run("recursive struct fields become pointers", func(t *testing.T) {
   234  		require.Nil(t, out.Recursive{}.FieldOne)
   235  		require.Nil(t, out.Recursive{}.FieldTwo)
   236  		require.Nil(t, out.Recursive{}.FieldThree)
   237  		require.NotNil(t, out.Recursive{}.FieldFour)
   238  	})
   239  
   240  	t.Run("overridden struct field names use same capitalization as config", func(t *testing.T) {
   241  		require.NotNil(t, out.RenameFieldTest{}.GOODnaME)
   242  	})
   243  }
   244  
   245  func TestModelGenerationStructFieldPointers(t *testing.T) {
   246  	cfg, err := config.LoadConfig("testdata/gqlgen_struct_field_pointers.yml")
   247  	require.NoError(t, err)
   248  	require.NoError(t, cfg.Init())
   249  	p := Plugin{
   250  		MutateHook: mutateHook,
   251  		FieldHook:  defaultFieldMutateHook,
   252  	}
   253  	require.NoError(t, p.MutateConfig(cfg))
   254  
   255  	t.Run("no pointer pointers", func(t *testing.T) {
   256  		generated, err := os.ReadFile("./out_struct_pointers/generated.go")
   257  		require.NoError(t, err)
   258  		require.NotContains(t, string(generated), "**")
   259  	})
   260  
   261  	t.Run("cyclical struct fields become pointers", func(t *testing.T) {
   262  		require.Nil(t, out_struct_pointers.CyclicalA{}.FieldOne)
   263  		require.Nil(t, out_struct_pointers.CyclicalA{}.FieldTwo)
   264  		require.Nil(t, out_struct_pointers.CyclicalA{}.FieldThree)
   265  		require.NotNil(t, out_struct_pointers.CyclicalA{}.FieldFour)
   266  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldOne)
   267  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldTwo)
   268  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldThree)
   269  		require.Nil(t, out_struct_pointers.CyclicalB{}.FieldFour)
   270  		require.NotNil(t, out_struct_pointers.CyclicalB{}.FieldFive)
   271  	})
   272  
   273  	t.Run("non-cyclical struct fields do not become pointers", func(t *testing.T) {
   274  		require.NotNil(t, out_struct_pointers.NotCyclicalB{}.FieldOne)
   275  		require.NotNil(t, out_struct_pointers.NotCyclicalB{}.FieldTwo)
   276  	})
   277  
   278  	t.Run("recursive struct fields become pointers", func(t *testing.T) {
   279  		require.Nil(t, out_struct_pointers.Recursive{}.FieldOne)
   280  		require.Nil(t, out_struct_pointers.Recursive{}.FieldTwo)
   281  		require.Nil(t, out_struct_pointers.Recursive{}.FieldThree)
   282  		require.NotNil(t, out_struct_pointers.Recursive{}.FieldFour)
   283  	})
   284  }
   285  
   286  func mutateHook(b *ModelBuild) *ModelBuild {
   287  	for _, model := range b.Models {
   288  		for _, field := range model.Fields {
   289  			field.Tag += ` database:"` + model.Name + field.Name + `"`
   290  		}
   291  	}
   292  
   293  	return b
   294  }
   295  
   296  func parseAst(path string) (*ast.Package, error) {
   297  	// test setup to parse the types
   298  	fset := token.NewFileSet()
   299  	pkgs, err := parser.ParseDir(fset, path, nil, parser.AllErrors)
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  	return pkgs["out"], nil
   304  }