github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/server_test.go (about) 1 package handler_test 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "net/url" 9 "testing" 10 11 "github.com/99designs/gqlgen/graphql" 12 "github.com/99designs/gqlgen/graphql/handler/testserver" 13 "github.com/99designs/gqlgen/graphql/handler/transport" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 "github.com/vektah/gqlparser/v2/ast" 17 "github.com/vektah/gqlparser/v2/gqlerror" 18 "github.com/vektah/gqlparser/v2/parser" 19 ) 20 21 func TestServer(t *testing.T) { 22 srv := testserver.New() 23 srv.AddTransport(&transport.GET{}) 24 25 t.Run("returns an error if no transport matches", func(t *testing.T) { 26 resp := post(srv, "/foo", "application/json") 27 assert.Equal(t, http.StatusBadRequest, resp.Code) 28 assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String()) 29 }) 30 31 t.Run("calls query on executable schema", func(t *testing.T) { 32 resp := get(srv, "/foo?query={name}") 33 assert.Equal(t, http.StatusOK, resp.Code) 34 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 35 }) 36 37 t.Run("mutations are forbidden", func(t *testing.T) { 38 resp := get(srv, "/foo?query=mutation{name}") 39 assert.Equal(t, http.StatusNotAcceptable, resp.Code) 40 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) 41 }) 42 43 t.Run("subscriptions are forbidden", func(t *testing.T) { 44 resp := get(srv, "/foo?query=subscription{name}") 45 assert.Equal(t, http.StatusNotAcceptable, resp.Code) 46 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) 47 }) 48 49 t.Run("invokes operation middleware in order", func(t *testing.T) { 50 var calls []string 51 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { 52 calls = append(calls, "first") 53 return next(ctx) 54 }) 55 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { 56 calls = append(calls, "second") 57 return next(ctx) 58 }) 59 60 resp := get(srv, "/foo?query={name}") 61 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 62 assert.Equal(t, []string{"first", "second"}, calls) 63 }) 64 65 t.Run("invokes response middleware in order", func(t *testing.T) { 66 var calls []string 67 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 68 calls = append(calls, "first") 69 return next(ctx) 70 }) 71 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 72 calls = append(calls, "second") 73 return next(ctx) 74 }) 75 76 resp := get(srv, "/foo?query={name}") 77 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 78 assert.Equal(t, []string{"first", "second"}, calls) 79 }) 80 81 t.Run("invokes field middleware in order", func(t *testing.T) { 82 var calls []string 83 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 84 calls = append(calls, "first") 85 return next(ctx) 86 }) 87 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 88 calls = append(calls, "second") 89 return next(ctx) 90 }) 91 92 resp := get(srv, "/foo?query={name}") 93 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 94 assert.Equal(t, []string{"first", "second"}, calls) 95 }) 96 97 t.Run("get query parse error in AroundResponses", func(t *testing.T) { 98 var errors1 gqlerror.List 99 var errors2 gqlerror.List 100 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 101 resp := next(ctx) 102 errors1 = graphql.GetErrors(ctx) 103 errors2 = resp.Errors 104 return resp 105 }) 106 107 resp := get(srv, "/foo?query=invalid") 108 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) 109 assert.Equal(t, 1, len(errors1)) 110 assert.Equal(t, 1, len(errors2)) 111 }) 112 113 t.Run("query caching", func(t *testing.T) { 114 ctx := context.Background() 115 cache := &graphql.MapCache{} 116 srv.SetQueryCache(cache) 117 qry := `query Foo {name}` 118 119 t.Run("cache miss populates cache", func(t *testing.T) { 120 resp := get(srv, "/foo?query="+url.QueryEscape(qry)) 121 assert.Equal(t, http.StatusOK, resp.Code) 122 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 123 124 cacheDoc, ok := cache.Get(ctx, qry) 125 require.True(t, ok) 126 require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name) 127 }) 128 129 t.Run("cache hits use document from cache", func(t *testing.T) { 130 doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`}) 131 require.Nil(t, err) 132 cache.Add(ctx, qry, doc) 133 134 resp := get(srv, "/foo?query="+url.QueryEscape(qry)) 135 assert.Equal(t, http.StatusOK, resp.Code) 136 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 137 138 cacheDoc, ok := cache.Get(ctx, qry) 139 require.True(t, ok) 140 require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name) 141 }) 142 }) 143 } 144 145 func TestErrorServer(t *testing.T) { 146 srv := testserver.NewError() 147 srv.AddTransport(&transport.GET{}) 148 149 t.Run("get resolver error in AroundResponses", func(t *testing.T) { 150 var errors1 gqlerror.List 151 var errors2 gqlerror.List 152 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 153 resp := next(ctx) 154 errors1 = graphql.GetErrors(ctx) 155 errors2 = resp.Errors 156 return resp 157 }) 158 159 resp := get(srv, "/foo?query={name}") 160 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 161 assert.Equal(t, 1, len(errors1)) 162 assert.Equal(t, 1, len(errors2)) 163 }) 164 } 165 166 type panicTransport struct{} 167 168 func (t panicTransport) Supports(r *http.Request) bool { 169 return true 170 } 171 172 func (h panicTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 173 panic(fmt.Errorf("panic in transport")) 174 } 175 176 func TestRecover(t *testing.T) { 177 srv := testserver.New() 178 srv.AddTransport(&panicTransport{}) 179 180 t.Run("recover from panic", func(t *testing.T) { 181 resp := get(srv, "/foo?query={name}") 182 183 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) 184 }) 185 } 186 187 func get(handler http.Handler, target string) *httptest.ResponseRecorder { 188 r := httptest.NewRequest("GET", target, nil) 189 w := httptest.NewRecorder() 190 191 handler.ServeHTTP(w, r) 192 return w 193 } 194 195 func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder { 196 r := httptest.NewRequest("POST", target, nil) 197 r.Header.Set("Content-Type", contentType) 198 w := httptest.NewRecorder() 199 200 handler.ServeHTTP(w, r) 201 return w 202 }