github.com/vetcher/gqlgen@v0.6.0/codegen/testserver/generated_test.go (about)

     1  //go:generate rm -f resolver.go
     2  //go:generate gorunpkg github.com/99designs/gqlgen
     3  
     4  package testserver
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"reflect"
    12  	"runtime"
    13  	"sort"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  
    19  	"github.com/99designs/gqlgen/graphql"
    20  
    21  	"github.com/99designs/gqlgen/client"
    22  	"github.com/99designs/gqlgen/handler"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestGeneratedResolversAreValid(t *testing.T) {
    27  	http.Handle("/query", handler.GraphQL(NewExecutableSchema(Config{
    28  		Resolvers: &Resolver{},
    29  	})))
    30  }
    31  
    32  func TestForcedResolverFieldIsPointer(t *testing.T) {
    33  	field, ok := reflect.TypeOf((*ForcedResolverResolver)(nil)).Elem().MethodByName("Field")
    34  	require.True(t, ok)
    35  	require.Equal(t, "*testserver.Circle", field.Type.Out(0).String())
    36  }
    37  
    38  func TestGeneratedServer(t *testing.T) {
    39  	resolvers := &testResolver{tick: make(chan string, 1)}
    40  
    41  	srv := httptest.NewServer(
    42  		handler.GraphQL(
    43  			NewExecutableSchema(Config{Resolvers: resolvers}),
    44  			handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    45  				path, _ := ctx.Value("path").([]int)
    46  				return next(context.WithValue(ctx, "path", append(path, 1)))
    47  			}),
    48  			handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    49  				path, _ := ctx.Value("path").([]int)
    50  				return next(context.WithValue(ctx, "path", append(path, 2)))
    51  			}),
    52  		))
    53  	c := client.New(srv.URL)
    54  
    55  	t.Run("null bubbling", func(t *testing.T) {
    56  		t.Run("when function errors on non required field", func(t *testing.T) {
    57  			var resp struct {
    58  				Valid       string
    59  				ErrorBubble *struct {
    60  					Id                      string
    61  					ErrorOnNonRequiredField *string
    62  				}
    63  			}
    64  			err := c.Post(`query { valid, errorBubble { id, errorOnNonRequiredField } }`, &resp)
    65  
    66  			require.EqualError(t, err, `[{"message":"boom","path":["errorBubble","errorOnNonRequiredField"]}]`)
    67  			require.Equal(t, "E1234", resp.ErrorBubble.Id)
    68  			require.Nil(t, resp.ErrorBubble.ErrorOnNonRequiredField)
    69  			require.Equal(t, "Ok", resp.Valid)
    70  		})
    71  
    72  		t.Run("when function errors", func(t *testing.T) {
    73  			var resp struct {
    74  				Valid       string
    75  				ErrorBubble *struct {
    76  					NilOnRequiredField string
    77  				}
    78  			}
    79  			err := c.Post(`query { valid, errorBubble { id, errorOnRequiredField } }`, &resp)
    80  
    81  			require.EqualError(t, err, `[{"message":"boom","path":["errorBubble","errorOnRequiredField"]}]`)
    82  			require.Nil(t, resp.ErrorBubble)
    83  			require.Equal(t, "Ok", resp.Valid)
    84  		})
    85  
    86  		t.Run("when user returns null on required field", func(t *testing.T) {
    87  			var resp struct {
    88  				Valid       string
    89  				ErrorBubble *struct {
    90  					NilOnRequiredField string
    91  				}
    92  			}
    93  			err := c.Post(`query { valid, errorBubble { id, nilOnRequiredField } }`, &resp)
    94  
    95  			require.EqualError(t, err, `[{"message":"must not be null","path":["errorBubble","nilOnRequiredField"]}]`)
    96  			require.Nil(t, resp.ErrorBubble)
    97  			require.Equal(t, "Ok", resp.Valid)
    98  		})
    99  
   100  	})
   101  
   102  	t.Run("middleware", func(t *testing.T) {
   103  		var resp struct {
   104  			User struct {
   105  				ID      int
   106  				Friends []struct {
   107  					ID int
   108  				}
   109  			}
   110  		}
   111  
   112  		called := false
   113  		resolvers.userFriends = func(ctx context.Context, obj *User) ([]User, error) {
   114  			assert.Equal(t, []int{1, 2, 1, 2}, ctx.Value("path"))
   115  			called = true
   116  			return []User{}, nil
   117  		}
   118  
   119  		err := c.Post(`query { user(id: 1) { id, friends { id } } }`, &resp)
   120  
   121  		require.NoError(t, err)
   122  		require.True(t, called)
   123  	})
   124  
   125  	t.Run("subscriptions", func(t *testing.T) {
   126  		t.Run("wont leak goroutines", func(t *testing.T) {
   127  			initialGoroutineCount := runtime.NumGoroutine()
   128  
   129  			sub := c.Websocket(`subscription { updated }`)
   130  
   131  			resolvers.tick <- "message"
   132  
   133  			var msg struct {
   134  				resp struct {
   135  					Updated string
   136  				}
   137  			}
   138  
   139  			err := sub.Next(&msg.resp)
   140  			require.NoError(t, err)
   141  			require.Equal(t, "message", msg.resp.Updated)
   142  			sub.Close()
   143  
   144  			// need a little bit of time for goroutines to settle
   145  			time.Sleep(200 * time.Millisecond)
   146  
   147  			require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
   148  		})
   149  
   150  		t.Run("will parse init payload", func(t *testing.T) {
   151  			sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{
   152  				"Authorization": "Bearer of the curse",
   153  				"number":        32,
   154  				"strings":       []string{"hello", "world"},
   155  			})
   156  
   157  			var msg struct {
   158  				resp struct {
   159  					InitPayload string
   160  				}
   161  			}
   162  
   163  			err := sub.Next(&msg.resp)
   164  			require.NoError(t, err)
   165  			require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload)
   166  			err = sub.Next(&msg.resp)
   167  			require.NoError(t, err)
   168  			require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload)
   169  			err = sub.Next(&msg.resp)
   170  			require.NoError(t, err)
   171  			require.Equal(t, "number = 32", msg.resp.InitPayload)
   172  			err = sub.Next(&msg.resp)
   173  			require.NoError(t, err)
   174  			require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload)
   175  			sub.Close()
   176  		})
   177  	})
   178  
   179  	t.Run("null args", func(t *testing.T) {
   180  		var resp struct {
   181  			NullableArg *string
   182  		}
   183  		err := c.Post(`query { nullableArg(arg: null) }`, &resp)
   184  		require.Nil(t, err)
   185  		require.Equal(t, "Ok", *resp.NullableArg)
   186  	})
   187  }
   188  
   189  func TestResponseExtension(t *testing.T) {
   190  	srv := httptest.NewServer(handler.GraphQL(
   191  		NewExecutableSchema(Config{
   192  			Resolvers: &testResolver{},
   193  		}),
   194  		handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   195  			rctx := graphql.GetRequestContext(ctx)
   196  			if err := rctx.RegisterExtension("example", "value"); err != nil {
   197  				panic(err)
   198  			}
   199  			return next(ctx)
   200  		}),
   201  	))
   202  	c := client.New(srv.URL)
   203  
   204  	raw, _ := c.RawPost(`query { valid }`)
   205  	require.Equal(t, raw.Extensions["example"], "value")
   206  }
   207  
   208  type testResolver struct {
   209  	tick        chan string
   210  	userFriends func(ctx context.Context, obj *User) ([]User, error)
   211  }
   212  
   213  func (r *testResolver) ForcedResolver() ForcedResolverResolver {
   214  	return &forcedResolverResolver{nil}
   215  }
   216  
   217  func (r *testResolver) User() UserResolver {
   218  	return &testUserResolver{r}
   219  }
   220  
   221  func (r *testResolver) Query() QueryResolver {
   222  	return &testQueryResolver{}
   223  }
   224  
   225  type testQueryResolver struct{ queryResolver }
   226  
   227  func (r *testQueryResolver) ErrorBubble(ctx context.Context) (*Error, error) {
   228  	return &Error{ID: "E1234"}, nil
   229  }
   230  
   231  func (r *testQueryResolver) Valid(ctx context.Context) (string, error) {
   232  	return "Ok", nil
   233  }
   234  
   235  func (r *testQueryResolver) User(ctx context.Context, id int) (User, error) {
   236  	return User{ID: 1}, nil
   237  }
   238  
   239  func (r *testQueryResolver) NullableArg(ctx context.Context, arg *int) (*string, error) {
   240  	s := "Ok"
   241  	return &s, nil
   242  }
   243  
   244  func (r *testResolver) Subscription() SubscriptionResolver {
   245  	return &testSubscriptionResolver{r}
   246  }
   247  
   248  type testUserResolver struct{ *testResolver }
   249  
   250  func (r *testResolver) Friends(ctx context.Context, obj *User) ([]User, error) {
   251  	return r.userFriends(ctx, obj)
   252  }
   253  
   254  type testSubscriptionResolver struct{ *testResolver }
   255  
   256  func (r *testSubscriptionResolver) Updated(ctx context.Context) (<-chan string, error) {
   257  	res := make(chan string, 1)
   258  
   259  	go func() {
   260  		for {
   261  			select {
   262  			case t := <-r.tick:
   263  				res <- t
   264  			case <-ctx.Done():
   265  				close(res)
   266  				return
   267  			}
   268  		}
   269  	}()
   270  	return res, nil
   271  }
   272  
   273  func (r *testSubscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) {
   274  	payload := handler.GetInitPayload(ctx)
   275  	channel := make(chan string, len(payload)+1)
   276  
   277  	go func() {
   278  		<-ctx.Done()
   279  		close(channel)
   280  	}()
   281  
   282  	// Test the helper function separately
   283  	auth := payload.Authorization()
   284  	if auth != "" {
   285  		channel <- "AUTH:" + auth
   286  	} else {
   287  		channel <- "AUTH:NONE"
   288  	}
   289  
   290  	// Send them over the channel in alphabetic order
   291  	keys := make([]string, 0, len(payload))
   292  	for key := range payload {
   293  		keys = append(keys, key)
   294  	}
   295  	sort.Strings(keys)
   296  	for _, key := range keys {
   297  		channel <- fmt.Sprintf("%s = %#+v", key, payload[key])
   298  	}
   299  
   300  	return channel, nil
   301  }