github.com/pouriasharifi/gqlgen@v0.7.2/codegen/testserver/generated_test.go (about)

     1  //go:generate rm -f resolver.go
     2  //go:generate gorunpkg github.com/99designs/gqlgen
     3  
     4  package testserver
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"reflect"
    12  	"runtime"
    13  	"sort"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/99designs/gqlgen/graphql/introspection"
    19  
    20  	"github.com/99designs/gqlgen/client"
    21  	"github.com/99designs/gqlgen/graphql"
    22  	"github.com/99designs/gqlgen/handler"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  )
    26  
    27  func TestGeneratedResolversAreValid(t *testing.T) {
    28  	http.Handle("/query", handler.GraphQL(NewExecutableSchema(Config{
    29  		Resolvers: &Resolver{},
    30  	})))
    31  }
    32  
    33  func TestForcedResolverFieldIsPointer(t *testing.T) {
    34  	field, ok := reflect.TypeOf((*ForcedResolverResolver)(nil)).Elem().MethodByName("Field")
    35  	require.True(t, ok)
    36  	require.Equal(t, "*testserver.Circle", field.Type.Out(0).String())
    37  }
    38  
    39  func TestGeneratedServer(t *testing.T) {
    40  	resolvers := &testResolver{tick: make(chan string, 1)}
    41  
    42  	srv := httptest.NewServer(
    43  		handler.GraphQL(
    44  			NewExecutableSchema(Config{Resolvers: resolvers}),
    45  			handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    46  				path, _ := ctx.Value("path").([]int)
    47  				return next(context.WithValue(ctx, "path", append(path, 1)))
    48  			}),
    49  			handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    50  				path, _ := ctx.Value("path").([]int)
    51  				return next(context.WithValue(ctx, "path", append(path, 2)))
    52  			}),
    53  		))
    54  	c := client.New(srv.URL)
    55  
    56  	t.Run("null bubbling", func(t *testing.T) {
    57  		t.Run("when function errors on non required field", func(t *testing.T) {
    58  			var resp struct {
    59  				Valid       string
    60  				ErrorBubble *struct {
    61  					Id                      string
    62  					ErrorOnNonRequiredField *string
    63  				}
    64  			}
    65  			err := c.Post(`query { valid, errorBubble { id, errorOnNonRequiredField } }`, &resp)
    66  
    67  			require.EqualError(t, err, `[{"message":"boom","path":["errorBubble","errorOnNonRequiredField"]}]`)
    68  			require.Equal(t, "E1234", resp.ErrorBubble.Id)
    69  			require.Nil(t, resp.ErrorBubble.ErrorOnNonRequiredField)
    70  			require.Equal(t, "Ok", resp.Valid)
    71  		})
    72  
    73  		t.Run("when function errors", func(t *testing.T) {
    74  			var resp struct {
    75  				Valid       string
    76  				ErrorBubble *struct {
    77  					NilOnRequiredField string
    78  				}
    79  			}
    80  			err := c.Post(`query { valid, errorBubble { id, errorOnRequiredField } }`, &resp)
    81  
    82  			require.EqualError(t, err, `[{"message":"boom","path":["errorBubble","errorOnRequiredField"]}]`)
    83  			require.Nil(t, resp.ErrorBubble)
    84  			require.Equal(t, "Ok", resp.Valid)
    85  		})
    86  
    87  		t.Run("when user returns null on required field", func(t *testing.T) {
    88  			var resp struct {
    89  				Valid       string
    90  				ErrorBubble *struct {
    91  					NilOnRequiredField string
    92  				}
    93  			}
    94  			err := c.Post(`query { valid, errorBubble { id, nilOnRequiredField } }`, &resp)
    95  
    96  			require.EqualError(t, err, `[{"message":"must not be null","path":["errorBubble","nilOnRequiredField"]}]`)
    97  			require.Nil(t, resp.ErrorBubble)
    98  			require.Equal(t, "Ok", resp.Valid)
    99  		})
   100  
   101  	})
   102  
   103  	t.Run("middleware", func(t *testing.T) {
   104  		var resp struct {
   105  			User struct {
   106  				ID      int
   107  				Friends []struct {
   108  					ID int
   109  				}
   110  			}
   111  		}
   112  
   113  		called := false
   114  		resolvers.userFriends = func(ctx context.Context, obj *User) ([]User, error) {
   115  			assert.Equal(t, []int{1, 2, 1, 2}, ctx.Value("path"))
   116  			called = true
   117  			return []User{}, nil
   118  		}
   119  
   120  		err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp)
   121  
   122  		require.NoError(t, err)
   123  		require.True(t, called)
   124  	})
   125  
   126  	t.Run("subscriptions", func(t *testing.T) {
   127  		t.Run("wont leak goroutines", func(t *testing.T) {
   128  			initialGoroutineCount := runtime.NumGoroutine()
   129  
   130  			sub := c.Websocket(`subscription { updated }`)
   131  
   132  			resolvers.tick <- "message"
   133  
   134  			var msg struct {
   135  				resp struct {
   136  					Updated string
   137  				}
   138  			}
   139  
   140  			err := sub.Next(&msg.resp)
   141  			require.NoError(t, err)
   142  			require.Equal(t, "message", msg.resp.Updated)
   143  			sub.Close()
   144  
   145  			// need a little bit of time for goroutines to settle
   146  			start := time.Now()
   147  			for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
   148  				time.Sleep(5 * time.Millisecond)
   149  			}
   150  
   151  			require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
   152  		})
   153  
   154  		t.Run("will parse init payload", func(t *testing.T) {
   155  			sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{
   156  				"Authorization": "Bearer of the curse",
   157  				"number":        32,
   158  				"strings":       []string{"hello", "world"},
   159  			})
   160  
   161  			var msg struct {
   162  				resp struct {
   163  					InitPayload string
   164  				}
   165  			}
   166  
   167  			err := sub.Next(&msg.resp)
   168  			require.NoError(t, err)
   169  			require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload)
   170  			err = sub.Next(&msg.resp)
   171  			require.NoError(t, err)
   172  			require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload)
   173  			err = sub.Next(&msg.resp)
   174  			require.NoError(t, err)
   175  			require.Equal(t, "number = 32", msg.resp.InitPayload)
   176  			err = sub.Next(&msg.resp)
   177  			require.NoError(t, err)
   178  			require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload)
   179  			sub.Close()
   180  		})
   181  	})
   182  
   183  	t.Run("null args", func(t *testing.T) {
   184  		var resp struct {
   185  			NullableArg *string
   186  		}
   187  		err := c.Post(`query { nullableArg(arg: null) }`, &resp)
   188  		require.Nil(t, err)
   189  		require.Equal(t, "Ok", *resp.NullableArg)
   190  	})
   191  }
   192  
   193  func TestIntrospection(t *testing.T) {
   194  	t.Run("disabled", func(t *testing.T) {
   195  		resolvers := &testResolver{tick: make(chan string, 1)}
   196  
   197  		srv := httptest.NewServer(
   198  			handler.GraphQL(
   199  				NewExecutableSchema(Config{Resolvers: resolvers}),
   200  				handler.IntrospectionEnabled(false),
   201  			),
   202  		)
   203  
   204  		c := client.New(srv.URL)
   205  
   206  		var resp interface{}
   207  		err := c.Post(introspection.Query, &resp)
   208  		require.EqualError(t, err, "[{\"message\":\"introspection disabled\",\"path\":[\"__schema\"]}]")
   209  	})
   210  
   211  	t.Run("enabled by default", func(t *testing.T) {
   212  		resolvers := &testResolver{tick: make(chan string, 1)}
   213  
   214  		srv := httptest.NewServer(
   215  			handler.GraphQL(
   216  				NewExecutableSchema(Config{Resolvers: resolvers}),
   217  			),
   218  		)
   219  
   220  		c := client.New(srv.URL)
   221  
   222  		var resp interface{}
   223  		err := c.Post(introspection.Query, &resp)
   224  		require.NoError(t, err)
   225  
   226  		t.Run("does not return empty deprecation strings", func(t *testing.T) {
   227  			q := `{
   228  			  __type(name:"InnerObject") {
   229  			    fields {
   230  			      name
   231  			      deprecationReason
   232  			    }
   233  			  }
   234  			}`
   235  
   236  			c := client.New(srv.URL)
   237  			var resp struct {
   238  				Type struct {
   239  					Fields []struct {
   240  						Name              string
   241  						DeprecationReason *string
   242  					}
   243  				} `json:"__type"`
   244  			}
   245  			err := c.Post(q, &resp)
   246  			require.NoError(t, err)
   247  
   248  			require.Equal(t, "id", resp.Type.Fields[0].Name)
   249  			require.Nil(t, resp.Type.Fields[0].DeprecationReason)
   250  		})
   251  	})
   252  
   253  	t.Run("disabled by middleware", func(t *testing.T) {
   254  		resolvers := &testResolver{tick: make(chan string, 1)}
   255  
   256  		srv := httptest.NewServer(
   257  			handler.GraphQL(
   258  				NewExecutableSchema(Config{Resolvers: resolvers}),
   259  				handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   260  					graphql.GetRequestContext(ctx).DisableIntrospection = true
   261  
   262  					return next(ctx)
   263  				}),
   264  			),
   265  		)
   266  
   267  		c := client.New(srv.URL)
   268  
   269  		var resp interface{}
   270  		err := c.Post(introspection.Query, &resp)
   271  		require.EqualError(t, err, "[{\"message\":\"introspection disabled\",\"path\":[\"__schema\"]}]")
   272  	})
   273  
   274  }
   275  
   276  var _ graphql.Tracer = (*testTracer)(nil)
   277  
   278  type testTracer struct {
   279  	id     int
   280  	append func(string)
   281  }
   282  
   283  func (tt *testTracer) StartOperationParsing(ctx context.Context) context.Context {
   284  	line := fmt.Sprintf("op:p:start:%d", tt.id)
   285  
   286  	tracerLogs, _ := ctx.Value("tracer").([]string)
   287  	ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line))
   288  	tt.append(line)
   289  
   290  	return ctx
   291  }
   292  
   293  func (tt *testTracer) EndOperationParsing(ctx context.Context) {
   294  	tt.append(fmt.Sprintf("op:p:end:%d", tt.id))
   295  }
   296  
   297  func (tt *testTracer) StartOperationValidation(ctx context.Context) context.Context {
   298  	line := fmt.Sprintf("op:v:start:%d", tt.id)
   299  
   300  	tracerLogs, _ := ctx.Value("tracer").([]string)
   301  	ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line))
   302  	tt.append(line)
   303  
   304  	return ctx
   305  }
   306  
   307  func (tt *testTracer) EndOperationValidation(ctx context.Context) {
   308  	tt.append(fmt.Sprintf("op:v:end:%d", tt.id))
   309  }
   310  
   311  func (tt *testTracer) StartOperationExecution(ctx context.Context) context.Context {
   312  	line := fmt.Sprintf("op:e:start:%d", tt.id)
   313  
   314  	tracerLogs, _ := ctx.Value("tracer").([]string)
   315  	ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line))
   316  	tt.append(line)
   317  
   318  	return ctx
   319  }
   320  
   321  func (tt *testTracer) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
   322  	line := fmt.Sprintf("field'a:e:start:%d:%s", tt.id, field.Name)
   323  
   324  	tracerLogs, _ := ctx.Value("tracer").([]string)
   325  	ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line))
   326  	tt.append(line)
   327  
   328  	return ctx
   329  }
   330  
   331  func (tt *testTracer) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
   332  	line := fmt.Sprintf("field'b:e:start:%d:%v", tt.id, rc.Path())
   333  
   334  	tracerLogs, _ := ctx.Value("tracer").([]string)
   335  	ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line))
   336  	tt.append(line)
   337  
   338  	return ctx
   339  }
   340  
   341  func (tt *testTracer) StartFieldChildExecution(ctx context.Context) context.Context {
   342  	line := fmt.Sprintf("field'c:e:start:%d", tt.id)
   343  
   344  	tracerLogs, _ := ctx.Value("tracer").([]string)
   345  	ctx = context.WithValue(ctx, "tracer", append(append([]string{}, tracerLogs...), line))
   346  	tt.append(line)
   347  
   348  	return ctx
   349  }
   350  
   351  func (tt *testTracer) EndFieldExecution(ctx context.Context) {
   352  	tt.append(fmt.Sprintf("field:e:end:%d", tt.id))
   353  }
   354  
   355  func (tt *testTracer) EndOperationExecution(ctx context.Context) {
   356  	tt.append(fmt.Sprintf("op:e:end:%d", tt.id))
   357  }
   358  
   359  var _ graphql.Tracer = (*configurableTracer)(nil)
   360  
   361  type configurableTracer struct {
   362  	StartOperationParsingCallback       func(ctx context.Context) context.Context
   363  	EndOperationParsingCallback         func(ctx context.Context)
   364  	StartOperationValidationCallback    func(ctx context.Context) context.Context
   365  	EndOperationValidationCallback      func(ctx context.Context)
   366  	StartOperationExecutionCallback     func(ctx context.Context) context.Context
   367  	StartFieldExecutionCallback         func(ctx context.Context, field graphql.CollectedField) context.Context
   368  	StartFieldResolverExecutionCallback func(ctx context.Context, rc *graphql.ResolverContext) context.Context
   369  	StartFieldChildExecutionCallback    func(ctx context.Context) context.Context
   370  	EndFieldExecutionCallback           func(ctx context.Context)
   371  	EndOperationExecutionCallback       func(ctx context.Context)
   372  }
   373  
   374  func (ct *configurableTracer) StartOperationParsing(ctx context.Context) context.Context {
   375  	if f := ct.StartOperationParsingCallback; f != nil {
   376  		ctx = f(ctx)
   377  	}
   378  	return ctx
   379  }
   380  
   381  func (ct *configurableTracer) EndOperationParsing(ctx context.Context) {
   382  	if f := ct.EndOperationParsingCallback; f != nil {
   383  		f(ctx)
   384  	}
   385  }
   386  
   387  func (ct *configurableTracer) StartOperationValidation(ctx context.Context) context.Context {
   388  	if f := ct.StartOperationValidationCallback; f != nil {
   389  		ctx = f(ctx)
   390  	}
   391  	return ctx
   392  }
   393  
   394  func (ct *configurableTracer) EndOperationValidation(ctx context.Context) {
   395  	if f := ct.EndOperationValidationCallback; f != nil {
   396  		f(ctx)
   397  	}
   398  }
   399  
   400  func (ct *configurableTracer) StartOperationExecution(ctx context.Context) context.Context {
   401  	if f := ct.StartOperationExecutionCallback; f != nil {
   402  		ctx = f(ctx)
   403  	}
   404  	return ctx
   405  }
   406  
   407  func (ct *configurableTracer) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
   408  	if f := ct.StartFieldExecutionCallback; f != nil {
   409  		ctx = f(ctx, field)
   410  	}
   411  	return ctx
   412  }
   413  
   414  func (ct *configurableTracer) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
   415  	if f := ct.StartFieldResolverExecutionCallback; f != nil {
   416  		ctx = f(ctx, rc)
   417  	}
   418  	return ctx
   419  }
   420  
   421  func (ct *configurableTracer) StartFieldChildExecution(ctx context.Context) context.Context {
   422  	if f := ct.StartFieldChildExecutionCallback; f != nil {
   423  		ctx = f(ctx)
   424  	}
   425  	return ctx
   426  }
   427  
   428  func (ct *configurableTracer) EndFieldExecution(ctx context.Context) {
   429  	if f := ct.EndFieldExecutionCallback; f != nil {
   430  		f(ctx)
   431  	}
   432  }
   433  
   434  func (ct *configurableTracer) EndOperationExecution(ctx context.Context) {
   435  	if f := ct.EndOperationExecutionCallback; f != nil {
   436  		f(ctx)
   437  	}
   438  }
   439  
   440  func TestTracer(t *testing.T) {
   441  	t.Run("called in the correct order", func(t *testing.T) {
   442  		resolvers := &testResolver{tick: make(chan string, 1)}
   443  
   444  		var tracerLog []string
   445  		var mu sync.Mutex
   446  
   447  		srv := httptest.NewServer(
   448  			handler.GraphQL(
   449  				NewExecutableSchema(Config{Resolvers: resolvers}),
   450  				handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
   451  					path, _ := ctx.Value("path").([]int)
   452  					return next(context.WithValue(ctx, "path", append(path, 1)))
   453  				}),
   454  				handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
   455  					path, _ := ctx.Value("path").([]int)
   456  					return next(context.WithValue(ctx, "path", append(path, 2)))
   457  				}),
   458  				handler.Tracer(&testTracer{
   459  					id: 1,
   460  					append: func(s string) {
   461  						mu.Lock()
   462  						defer mu.Unlock()
   463  						tracerLog = append(tracerLog, s)
   464  					},
   465  				}),
   466  				handler.Tracer(&testTracer{
   467  					id: 2,
   468  					append: func(s string) {
   469  						mu.Lock()
   470  						defer mu.Unlock()
   471  						tracerLog = append(tracerLog, s)
   472  					},
   473  				}),
   474  			))
   475  		defer srv.Close()
   476  		c := client.New(srv.URL)
   477  
   478  		var resp struct {
   479  			User struct {
   480  				ID      int
   481  				Friends []struct {
   482  					ID int
   483  				}
   484  			}
   485  		}
   486  
   487  		called := false
   488  		resolvers.userFriends = func(ctx context.Context, obj *User) ([]User, error) {
   489  			assert.Equal(t, []string{
   490  				"op:p:start:1", "op:p:start:2",
   491  				"op:v:start:1", "op:v:start:2",
   492  				"op:e:start:1", "op:e:start:2",
   493  				"field'a:e:start:1:user", "field'a:e:start:2:user",
   494  				"field'b:e:start:1:[user]", "field'b:e:start:2:[user]",
   495  				"field'c:e:start:1", "field'c:e:start:2",
   496  				"field'a:e:start:1:friends", "field'a:e:start:2:friends",
   497  				"field'b:e:start:1:[user friends]", "field'b:e:start:2:[user friends]",
   498  			}, ctx.Value("tracer"))
   499  			called = true
   500  			return []User{}, nil
   501  		}
   502  
   503  		err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp)
   504  
   505  		require.NoError(t, err)
   506  		require.True(t, called)
   507  		mu.Lock()
   508  		defer mu.Unlock()
   509  		assert.Equal(t, []string{
   510  			"op:p:start:1", "op:p:start:2",
   511  			"op:p:end:2", "op:p:end:1",
   512  
   513  			"op:v:start:1", "op:v:start:2",
   514  			"op:v:end:2", "op:v:end:1",
   515  
   516  			"op:e:start:1", "op:e:start:2",
   517  
   518  			"field'a:e:start:1:user", "field'a:e:start:2:user",
   519  			"field'b:e:start:1:[user]", "field'b:e:start:2:[user]",
   520  			"field'c:e:start:1", "field'c:e:start:2",
   521  			"field'a:e:start:1:id", "field'a:e:start:2:id",
   522  			"field'b:e:start:1:[user id]", "field'b:e:start:2:[user id]",
   523  			"field'c:e:start:1", "field'c:e:start:2",
   524  			"field:e:end:2", "field:e:end:1",
   525  			"field'a:e:start:1:friends", "field'a:e:start:2:friends",
   526  			"field'b:e:start:1:[user friends]", "field'b:e:start:2:[user friends]",
   527  			"field'c:e:start:1", "field'c:e:start:2",
   528  			"field:e:end:2", "field:e:end:1",
   529  			"field:e:end:2", "field:e:end:1",
   530  
   531  			"op:e:end:2", "op:e:end:1",
   532  		}, tracerLog)
   533  	})
   534  
   535  	t.Run("take ctx over from prev step", func(t *testing.T) {
   536  		resolvers := &testResolver{tick: make(chan string, 1)}
   537  
   538  		configurableTracer := &configurableTracer{
   539  			StartOperationParsingCallback: func(ctx context.Context) context.Context {
   540  				return context.WithValue(ctx, "StartOperationParsing", true)
   541  			},
   542  			EndOperationParsingCallback: func(ctx context.Context) {
   543  				assert.NotNil(t, ctx.Value("StartOperationParsing"))
   544  			},
   545  
   546  			StartOperationValidationCallback: func(ctx context.Context) context.Context {
   547  				return context.WithValue(ctx, "StartOperationValidation", true)
   548  			},
   549  			EndOperationValidationCallback: func(ctx context.Context) {
   550  				assert.NotNil(t, ctx.Value("StartOperationParsing"))
   551  				assert.NotNil(t, ctx.Value("StartOperationValidation"))
   552  			},
   553  
   554  			StartOperationExecutionCallback: func(ctx context.Context) context.Context {
   555  				return context.WithValue(ctx, "StartOperationExecution", true)
   556  			},
   557  			StartFieldExecutionCallback: func(ctx context.Context, field graphql.CollectedField) context.Context {
   558  				return context.WithValue(ctx, "StartFieldExecution", true)
   559  			},
   560  			StartFieldResolverExecutionCallback: func(ctx context.Context, rc *graphql.ResolverContext) context.Context {
   561  				return context.WithValue(ctx, "StartFieldResolverExecution", true)
   562  			},
   563  			StartFieldChildExecutionCallback: func(ctx context.Context) context.Context {
   564  				return context.WithValue(ctx, "StartFieldChildExecution", true)
   565  			},
   566  			EndFieldExecutionCallback: func(ctx context.Context) {
   567  				assert.NotNil(t, ctx.Value("StartOperationParsing"))
   568  				assert.NotNil(t, ctx.Value("StartOperationValidation"))
   569  				assert.NotNil(t, ctx.Value("StartOperationExecution"))
   570  				assert.NotNil(t, ctx.Value("StartFieldExecution"))
   571  				assert.NotNil(t, ctx.Value("StartFieldResolverExecution"))
   572  				assert.NotNil(t, ctx.Value("StartFieldChildExecution"))
   573  			},
   574  
   575  			EndOperationExecutionCallback: func(ctx context.Context) {
   576  				assert.NotNil(t, ctx.Value("StartOperationParsing"))
   577  				assert.NotNil(t, ctx.Value("StartOperationValidation"))
   578  				assert.NotNil(t, ctx.Value("StartOperationExecution"))
   579  			},
   580  		}
   581  
   582  		srv := httptest.NewServer(
   583  			handler.GraphQL(
   584  				NewExecutableSchema(Config{Resolvers: resolvers}),
   585  				handler.Tracer(configurableTracer),
   586  			))
   587  		defer srv.Close()
   588  		c := client.New(srv.URL)
   589  
   590  		var resp struct {
   591  			User struct {
   592  				ID      int
   593  				Friends []struct {
   594  					ID int
   595  				}
   596  			}
   597  		}
   598  
   599  		called := false
   600  		resolvers.userFriends = func(ctx context.Context, obj *User) ([]User, error) {
   601  			called = true
   602  			return []User{}, nil
   603  		}
   604  
   605  		err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp)
   606  
   607  		require.NoError(t, err)
   608  		require.True(t, called)
   609  	})
   610  
   611  	t.Run("model methods", func(t *testing.T) {
   612  		srv := httptest.NewServer(
   613  			handler.GraphQL(
   614  				NewExecutableSchema(Config{Resolvers: &testResolver{}}),
   615  			))
   616  		defer srv.Close()
   617  		c := client.New(srv.URL)
   618  		t.Run("without context", func(t *testing.T) {
   619  			var resp struct {
   620  				ModelMethods struct {
   621  					NoContext bool
   622  				}
   623  			}
   624  			err := c.Post(`query { modelMethods{ noContext } }`, &resp)
   625  			require.NoError(t, err)
   626  			require.True(t, resp.ModelMethods.NoContext)
   627  		})
   628  		t.Run("with context", func(t *testing.T) {
   629  			var resp struct {
   630  				ModelMethods struct {
   631  					WithContext bool
   632  				}
   633  			}
   634  			err := c.Post(`query { modelMethods{ withContext } }`, &resp)
   635  			require.NoError(t, err)
   636  			require.True(t, resp.ModelMethods.WithContext)
   637  		})
   638  	})
   639  }
   640  
   641  func TestResponseExtension(t *testing.T) {
   642  	srv := httptest.NewServer(handler.GraphQL(
   643  		NewExecutableSchema(Config{
   644  			Resolvers: &testResolver{},
   645  		}),
   646  		handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   647  			rctx := graphql.GetRequestContext(ctx)
   648  			if err := rctx.RegisterExtension("example", "value"); err != nil {
   649  				panic(err)
   650  			}
   651  			return next(ctx)
   652  		}),
   653  	))
   654  	c := client.New(srv.URL)
   655  
   656  	raw, _ := c.RawPost(`query { valid }`)
   657  	require.Equal(t, raw.Extensions["example"], "value")
   658  }
   659  
   660  type testResolver struct {
   661  	tick        chan string
   662  	userFriends func(ctx context.Context, obj *User) ([]User, error)
   663  }
   664  
   665  func (r *testResolver) ForcedResolver() ForcedResolverResolver {
   666  	return &forcedResolverResolver{nil}
   667  }
   668  
   669  func (r *testResolver) User() UserResolver {
   670  	return &testUserResolver{r}
   671  }
   672  
   673  func (r *testResolver) Query() QueryResolver {
   674  	return &testQueryResolver{}
   675  }
   676  func (r *testResolver) ModelMethods() ModelMethodsResolver {
   677  	return &testModelMethodsResolver{}
   678  }
   679  
   680  type testModelMethodsResolver struct{}
   681  
   682  func (r *testModelMethodsResolver) ResolverField(ctx context.Context, obj *ModelMethods) (bool, error) {
   683  	return true, nil
   684  }
   685  
   686  type testQueryResolver struct{ queryResolver }
   687  
   688  func (r *testQueryResolver) ErrorBubble(ctx context.Context) (*Error, error) {
   689  	return &Error{ID: "E1234"}, nil
   690  }
   691  
   692  func (r *testQueryResolver) Valid(ctx context.Context) (string, error) {
   693  	return "Ok", nil
   694  }
   695  
   696  func (r *testQueryResolver) User(ctx context.Context, id int) (User, error) {
   697  	return User{ID: 1}, nil
   698  }
   699  
   700  func (r *testQueryResolver) NullableArg(ctx context.Context, arg *int) (*string, error) {
   701  	s := "Ok"
   702  	return &s, nil
   703  }
   704  
   705  func (r *testQueryResolver) ModelMethods(ctx context.Context) (*ModelMethods, error) {
   706  	return &ModelMethods{}, nil
   707  }
   708  
   709  func (r *testResolver) Subscription() SubscriptionResolver {
   710  	return &testSubscriptionResolver{r}
   711  }
   712  
   713  type testUserResolver struct{ *testResolver }
   714  
   715  func (r *testResolver) Friends(ctx context.Context, obj *User) ([]User, error) {
   716  	return r.userFriends(ctx, obj)
   717  }
   718  
   719  type testSubscriptionResolver struct{ *testResolver }
   720  
   721  func (r *testSubscriptionResolver) Updated(ctx context.Context) (<-chan string, error) {
   722  	res := make(chan string, 1)
   723  
   724  	go func() {
   725  		for {
   726  			select {
   727  			case t := <-r.tick:
   728  				res <- t
   729  			case <-ctx.Done():
   730  				close(res)
   731  				return
   732  			}
   733  		}
   734  	}()
   735  	return res, nil
   736  }
   737  
   738  func (r *testSubscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) {
   739  	payload := handler.GetInitPayload(ctx)
   740  	channel := make(chan string, len(payload)+1)
   741  
   742  	go func() {
   743  		<-ctx.Done()
   744  		close(channel)
   745  	}()
   746  
   747  	// Test the helper function separately
   748  	auth := payload.Authorization()
   749  	if auth != "" {
   750  		channel <- "AUTH:" + auth
   751  	} else {
   752  		channel <- "AUTH:NONE"
   753  	}
   754  
   755  	// Send them over the channel in alphabetic order
   756  	keys := make([]string, 0, len(payload))
   757  	for key := range payload {
   758  		keys = append(keys, key)
   759  	}
   760  	sort.Strings(keys)
   761  	for _, key := range keys {
   762  		channel <- fmt.Sprintf("%s = %#+v", key, payload[key])
   763  	}
   764  
   765  	return channel, nil
   766  }