github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/complexity/complexity.go (about)

     1  package complexity
     2  
     3  import (
     4  	"github.com/mstephano/gqlgen-schemagen/graphql"
     5  	"github.com/vektah/gqlparser/v2/ast"
     6  )
     7  
     8  func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
     9  	walker := complexityWalker{
    10  		es:     es,
    11  		schema: es.Schema(),
    12  		vars:   vars,
    13  	}
    14  	return walker.selectionSetComplexity(op.SelectionSet)
    15  }
    16  
    17  type complexityWalker struct {
    18  	es     graphql.ExecutableSchema
    19  	schema *ast.Schema
    20  	vars   map[string]interface{}
    21  }
    22  
    23  func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
    24  	var complexity int
    25  	for _, selection := range selectionSet {
    26  		switch s := selection.(type) {
    27  		case *ast.Field:
    28  			fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
    29  
    30  			if fieldDefinition.Name == "__Schema" {
    31  				continue
    32  			}
    33  
    34  			var childComplexity int
    35  			switch fieldDefinition.Kind {
    36  			case ast.Object, ast.Interface, ast.Union:
    37  				childComplexity = cw.selectionSetComplexity(s.SelectionSet)
    38  			}
    39  
    40  			args := s.ArgumentMap(cw.vars)
    41  			var fieldComplexity int
    42  			if s.ObjectDefinition.Kind == ast.Interface {
    43  				fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
    44  			} else {
    45  				fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
    46  			}
    47  			complexity = safeAdd(complexity, fieldComplexity)
    48  
    49  		case *ast.FragmentSpread:
    50  			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
    51  
    52  		case *ast.InlineFragment:
    53  			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
    54  		}
    55  	}
    56  	return complexity
    57  }
    58  
    59  func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
    60  	// Interfaces don't have their own separate field costs, so they have to assume the worst case.
    61  	// We iterate over all implementors and choose the most expensive one.
    62  	maxComplexity := 0
    63  	implementors := cw.schema.GetPossibleTypes(def)
    64  	for _, t := range implementors {
    65  		fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
    66  		if fieldComplexity > maxComplexity {
    67  			maxComplexity = fieldComplexity
    68  		}
    69  	}
    70  	return maxComplexity
    71  }
    72  
    73  func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
    74  	if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
    75  		return customComplexity
    76  	}
    77  	// default complexity calculation
    78  	return safeAdd(1, childComplexity)
    79  }
    80  
    81  const maxInt = int(^uint(0) >> 1)
    82  
    83  // safeAdd is a saturating add of a and b that ignores negative operands.
    84  // If a + b would overflow through normal Go addition,
    85  // it returns the maximum integer value instead.
    86  //
    87  // Adding complexities with this function prevents attackers from intentionally
    88  // overflowing the complexity calculation to allow overly-complex queries.
    89  //
    90  // It also helps mitigate the impact of custom complexities that accidentally
    91  // return negative values.
    92  func safeAdd(a, b int) int {
    93  	// Ignore negative operands.
    94  	if a < 0 {
    95  		if b < 0 {
    96  			return 1
    97  		}
    98  		return b
    99  	} else if b < 0 {
   100  		return a
   101  	}
   102  
   103  	c := a + b
   104  	if c < a {
   105  		// Set c to maximum integer instead of overflowing.
   106  		c = maxInt
   107  	}
   108  	return c
   109  }