github.com/deliveroo/gqlgen@v0.7.2/complexity/complexity.go (about)

     1  package complexity
     2  
     3  import (
     4  	"github.com/99designs/gqlgen/graphql"
     5  	"github.com/vektah/gqlparser/ast"
     6  )
     7  
     8  func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
     9  	walker := complexityWalker{
    10  		es:     es,
    11  		schema: es.Schema(),
    12  		vars:   vars,
    13  	}
    14  	return walker.selectionSetComplexity(op.SelectionSet)
    15  }
    16  
    17  type complexityWalker struct {
    18  	es     graphql.ExecutableSchema
    19  	schema *ast.Schema
    20  	vars   map[string]interface{}
    21  }
    22  
    23  func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
    24  	var complexity int
    25  	for _, selection := range selectionSet {
    26  		switch s := selection.(type) {
    27  		case *ast.Field:
    28  			fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
    29  			var childComplexity int
    30  			switch fieldDefinition.Kind {
    31  			case ast.Object, ast.Interface, ast.Union:
    32  				childComplexity = cw.selectionSetComplexity(s.SelectionSet)
    33  			}
    34  
    35  			args := s.ArgumentMap(cw.vars)
    36  			var fieldComplexity int
    37  			if s.ObjectDefinition.Kind == ast.Interface {
    38  				fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
    39  			} else {
    40  				fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
    41  			}
    42  			complexity = safeAdd(complexity, fieldComplexity)
    43  
    44  		case *ast.FragmentSpread:
    45  			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
    46  
    47  		case *ast.InlineFragment:
    48  			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
    49  		}
    50  	}
    51  	return complexity
    52  }
    53  
    54  func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
    55  	// Interfaces don't have their own separate field costs, so they have to assume the worst case.
    56  	// We iterate over all implementors and choose the most expensive one.
    57  	maxComplexity := 0
    58  	implementors := cw.schema.GetPossibleTypes(def)
    59  	for _, t := range implementors {
    60  		fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
    61  		if fieldComplexity > maxComplexity {
    62  			maxComplexity = fieldComplexity
    63  		}
    64  	}
    65  	return maxComplexity
    66  }
    67  
    68  func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
    69  	if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
    70  		return customComplexity
    71  	}
    72  	// default complexity calculation
    73  	return safeAdd(1, childComplexity)
    74  }
    75  
    76  const maxInt = int(^uint(0) >> 1)
    77  
    78  // safeAdd is a saturating add of a and b that ignores negative operands.
    79  // If a + b would overflow through normal Go addition,
    80  // it returns the maximum integer value instead.
    81  //
    82  // Adding complexities with this function prevents attackers from intentionally
    83  // overflowing the complexity calculation to allow overly-complex queries.
    84  //
    85  // It also helps mitigate the impact of custom complexities that accidentally
    86  // return negative values.
    87  func safeAdd(a, b int) int {
    88  	// Ignore negative operands.
    89  	if a < 0 {
    90  		if b < 0 {
    91  			return 1
    92  		}
    93  		return b
    94  	} else if b < 0 {
    95  		return a
    96  	}
    97  
    98  	c := a + b
    99  	if c < a {
   100  		// Set c to maximum integer instead of overflowing.
   101  		c = maxInt
   102  	}
   103  	return c
   104  }