github.com/99designs/gqlgen@v0.17.45/graphql/handler/server_test.go (about) 1 package handler_test 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "net/url" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 "github.com/vektah/gqlparser/v2/ast" 14 "github.com/vektah/gqlparser/v2/gqlerror" 15 "github.com/vektah/gqlparser/v2/parser" 16 17 "github.com/99designs/gqlgen/graphql" 18 "github.com/99designs/gqlgen/graphql/handler/testserver" 19 "github.com/99designs/gqlgen/graphql/handler/transport" 20 ) 21 22 func TestServer(t *testing.T) { 23 srv := testserver.New() 24 srv.AddTransport(&transport.GET{}) 25 26 t.Run("returns an error if no transport matches", func(t *testing.T) { 27 resp := post(srv, "/foo", "application/json") 28 assert.Equal(t, http.StatusBadRequest, resp.Code) 29 assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String()) 30 }) 31 32 t.Run("calls query on executable schema", func(t *testing.T) { 33 resp := get(srv, "/foo?query={name}") 34 assert.Equal(t, http.StatusOK, resp.Code) 35 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 36 }) 37 38 t.Run("mutations are forbidden", func(t *testing.T) { 39 resp := get(srv, "/foo?query=mutation{name}") 40 assert.Equal(t, http.StatusNotAcceptable, resp.Code) 41 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) 42 }) 43 44 t.Run("subscriptions are forbidden", func(t *testing.T) { 45 resp := get(srv, "/foo?query=subscription{name}") 46 assert.Equal(t, http.StatusNotAcceptable, resp.Code) 47 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) 48 }) 49 50 t.Run("invokes operation middleware in order", func(t *testing.T) { 51 var calls []string 52 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { 53 calls = append(calls, "first") 54 return next(ctx) 55 }) 56 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { 57 calls = append(calls, "second") 58 return next(ctx) 59 }) 60 61 resp := get(srv, "/foo?query={name}") 62 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 63 assert.Equal(t, []string{"first", "second"}, calls) 64 }) 65 66 t.Run("invokes response middleware in order", func(t *testing.T) { 67 var calls []string 68 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 69 calls = append(calls, "first") 70 return next(ctx) 71 }) 72 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 73 calls = append(calls, "second") 74 return next(ctx) 75 }) 76 77 resp := get(srv, "/foo?query={name}") 78 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 79 assert.Equal(t, []string{"first", "second"}, calls) 80 }) 81 82 t.Run("invokes field middleware in order", func(t *testing.T) { 83 var calls []string 84 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 85 calls = append(calls, "first") 86 return next(ctx) 87 }) 88 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 89 calls = append(calls, "second") 90 return next(ctx) 91 }) 92 93 resp := get(srv, "/foo?query={name}") 94 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 95 assert.Equal(t, []string{"first", "second"}, calls) 96 }) 97 98 t.Run("get query parse error in AroundResponses", func(t *testing.T) { 99 var errors1 gqlerror.List 100 var errors2 gqlerror.List 101 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 102 resp := next(ctx) 103 errors1 = graphql.GetErrors(ctx) 104 errors2 = resp.Errors 105 return resp 106 }) 107 108 resp := get(srv, "/foo?query=invalid") 109 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) 110 assert.Equal(t, 1, len(errors1)) 111 assert.Equal(t, 1, len(errors2)) 112 }) 113 114 t.Run("query caching", func(t *testing.T) { 115 ctx := context.Background() 116 cache := &graphql.MapCache{} 117 srv.SetQueryCache(cache) 118 qry := `query Foo {name}` 119 120 t.Run("cache miss populates cache", func(t *testing.T) { 121 resp := get(srv, "/foo?query="+url.QueryEscape(qry)) 122 assert.Equal(t, http.StatusOK, resp.Code) 123 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 124 125 cacheDoc, ok := cache.Get(ctx, qry) 126 require.True(t, ok) 127 require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name) 128 }) 129 130 t.Run("cache hits use document from cache", func(t *testing.T) { 131 doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`}) 132 require.Nil(t, err) 133 cache.Add(ctx, qry, doc) 134 135 resp := get(srv, "/foo?query="+url.QueryEscape(qry)) 136 assert.Equal(t, http.StatusOK, resp.Code) 137 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 138 139 cacheDoc, ok := cache.Get(ctx, qry) 140 require.True(t, ok) 141 require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name) 142 }) 143 }) 144 } 145 146 func TestErrorServer(t *testing.T) { 147 srv := testserver.NewError() 148 srv.AddTransport(&transport.GET{}) 149 150 t.Run("get resolver error in AroundResponses", func(t *testing.T) { 151 var errors1 gqlerror.List 152 var errors2 gqlerror.List 153 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 154 resp := next(ctx) 155 errors1 = graphql.GetErrors(ctx) 156 errors2 = resp.Errors 157 return resp 158 }) 159 160 resp := get(srv, "/foo?query={name}") 161 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 162 assert.Equal(t, 1, len(errors1)) 163 assert.Equal(t, 1, len(errors2)) 164 }) 165 } 166 167 type panicTransport struct{} 168 169 func (t panicTransport) Supports(r *http.Request) bool { 170 return true 171 } 172 173 func (t panicTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 174 panic(fmt.Errorf("panic in transport")) 175 } 176 177 func TestRecover(t *testing.T) { 178 srv := testserver.New() 179 srv.AddTransport(&panicTransport{}) 180 181 t.Run("recover from panic", func(t *testing.T) { 182 resp := get(srv, "/foo?query={name}") 183 184 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) 185 }) 186 } 187 188 func get(handler http.Handler, target string) *httptest.ResponseRecorder { 189 r := httptest.NewRequest("GET", target, nil) 190 w := httptest.NewRecorder() 191 192 handler.ServeHTTP(w, r) 193 return w 194 } 195 196 func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder { 197 r := httptest.NewRequest("POST", target, nil) 198 r.Header.Set("Content-Type", contentType) 199 w := httptest.NewRecorder() 200 201 handler.ServeHTTP(w, r) 202 return w 203 }