github.com/renderinc/gqlgen@v0.7.2/complexity/complexity.go (about) 1 package complexity 2 3 import ( 4 "github.com/99designs/gqlgen/graphql" 5 "github.com/vektah/gqlparser/ast" 6 ) 7 8 func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int { 9 walker := complexityWalker{ 10 es: es, 11 schema: es.Schema(), 12 vars: vars, 13 } 14 return walker.selectionSetComplexity(op.SelectionSet) 15 } 16 17 type complexityWalker struct { 18 es graphql.ExecutableSchema 19 schema *ast.Schema 20 vars map[string]interface{} 21 } 22 23 func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int { 24 var complexity int 25 for _, selection := range selectionSet { 26 switch s := selection.(type) { 27 case *ast.Field: 28 fieldDefinition := cw.schema.Types[s.Definition.Type.Name()] 29 var childComplexity int 30 switch fieldDefinition.Kind { 31 case ast.Object, ast.Interface, ast.Union: 32 childComplexity = cw.selectionSetComplexity(s.SelectionSet) 33 } 34 35 args := s.ArgumentMap(cw.vars) 36 var fieldComplexity int 37 if s.ObjectDefinition.Kind == ast.Interface { 38 fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args) 39 } else { 40 fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args) 41 } 42 complexity = safeAdd(complexity, fieldComplexity) 43 44 case *ast.FragmentSpread: 45 complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet)) 46 47 case *ast.InlineFragment: 48 complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet)) 49 } 50 } 51 return complexity 52 } 53 54 func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int { 55 // Interfaces don't have their own separate field costs, so they have to assume the worst case. 56 // We iterate over all implementors and choose the most expensive one. 57 maxComplexity := 0 58 implementors := cw.schema.GetPossibleTypes(def) 59 for _, t := range implementors { 60 fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args) 61 if fieldComplexity > maxComplexity { 62 maxComplexity = fieldComplexity 63 } 64 } 65 return maxComplexity 66 } 67 68 func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int { 69 if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity { 70 return customComplexity 71 } 72 // default complexity calculation 73 return safeAdd(1, childComplexity) 74 } 75 76 const maxInt = int(^uint(0) >> 1) 77 78 // safeAdd is a saturating add of a and b that ignores negative operands. 79 // If a + b would overflow through normal Go addition, 80 // it returns the maximum integer value instead. 81 // 82 // Adding complexities with this function prevents attackers from intentionally 83 // overflowing the complexity calculation to allow overly-complex queries. 84 // 85 // It also helps mitigate the impact of custom complexities that accidentally 86 // return negative values. 87 func safeAdd(a, b int) int { 88 // Ignore negative operands. 89 if a < 0 { 90 if b < 0 { 91 return 1 92 } 93 return b 94 } else if b < 0 { 95 return a 96 } 97 98 c := a + b 99 if c < a { 100 // Set c to maximum integer instead of overflowing. 101 c = maxInt 102 } 103 return c 104 }