github.com/vetcher/gqlgen@v0.6.0/codegen/testserver/generated_test.go (about) 1 //go:generate rm -f resolver.go 2 //go:generate gorunpkg github.com/99designs/gqlgen 3 4 package testserver 5 6 import ( 7 "context" 8 "fmt" 9 "net/http" 10 "net/http/httptest" 11 "reflect" 12 "runtime" 13 "sort" 14 "testing" 15 "time" 16 17 "github.com/stretchr/testify/assert" 18 19 "github.com/99designs/gqlgen/graphql" 20 21 "github.com/99designs/gqlgen/client" 22 "github.com/99designs/gqlgen/handler" 23 "github.com/stretchr/testify/require" 24 ) 25 26 func TestGeneratedResolversAreValid(t *testing.T) { 27 http.Handle("/query", handler.GraphQL(NewExecutableSchema(Config{ 28 Resolvers: &Resolver{}, 29 }))) 30 } 31 32 func TestForcedResolverFieldIsPointer(t *testing.T) { 33 field, ok := reflect.TypeOf((*ForcedResolverResolver)(nil)).Elem().MethodByName("Field") 34 require.True(t, ok) 35 require.Equal(t, "*testserver.Circle", field.Type.Out(0).String()) 36 } 37 38 func TestGeneratedServer(t *testing.T) { 39 resolvers := &testResolver{tick: make(chan string, 1)} 40 41 srv := httptest.NewServer( 42 handler.GraphQL( 43 NewExecutableSchema(Config{Resolvers: resolvers}), 44 handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 45 path, _ := ctx.Value("path").([]int) 46 return next(context.WithValue(ctx, "path", append(path, 1))) 47 }), 48 handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 49 path, _ := ctx.Value("path").([]int) 50 return next(context.WithValue(ctx, "path", append(path, 2))) 51 }), 52 )) 53 c := client.New(srv.URL) 54 55 t.Run("null bubbling", func(t *testing.T) { 56 t.Run("when function errors on non required field", func(t *testing.T) { 57 var resp struct { 58 Valid string 59 ErrorBubble *struct { 60 Id string 61 ErrorOnNonRequiredField *string 62 } 63 } 64 err := c.Post(`query { valid, errorBubble { id, errorOnNonRequiredField } }`, &resp) 65 66 require.EqualError(t, err, `[{"message":"boom","path":["errorBubble","errorOnNonRequiredField"]}]`) 67 require.Equal(t, "E1234", resp.ErrorBubble.Id) 68 require.Nil(t, resp.ErrorBubble.ErrorOnNonRequiredField) 69 require.Equal(t, "Ok", resp.Valid) 70 }) 71 72 t.Run("when function errors", func(t *testing.T) { 73 var resp struct { 74 Valid string 75 ErrorBubble *struct { 76 NilOnRequiredField string 77 } 78 } 79 err := c.Post(`query { valid, errorBubble { id, errorOnRequiredField } }`, &resp) 80 81 require.EqualError(t, err, `[{"message":"boom","path":["errorBubble","errorOnRequiredField"]}]`) 82 require.Nil(t, resp.ErrorBubble) 83 require.Equal(t, "Ok", resp.Valid) 84 }) 85 86 t.Run("when user returns null on required field", func(t *testing.T) { 87 var resp struct { 88 Valid string 89 ErrorBubble *struct { 90 NilOnRequiredField string 91 } 92 } 93 err := c.Post(`query { valid, errorBubble { id, nilOnRequiredField } }`, &resp) 94 95 require.EqualError(t, err, `[{"message":"must not be null","path":["errorBubble","nilOnRequiredField"]}]`) 96 require.Nil(t, resp.ErrorBubble) 97 require.Equal(t, "Ok", resp.Valid) 98 }) 99 100 }) 101 102 t.Run("middleware", func(t *testing.T) { 103 var resp struct { 104 User struct { 105 ID int 106 Friends []struct { 107 ID int 108 } 109 } 110 } 111 112 called := false 113 resolvers.userFriends = func(ctx context.Context, obj *User) ([]User, error) { 114 assert.Equal(t, []int{1, 2, 1, 2}, ctx.Value("path")) 115 called = true 116 return []User{}, nil 117 } 118 119 err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp) 120 121 require.NoError(t, err) 122 require.True(t, called) 123 }) 124 125 t.Run("subscriptions", func(t *testing.T) { 126 t.Run("wont leak goroutines", func(t *testing.T) { 127 initialGoroutineCount := runtime.NumGoroutine() 128 129 sub := c.Websocket(`subscription { updated }`) 130 131 resolvers.tick <- "message" 132 133 var msg struct { 134 resp struct { 135 Updated string 136 } 137 } 138 139 err := sub.Next(&msg.resp) 140 require.NoError(t, err) 141 require.Equal(t, "message", msg.resp.Updated) 142 sub.Close() 143 144 // need a little bit of time for goroutines to settle 145 time.Sleep(200 * time.Millisecond) 146 147 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) 148 }) 149 150 t.Run("will parse init payload", func(t *testing.T) { 151 sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{ 152 "Authorization": "Bearer of the curse", 153 "number": 32, 154 "strings": []string{"hello", "world"}, 155 }) 156 157 var msg struct { 158 resp struct { 159 InitPayload string 160 } 161 } 162 163 err := sub.Next(&msg.resp) 164 require.NoError(t, err) 165 require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload) 166 err = sub.Next(&msg.resp) 167 require.NoError(t, err) 168 require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload) 169 err = sub.Next(&msg.resp) 170 require.NoError(t, err) 171 require.Equal(t, "number = 32", msg.resp.InitPayload) 172 err = sub.Next(&msg.resp) 173 require.NoError(t, err) 174 require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload) 175 sub.Close() 176 }) 177 }) 178 179 t.Run("null args", func(t *testing.T) { 180 var resp struct { 181 NullableArg *string 182 } 183 err := c.Post(`query { nullableArg(arg: null) }`, &resp) 184 require.Nil(t, err) 185 require.Equal(t, "Ok", *resp.NullableArg) 186 }) 187 } 188 189 func TestResponseExtension(t *testing.T) { 190 srv := httptest.NewServer(handler.GraphQL( 191 NewExecutableSchema(Config{ 192 Resolvers: &testResolver{}, 193 }), 194 handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { 195 rctx := graphql.GetRequestContext(ctx) 196 if err := rctx.RegisterExtension("example", "value"); err != nil { 197 panic(err) 198 } 199 return next(ctx) 200 }), 201 )) 202 c := client.New(srv.URL) 203 204 raw, _ := c.RawPost(`query { valid }`) 205 require.Equal(t, raw.Extensions["example"], "value") 206 } 207 208 type testResolver struct { 209 tick chan string 210 userFriends func(ctx context.Context, obj *User) ([]User, error) 211 } 212 213 func (r *testResolver) ForcedResolver() ForcedResolverResolver { 214 return &forcedResolverResolver{nil} 215 } 216 217 func (r *testResolver) User() UserResolver { 218 return &testUserResolver{r} 219 } 220 221 func (r *testResolver) Query() QueryResolver { 222 return &testQueryResolver{} 223 } 224 225 type testQueryResolver struct{ queryResolver } 226 227 func (r *testQueryResolver) ErrorBubble(ctx context.Context) (*Error, error) { 228 return &Error{ID: "E1234"}, nil 229 } 230 231 func (r *testQueryResolver) Valid(ctx context.Context) (string, error) { 232 return "Ok", nil 233 } 234 235 func (r *testQueryResolver) User(ctx context.Context, id int) (User, error) { 236 return User{ID: 1}, nil 237 } 238 239 func (r *testQueryResolver) NullableArg(ctx context.Context, arg *int) (*string, error) { 240 s := "Ok" 241 return &s, nil 242 } 243 244 func (r *testResolver) Subscription() SubscriptionResolver { 245 return &testSubscriptionResolver{r} 246 } 247 248 type testUserResolver struct{ *testResolver } 249 250 func (r *testResolver) Friends(ctx context.Context, obj *User) ([]User, error) { 251 return r.userFriends(ctx, obj) 252 } 253 254 type testSubscriptionResolver struct{ *testResolver } 255 256 func (r *testSubscriptionResolver) Updated(ctx context.Context) (<-chan string, error) { 257 res := make(chan string, 1) 258 259 go func() { 260 for { 261 select { 262 case t := <-r.tick: 263 res <- t 264 case <-ctx.Done(): 265 close(res) 266 return 267 } 268 } 269 }() 270 return res, nil 271 } 272 273 func (r *testSubscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) { 274 payload := handler.GetInitPayload(ctx) 275 channel := make(chan string, len(payload)+1) 276 277 go func() { 278 <-ctx.Done() 279 close(channel) 280 }() 281 282 // Test the helper function separately 283 auth := payload.Authorization() 284 if auth != "" { 285 channel <- "AUTH:" + auth 286 } else { 287 channel <- "AUTH:NONE" 288 } 289 290 // Send them over the channel in alphabetic order 291 keys := make([]string, 0, len(payload)) 292 for key := range payload { 293 keys = append(keys, key) 294 } 295 sort.Strings(keys) 296 for _, key := range keys { 297 channel <- fmt.Sprintf("%s = %#+v", key, payload[key]) 298 } 299 300 return channel, nil 301 }