github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/extension/complexity_test.go (about)

     1  package extension_test
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"strings"
     8  	"testing"
     9  
    10  	"github.com/99designs/gqlgen/graphql"
    11  	"github.com/99designs/gqlgen/graphql/handler/extension"
    12  	"github.com/99designs/gqlgen/graphql/handler/testserver"
    13  	"github.com/99designs/gqlgen/graphql/handler/transport"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestHandlerComplexity(t *testing.T) {
    18  	h := testserver.New()
    19  	h.Use(&extension.ComplexityLimit{
    20  		Func: func(ctx context.Context, rc *graphql.OperationContext) int {
    21  			if rc.RawQuery == "{ ok: name }" {
    22  				return 4
    23  			}
    24  			return 2
    25  		},
    26  	})
    27  	h.AddTransport(&transport.POST{})
    28  	var stats *extension.ComplexityStats
    29  	h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    30  		stats = extension.GetComplexityStats(ctx)
    31  		return next(ctx)
    32  	})
    33  
    34  	t.Run("below complexity limit", func(t *testing.T) {
    35  		stats = nil
    36  		h.SetCalculatedComplexity(2)
    37  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    38  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    39  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    40  
    41  		require.Equal(t, 2, stats.ComplexityLimit)
    42  		require.Equal(t, 2, stats.Complexity)
    43  	})
    44  
    45  	t.Run("above complexity limit", func(t *testing.T) {
    46  		stats = nil
    47  		h.SetCalculatedComplexity(4)
    48  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    49  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    50  		require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String())
    51  
    52  		require.Equal(t, 2, stats.ComplexityLimit)
    53  		require.Equal(t, 4, stats.Complexity)
    54  	})
    55  
    56  	t.Run("within dynamic complexity limit", func(t *testing.T) {
    57  		stats = nil
    58  		h.SetCalculatedComplexity(4)
    59  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ ok: name }"}`)
    60  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    61  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    62  
    63  		require.Equal(t, 4, stats.ComplexityLimit)
    64  		require.Equal(t, 4, stats.Complexity)
    65  	})
    66  }
    67  
    68  func TestFixedComplexity(t *testing.T) {
    69  	h := testserver.New()
    70  	h.Use(extension.FixedComplexityLimit(2))
    71  	h.AddTransport(&transport.POST{})
    72  
    73  	var stats *extension.ComplexityStats
    74  	h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    75  		stats = extension.GetComplexityStats(ctx)
    76  		return next(ctx)
    77  	})
    78  
    79  	t.Run("below complexity limit", func(t *testing.T) {
    80  		h.SetCalculatedComplexity(2)
    81  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    82  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    83  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    84  
    85  		require.Equal(t, 2, stats.ComplexityLimit)
    86  		require.Equal(t, 2, stats.Complexity)
    87  	})
    88  
    89  	t.Run("above complexity limit", func(t *testing.T) {
    90  		h.SetCalculatedComplexity(4)
    91  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    92  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    93  		require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String())
    94  
    95  		require.Equal(t, 2, stats.ComplexityLimit)
    96  		require.Equal(t, 4, stats.Complexity)
    97  	})
    98  }
    99  
   100  func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder {
   101  	r := httptest.NewRequest(method, target, strings.NewReader(body))
   102  	r.Header.Set("Content-Type", "application/json")
   103  	w := httptest.NewRecorder()
   104  
   105  	handler.ServeHTTP(w, r)
   106  	return w
   107  }