github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/complexity/complexity.go (about) 1 package complexity 2 3 import ( 4 "github.com/mstephano/gqlgen-schemagen/graphql" 5 "github.com/vektah/gqlparser/v2/ast" 6 ) 7 8 func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int { 9 walker := complexityWalker{ 10 es: es, 11 schema: es.Schema(), 12 vars: vars, 13 } 14 return walker.selectionSetComplexity(op.SelectionSet) 15 } 16 17 type complexityWalker struct { 18 es graphql.ExecutableSchema 19 schema *ast.Schema 20 vars map[string]interface{} 21 } 22 23 func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int { 24 var complexity int 25 for _, selection := range selectionSet { 26 switch s := selection.(type) { 27 case *ast.Field: 28 fieldDefinition := cw.schema.Types[s.Definition.Type.Name()] 29 30 if fieldDefinition.Name == "__Schema" { 31 continue 32 } 33 34 var childComplexity int 35 switch fieldDefinition.Kind { 36 case ast.Object, ast.Interface, ast.Union: 37 childComplexity = cw.selectionSetComplexity(s.SelectionSet) 38 } 39 40 args := s.ArgumentMap(cw.vars) 41 var fieldComplexity int 42 if s.ObjectDefinition.Kind == ast.Interface { 43 fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args) 44 } else { 45 fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args) 46 } 47 complexity = safeAdd(complexity, fieldComplexity) 48 49 case *ast.FragmentSpread: 50 complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet)) 51 52 case *ast.InlineFragment: 53 complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet)) 54 } 55 } 56 return complexity 57 } 58 59 func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int { 60 // Interfaces don't have their own separate field costs, so they have to assume the worst case. 61 // We iterate over all implementors and choose the most expensive one. 62 maxComplexity := 0 63 implementors := cw.schema.GetPossibleTypes(def) 64 for _, t := range implementors { 65 fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args) 66 if fieldComplexity > maxComplexity { 67 maxComplexity = fieldComplexity 68 } 69 } 70 return maxComplexity 71 } 72 73 func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int { 74 if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity { 75 return customComplexity 76 } 77 // default complexity calculation 78 return safeAdd(1, childComplexity) 79 } 80 81 const maxInt = int(^uint(0) >> 1) 82 83 // safeAdd is a saturating add of a and b that ignores negative operands. 84 // If a + b would overflow through normal Go addition, 85 // it returns the maximum integer value instead. 86 // 87 // Adding complexities with this function prevents attackers from intentionally 88 // overflowing the complexity calculation to allow overly-complex queries. 89 // 90 // It also helps mitigate the impact of custom complexities that accidentally 91 // return negative values. 92 func safeAdd(a, b int) int { 93 // Ignore negative operands. 94 if a < 0 { 95 if b < 0 { 96 return 1 97 } 98 return b 99 } else if b < 0 { 100 return a 101 } 102 103 c := a + b 104 if c < a { 105 // Set c to maximum integer instead of overflowing. 106 c = maxInt 107 } 108 return c 109 }