git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/graphql/handler/server_test.go (about)

     1  package handler_test
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"testing"
     9  
    10  	"git.sr.ht/~sircmpwn/gqlgen/graphql"
    11  	"git.sr.ht/~sircmpwn/gqlgen/graphql/handler/testserver"
    12  	"git.sr.ht/~sircmpwn/gqlgen/graphql/handler/transport"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  	"github.com/vektah/gqlparser/v2/ast"
    16  	"github.com/vektah/gqlparser/v2/gqlerror"
    17  	"github.com/vektah/gqlparser/v2/parser"
    18  )
    19  
    20  func TestServer(t *testing.T) {
    21  	srv := testserver.New()
    22  	srv.AddTransport(&transport.GET{})
    23  
    24  	t.Run("returns an error if no transport matches", func(t *testing.T) {
    25  		resp := post(srv, "/foo", "application/json")
    26  		assert.Equal(t, http.StatusBadRequest, resp.Code)
    27  		assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String())
    28  	})
    29  
    30  	t.Run("calls query on executable schema", func(t *testing.T) {
    31  		resp := get(srv, "/foo?query={name}")
    32  		assert.Equal(t, http.StatusOK, resp.Code)
    33  		assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    34  	})
    35  
    36  	t.Run("mutations are forbidden", func(t *testing.T) {
    37  		resp := get(srv, "/foo?query=mutation{name}")
    38  		assert.Equal(t, http.StatusNotAcceptable, resp.Code)
    39  		assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
    40  	})
    41  
    42  	t.Run("subscriptions are forbidden", func(t *testing.T) {
    43  		resp := get(srv, "/foo?query=subscription{name}")
    44  		assert.Equal(t, http.StatusNotAcceptable, resp.Code)
    45  		assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
    46  	})
    47  
    48  	t.Run("invokes operation middleware in order", func(t *testing.T) {
    49  		var calls []string
    50  		srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
    51  			calls = append(calls, "first")
    52  			return next(ctx)
    53  		})
    54  		srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
    55  			calls = append(calls, "second")
    56  			return next(ctx)
    57  		})
    58  
    59  		resp := get(srv, "/foo?query={name}")
    60  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    61  		assert.Equal(t, []string{"first", "second"}, calls)
    62  	})
    63  
    64  	t.Run("invokes response middleware in order", func(t *testing.T) {
    65  		var calls []string
    66  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    67  			calls = append(calls, "first")
    68  			return next(ctx)
    69  		})
    70  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    71  			calls = append(calls, "second")
    72  			return next(ctx)
    73  		})
    74  
    75  		resp := get(srv, "/foo?query={name}")
    76  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    77  		assert.Equal(t, []string{"first", "second"}, calls)
    78  	})
    79  
    80  	t.Run("invokes field middleware in order", func(t *testing.T) {
    81  		var calls []string
    82  		srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    83  			calls = append(calls, "first")
    84  			return next(ctx)
    85  		})
    86  		srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    87  			calls = append(calls, "second")
    88  			return next(ctx)
    89  		})
    90  
    91  		resp := get(srv, "/foo?query={name}")
    92  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    93  		assert.Equal(t, []string{"first", "second"}, calls)
    94  	})
    95  
    96  	t.Run("get query parse error in AroundResponses", func(t *testing.T) {
    97  		var errors1 gqlerror.List
    98  		var errors2 gqlerror.List
    99  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   100  			resp := next(ctx)
   101  			errors1 = graphql.GetErrors(ctx)
   102  			errors2 = resp.Errors
   103  			return resp
   104  		})
   105  
   106  		resp := get(srv, "/foo?query=invalid")
   107  		assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
   108  		assert.Equal(t, 1, len(errors1))
   109  		assert.Equal(t, 1, len(errors2))
   110  	})
   111  
   112  	t.Run("query caching", func(t *testing.T) {
   113  		ctx := context.Background()
   114  		cache := &graphql.MapCache{}
   115  		srv.SetQueryCache(cache)
   116  		qry := `query Foo {name}`
   117  
   118  		t.Run("cache miss populates cache", func(t *testing.T) {
   119  			resp := get(srv, "/foo?query="+url.QueryEscape(qry))
   120  			assert.Equal(t, http.StatusOK, resp.Code)
   121  			assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   122  
   123  			cacheDoc, ok := cache.Get(ctx, qry)
   124  			require.True(t, ok)
   125  			require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
   126  		})
   127  
   128  		t.Run("cache hits use document from cache", func(t *testing.T) {
   129  			doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`})
   130  			require.Nil(t, err)
   131  			cache.Add(ctx, qry, doc)
   132  
   133  			resp := get(srv, "/foo?query="+url.QueryEscape(qry))
   134  			assert.Equal(t, http.StatusOK, resp.Code)
   135  			assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   136  
   137  			cacheDoc, ok := cache.Get(ctx, qry)
   138  			require.True(t, ok)
   139  			require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
   140  		})
   141  	})
   142  }
   143  
   144  func TestErrorServer(t *testing.T) {
   145  	srv := testserver.NewError()
   146  	srv.AddTransport(&transport.GET{})
   147  
   148  	t.Run("get resolver error in AroundResponses", func(t *testing.T) {
   149  		var errors1 gqlerror.List
   150  		var errors2 gqlerror.List
   151  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   152  			resp := next(ctx)
   153  			errors1 = graphql.GetErrors(ctx)
   154  			errors2 = resp.Errors
   155  			return resp
   156  		})
   157  
   158  		resp := get(srv, "/foo?query={name}")
   159  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
   160  		assert.Equal(t, 1, len(errors1))
   161  		assert.Equal(t, 1, len(errors2))
   162  	})
   163  }
   164  
   165  func get(handler http.Handler, target string) *httptest.ResponseRecorder {
   166  	r := httptest.NewRequest("GET", target, nil)
   167  	w := httptest.NewRecorder()
   168  
   169  	handler.ServeHTTP(w, r)
   170  	return w
   171  }
   172  
   173  func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder {
   174  	r := httptest.NewRequest("POST", target, nil)
   175  	r.Header.Set("Content-Type", contentType)
   176  	w := httptest.NewRecorder()
   177  
   178  	handler.ServeHTTP(w, r)
   179  	return w
   180  }