github.com/99designs/gqlgen@v0.17.45/graphql/handler/extension/complexity_test.go (about) 1 package extension_test 2 3 import ( 4 "context" 5 "net/http" 6 "net/http/httptest" 7 "strings" 8 "testing" 9 10 "github.com/stretchr/testify/require" 11 12 "github.com/99designs/gqlgen/graphql" 13 "github.com/99designs/gqlgen/graphql/handler/extension" 14 "github.com/99designs/gqlgen/graphql/handler/testserver" 15 "github.com/99designs/gqlgen/graphql/handler/transport" 16 ) 17 18 func TestHandlerComplexity(t *testing.T) { 19 h := testserver.New() 20 h.Use(&extension.ComplexityLimit{ 21 Func: func(ctx context.Context, rc *graphql.OperationContext) int { 22 if rc.RawQuery == "{ ok: name }" { 23 return 4 24 } 25 return 2 26 }, 27 }) 28 h.AddTransport(&transport.POST{}) 29 var stats *extension.ComplexityStats 30 h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 31 stats = extension.GetComplexityStats(ctx) 32 return next(ctx) 33 }) 34 35 t.Run("below complexity limit", func(t *testing.T) { 36 stats = nil 37 h.SetCalculatedComplexity(2) 38 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) 39 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 40 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 41 42 require.Equal(t, 2, stats.ComplexityLimit) 43 require.Equal(t, 2, stats.Complexity) 44 }) 45 46 t.Run("above complexity limit", func(t *testing.T) { 47 stats = nil 48 h.SetCalculatedComplexity(4) 49 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) 50 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 51 require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String()) 52 53 require.Equal(t, 2, stats.ComplexityLimit) 54 require.Equal(t, 4, stats.Complexity) 55 }) 56 57 t.Run("within dynamic complexity limit", func(t *testing.T) { 58 stats = nil 59 h.SetCalculatedComplexity(4) 60 resp := doRequest(h, "POST", "/graphql", `{"query":"{ ok: name }"}`) 61 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 62 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 63 64 require.Equal(t, 4, stats.ComplexityLimit) 65 require.Equal(t, 4, stats.Complexity) 66 }) 67 } 68 69 func TestFixedComplexity(t *testing.T) { 70 h := testserver.New() 71 h.Use(extension.FixedComplexityLimit(2)) 72 h.AddTransport(&transport.POST{}) 73 74 var stats *extension.ComplexityStats 75 h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 76 stats = extension.GetComplexityStats(ctx) 77 return next(ctx) 78 }) 79 80 t.Run("below complexity limit", func(t *testing.T) { 81 h.SetCalculatedComplexity(2) 82 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) 83 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 84 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 85 86 require.Equal(t, 2, stats.ComplexityLimit) 87 require.Equal(t, 2, stats.Complexity) 88 }) 89 90 t.Run("above complexity limit", func(t *testing.T) { 91 h.SetCalculatedComplexity(4) 92 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) 93 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 94 require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String()) 95 96 require.Equal(t, 2, stats.ComplexityLimit) 97 require.Equal(t, 4, stats.Complexity) 98 }) 99 100 t.Run("bypass __schema field", func(t *testing.T) { 101 h.SetCalculatedComplexity(4) 102 resp := doRequest(h, "POST", "/graphql", `{ "operationName":"IntrospectionQuery", "query":"query IntrospectionQuery { __schema { queryType { name } mutationType { name }}}"}`) 103 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) 104 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) 105 106 require.Equal(t, 2, stats.ComplexityLimit) 107 require.Equal(t, 0, stats.Complexity) 108 }) 109 } 110 111 func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder { 112 r := httptest.NewRequest(method, target, strings.NewReader(body)) 113 r.Header.Set("Content-Type", "application/json") 114 w := httptest.NewRecorder() 115 116 handler.ServeHTTP(w, r) 117 return w 118 }