git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/graphql/handler/server_test.go (about) 1 package handler_test 2 3 import ( 4 "context" 5 "net/http" 6 "net/http/httptest" 7 "net/url" 8 "testing" 9 10 "git.sr.ht/~sircmpwn/gqlgen/graphql" 11 "git.sr.ht/~sircmpwn/gqlgen/graphql/handler/testserver" 12 "git.sr.ht/~sircmpwn/gqlgen/graphql/handler/transport" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15 "github.com/vektah/gqlparser/v2/ast" 16 "github.com/vektah/gqlparser/v2/gqlerror" 17 "github.com/vektah/gqlparser/v2/parser" 18 ) 19 20 func TestServer(t *testing.T) { 21 srv := testserver.New() 22 srv.AddTransport(&transport.GET{}) 23 24 t.Run("returns an error if no transport matches", func(t *testing.T) { 25 resp := post(srv, "/foo", "application/json") 26 assert.Equal(t, http.StatusBadRequest, resp.Code) 27 assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String()) 28 }) 29 30 t.Run("calls query on executable schema", func(t *testing.T) { 31 resp := get(srv, "/foo?query={name}") 32 assert.Equal(t, http.StatusOK, resp.Code) 33 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 34 }) 35 36 t.Run("mutations are forbidden", func(t *testing.T) { 37 resp := get(srv, "/foo?query=mutation{name}") 38 assert.Equal(t, http.StatusNotAcceptable, resp.Code) 39 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) 40 }) 41 42 t.Run("subscriptions are forbidden", func(t *testing.T) { 43 resp := get(srv, "/foo?query=subscription{name}") 44 assert.Equal(t, http.StatusNotAcceptable, resp.Code) 45 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) 46 }) 47 48 t.Run("invokes operation middleware in order", func(t *testing.T) { 49 var calls []string 50 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { 51 calls = append(calls, "first") 52 return next(ctx) 53 }) 54 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { 55 calls = append(calls, "second") 56 return next(ctx) 57 }) 58 59 resp := get(srv, "/foo?query={name}") 60 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 61 assert.Equal(t, []string{"first", "second"}, calls) 62 }) 63 64 t.Run("invokes response middleware in order", func(t *testing.T) { 65 var calls []string 66 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 67 calls = append(calls, "first") 68 return next(ctx) 69 }) 70 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 71 calls = append(calls, "second") 72 return next(ctx) 73 }) 74 75 resp := get(srv, "/foo?query={name}") 76 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 77 assert.Equal(t, []string{"first", "second"}, calls) 78 }) 79 80 t.Run("invokes field middleware in order", func(t *testing.T) { 81 var calls []string 82 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 83 calls = append(calls, "first") 84 return next(ctx) 85 }) 86 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 87 calls = append(calls, "second") 88 return next(ctx) 89 }) 90 91 resp := get(srv, "/foo?query={name}") 92 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 93 assert.Equal(t, []string{"first", "second"}, calls) 94 }) 95 96 t.Run("get query parse error in AroundResponses", func(t *testing.T) { 97 var errors1 gqlerror.List 98 var errors2 gqlerror.List 99 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 100 resp := next(ctx) 101 errors1 = graphql.GetErrors(ctx) 102 errors2 = resp.Errors 103 return resp 104 }) 105 106 resp := get(srv, "/foo?query=invalid") 107 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) 108 assert.Equal(t, 1, len(errors1)) 109 assert.Equal(t, 1, len(errors2)) 110 }) 111 112 t.Run("query caching", func(t *testing.T) { 113 ctx := context.Background() 114 cache := &graphql.MapCache{} 115 srv.SetQueryCache(cache) 116 qry := `query Foo {name}` 117 118 t.Run("cache miss populates cache", func(t *testing.T) { 119 resp := get(srv, "/foo?query="+url.QueryEscape(qry)) 120 assert.Equal(t, http.StatusOK, resp.Code) 121 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 122 123 cacheDoc, ok := cache.Get(ctx, qry) 124 require.True(t, ok) 125 require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name) 126 }) 127 128 t.Run("cache hits use document from cache", func(t *testing.T) { 129 doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`}) 130 require.Nil(t, err) 131 cache.Add(ctx, qry, doc) 132 133 resp := get(srv, "/foo?query="+url.QueryEscape(qry)) 134 assert.Equal(t, http.StatusOK, resp.Code) 135 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 136 137 cacheDoc, ok := cache.Get(ctx, qry) 138 require.True(t, ok) 139 require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name) 140 }) 141 }) 142 } 143 144 func TestErrorServer(t *testing.T) { 145 srv := testserver.NewError() 146 srv.AddTransport(&transport.GET{}) 147 148 t.Run("get resolver error in AroundResponses", func(t *testing.T) { 149 var errors1 gqlerror.List 150 var errors2 gqlerror.List 151 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 152 resp := next(ctx) 153 errors1 = graphql.GetErrors(ctx) 154 errors2 = resp.Errors 155 return resp 156 }) 157 158 resp := get(srv, "/foo?query={name}") 159 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 160 assert.Equal(t, 1, len(errors1)) 161 assert.Equal(t, 1, len(errors2)) 162 }) 163 } 164 165 func get(handler http.Handler, target string) *httptest.ResponseRecorder { 166 r := httptest.NewRequest("GET", target, nil) 167 w := httptest.NewRecorder() 168 169 handler.ServeHTTP(w, r) 170 return w 171 } 172 173 func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder { 174 r := httptest.NewRequest("POST", target, nil) 175 r.Header.Set("Content-Type", contentType) 176 w := httptest.NewRecorder() 177 178 handler.ServeHTTP(w, r) 179 return w 180 }