github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/server_test.go (about)

     1  package handler_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"testing"
    10  
    11  	"github.com/99designs/gqlgen/graphql"
    12  	"github.com/99designs/gqlgen/graphql/handler/testserver"
    13  	"github.com/99designs/gqlgen/graphql/handler/transport"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  	"github.com/vektah/gqlparser/v2/ast"
    17  	"github.com/vektah/gqlparser/v2/gqlerror"
    18  	"github.com/vektah/gqlparser/v2/parser"
    19  )
    20  
    21  func TestServer(t *testing.T) {
    22  	srv := testserver.New()
    23  	srv.AddTransport(&transport.GET{})
    24  
    25  	t.Run("returns an error if no transport matches", func(t *testing.T) {
    26  		resp := post(srv, "/foo", "application/json")
    27  		assert.Equal(t, http.StatusBadRequest, resp.Code)
    28  		assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String())
    29  	})
    30  
    31  	t.Run("calls query on executable schema", func(t *testing.T) {
    32  		resp := get(srv, "/foo?query={name}")
    33  		assert.Equal(t, http.StatusOK, resp.Code)
    34  		assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    35  	})
    36  
    37  	t.Run("mutations are forbidden", func(t *testing.T) {
    38  		resp := get(srv, "/foo?query=mutation{name}")
    39  		assert.Equal(t, http.StatusNotAcceptable, resp.Code)
    40  		assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
    41  	})
    42  
    43  	t.Run("subscriptions are forbidden", func(t *testing.T) {
    44  		resp := get(srv, "/foo?query=subscription{name}")
    45  		assert.Equal(t, http.StatusNotAcceptable, resp.Code)
    46  		assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
    47  	})
    48  
    49  	t.Run("invokes operation middleware in order", func(t *testing.T) {
    50  		var calls []string
    51  		srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
    52  			calls = append(calls, "first")
    53  			return next(ctx)
    54  		})
    55  		srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
    56  			calls = append(calls, "second")
    57  			return next(ctx)
    58  		})
    59  
    60  		resp := get(srv, "/foo?query={name}")
    61  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    62  		assert.Equal(t, []string{"first", "second"}, calls)
    63  	})
    64  
    65  	t.Run("invokes response middleware in order", func(t *testing.T) {
    66  		var calls []string
    67  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    68  			calls = append(calls, "first")
    69  			return next(ctx)
    70  		})
    71  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    72  			calls = append(calls, "second")
    73  			return next(ctx)
    74  		})
    75  
    76  		resp := get(srv, "/foo?query={name}")
    77  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    78  		assert.Equal(t, []string{"first", "second"}, calls)
    79  	})
    80  
    81  	t.Run("invokes field middleware in order", func(t *testing.T) {
    82  		var calls []string
    83  		srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    84  			calls = append(calls, "first")
    85  			return next(ctx)
    86  		})
    87  		srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    88  			calls = append(calls, "second")
    89  			return next(ctx)
    90  		})
    91  
    92  		resp := get(srv, "/foo?query={name}")
    93  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    94  		assert.Equal(t, []string{"first", "second"}, calls)
    95  	})
    96  
    97  	t.Run("get query parse error in AroundResponses", func(t *testing.T) {
    98  		var errors1 gqlerror.List
    99  		var errors2 gqlerror.List
   100  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   101  			resp := next(ctx)
   102  			errors1 = graphql.GetErrors(ctx)
   103  			errors2 = resp.Errors
   104  			return resp
   105  		})
   106  
   107  		resp := get(srv, "/foo?query=invalid")
   108  		assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
   109  		assert.Equal(t, 1, len(errors1))
   110  		assert.Equal(t, 1, len(errors2))
   111  	})
   112  
   113  	t.Run("query caching", func(t *testing.T) {
   114  		ctx := context.Background()
   115  		cache := &graphql.MapCache{}
   116  		srv.SetQueryCache(cache)
   117  		qry := `query Foo {name}`
   118  
   119  		t.Run("cache miss populates cache", func(t *testing.T) {
   120  			resp := get(srv, "/foo?query="+url.QueryEscape(qry))
   121  			assert.Equal(t, http.StatusOK, resp.Code)
   122  			assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   123  
   124  			cacheDoc, ok := cache.Get(ctx, qry)
   125  			require.True(t, ok)
   126  			require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
   127  		})
   128  
   129  		t.Run("cache hits use document from cache", func(t *testing.T) {
   130  			doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`})
   131  			require.Nil(t, err)
   132  			cache.Add(ctx, qry, doc)
   133  
   134  			resp := get(srv, "/foo?query="+url.QueryEscape(qry))
   135  			assert.Equal(t, http.StatusOK, resp.Code)
   136  			assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   137  
   138  			cacheDoc, ok := cache.Get(ctx, qry)
   139  			require.True(t, ok)
   140  			require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
   141  		})
   142  	})
   143  }
   144  
   145  func TestErrorServer(t *testing.T) {
   146  	srv := testserver.NewError()
   147  	srv.AddTransport(&transport.GET{})
   148  
   149  	t.Run("get resolver error in AroundResponses", func(t *testing.T) {
   150  		var errors1 gqlerror.List
   151  		var errors2 gqlerror.List
   152  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   153  			resp := next(ctx)
   154  			errors1 = graphql.GetErrors(ctx)
   155  			errors2 = resp.Errors
   156  			return resp
   157  		})
   158  
   159  		resp := get(srv, "/foo?query={name}")
   160  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
   161  		assert.Equal(t, 1, len(errors1))
   162  		assert.Equal(t, 1, len(errors2))
   163  	})
   164  }
   165  
   166  type panicTransport struct{}
   167  
   168  func (t panicTransport) Supports(r *http.Request) bool {
   169  	return true
   170  }
   171  
   172  func (h panicTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
   173  	panic(fmt.Errorf("panic in transport"))
   174  }
   175  
   176  func TestRecover(t *testing.T) {
   177  	srv := testserver.New()
   178  	srv.AddTransport(&panicTransport{})
   179  
   180  	t.Run("recover from panic", func(t *testing.T) {
   181  		resp := get(srv, "/foo?query={name}")
   182  
   183  		assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
   184  	})
   185  }
   186  
   187  func get(handler http.Handler, target string) *httptest.ResponseRecorder {
   188  	r := httptest.NewRequest("GET", target, nil)
   189  	w := httptest.NewRecorder()
   190  
   191  	handler.ServeHTTP(w, r)
   192  	return w
   193  }
   194  
   195  func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder {
   196  	r := httptest.NewRequest("POST", target, nil)
   197  	r.Header.Set("Content-Type", contentType)
   198  	w := httptest.NewRecorder()
   199  
   200  	handler.ServeHTTP(w, r)
   201  	return w
   202  }