github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/complexity/complexity.go (about)

     1  package complexity
     2  
     3  import (
     4  	"github.com/vektah/gqlparser/v2/ast"
     5  
     6  	"github.com/niko0xdev/gqlgen/graphql"
     7  )
     8  
     9  func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
    10  	walker := complexityWalker{
    11  		es:     es,
    12  		schema: es.Schema(),
    13  		vars:   vars,
    14  	}
    15  	return walker.selectionSetComplexity(op.SelectionSet)
    16  }
    17  
    18  type complexityWalker struct {
    19  	es     graphql.ExecutableSchema
    20  	schema *ast.Schema
    21  	vars   map[string]interface{}
    22  }
    23  
    24  func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
    25  	var complexity int
    26  	for _, selection := range selectionSet {
    27  		switch s := selection.(type) {
    28  		case *ast.Field:
    29  			fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
    30  
    31  			if fieldDefinition.Name == "__Schema" {
    32  				continue
    33  			}
    34  
    35  			var childComplexity int
    36  			switch fieldDefinition.Kind {
    37  			case ast.Object, ast.Interface, ast.Union:
    38  				childComplexity = cw.selectionSetComplexity(s.SelectionSet)
    39  			}
    40  
    41  			args := s.ArgumentMap(cw.vars)
    42  			var fieldComplexity int
    43  			if s.ObjectDefinition.Kind == ast.Interface {
    44  				fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
    45  			} else {
    46  				fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
    47  			}
    48  			complexity = safeAdd(complexity, fieldComplexity)
    49  
    50  		case *ast.FragmentSpread:
    51  			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
    52  
    53  		case *ast.InlineFragment:
    54  			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
    55  		}
    56  	}
    57  	return complexity
    58  }
    59  
    60  func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
    61  	// Interfaces don't have their own separate field costs, so they have to assume the worst case.
    62  	// We iterate over all implementors and choose the most expensive one.
    63  	maxComplexity := 0
    64  	implementors := cw.schema.GetPossibleTypes(def)
    65  	for _, t := range implementors {
    66  		fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
    67  		if fieldComplexity > maxComplexity {
    68  			maxComplexity = fieldComplexity
    69  		}
    70  	}
    71  	return maxComplexity
    72  }
    73  
    74  func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
    75  	if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
    76  		return customComplexity
    77  	}
    78  	// default complexity calculation
    79  	return safeAdd(1, childComplexity)
    80  }
    81  
    82  const maxInt = int(^uint(0) >> 1)
    83  
    84  // safeAdd is a saturating add of a and b that ignores negative operands.
    85  // If a + b would overflow through normal Go addition,
    86  // it returns the maximum integer value instead.
    87  //
    88  // Adding complexities with this function prevents attackers from intentionally
    89  // overflowing the complexity calculation to allow overly-complex queries.
    90  //
    91  // It also helps mitigate the impact of custom complexities that accidentally
    92  // return negative values.
    93  func safeAdd(a, b int) int {
    94  	// Ignore negative operands.
    95  	if a < 0 {
    96  		if b < 0 {
    97  			return 1
    98  		}
    99  		return b
   100  	} else if b < 0 {
   101  		return a
   102  	}
   103  
   104  	c := a + b
   105  	if c < a {
   106  		// Set c to maximum integer instead of overflowing.
   107  		c = maxInt
   108  	}
   109  	return c
   110  }