github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/extension/complexity.go (about)

     1  package extension
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/99designs/gqlgen/complexity"
     8  	"github.com/99designs/gqlgen/graphql"
     9  	"github.com/99designs/gqlgen/graphql/errcode"
    10  	"github.com/vektah/gqlparser/v2/gqlerror"
    11  )
    12  
    13  const errComplexityLimit = "COMPLEXITY_LIMIT_EXCEEDED"
    14  
    15  // ComplexityLimit allows you to define a limit on query complexity
    16  //
    17  // If a query is submitted that exceeds the limit, a 422 status code will be returned.
    18  type ComplexityLimit struct {
    19  	Func func(ctx context.Context, rc *graphql.OperationContext) int
    20  
    21  	es graphql.ExecutableSchema
    22  }
    23  
    24  var _ interface {
    25  	graphql.OperationContextMutator
    26  	graphql.HandlerExtension
    27  } = &ComplexityLimit{}
    28  
    29  const complexityExtension = "ComplexityLimit"
    30  
    31  type ComplexityStats struct {
    32  	// The calculated complexity for this request
    33  	Complexity int
    34  
    35  	// The complexity limit for this request returned by the extension func
    36  	ComplexityLimit int
    37  }
    38  
    39  // FixedComplexityLimit sets a complexity limit that does not change
    40  func FixedComplexityLimit(limit int) *ComplexityLimit {
    41  	return &ComplexityLimit{
    42  		Func: func(ctx context.Context, rc *graphql.OperationContext) int {
    43  			return limit
    44  		},
    45  	}
    46  }
    47  
    48  func (c ComplexityLimit) ExtensionName() string {
    49  	return complexityExtension
    50  }
    51  
    52  func (c *ComplexityLimit) Validate(schema graphql.ExecutableSchema) error {
    53  	if c.Func == nil {
    54  		return fmt.Errorf("ComplexityLimit func can not be nil")
    55  	}
    56  	c.es = schema
    57  	return nil
    58  }
    59  
    60  func (c ComplexityLimit) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
    61  	op := rc.Doc.Operations.ForName(rc.OperationName)
    62  	complexity := complexity.Calculate(c.es, op, rc.Variables)
    63  
    64  	limit := c.Func(ctx, rc)
    65  
    66  	rc.Stats.SetExtension(complexityExtension, &ComplexityStats{
    67  		Complexity:      complexity,
    68  		ComplexityLimit: limit,
    69  	})
    70  
    71  	if complexity > limit {
    72  		err := gqlerror.Errorf("operation has complexity %d, which exceeds the limit of %d", complexity, limit)
    73  		errcode.Set(err, errComplexityLimit)
    74  		return err
    75  	}
    76  
    77  	return nil
    78  }
    79  
    80  func GetComplexityStats(ctx context.Context) *ComplexityStats {
    81  	rc := graphql.GetOperationContext(ctx)
    82  	if rc == nil {
    83  		return nil
    84  	}
    85  
    86  	s, _ := rc.Stats.GetExtension(complexityExtension).(*ComplexityStats)
    87  	return s
    88  }