github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/graphql/executable_schema.go (about) 1 //go:generate go run github.com/matryer/moq -out executable_schema_mock.go . ExecutableSchema 2 3 package graphql 4 5 import ( 6 "context" 7 "fmt" 8 9 "github.com/vektah/gqlparser/v2/ast" 10 ) 11 12 type ExecutableSchema interface { 13 Schema() *ast.Schema 14 15 Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) 16 Exec(ctx context.Context) ResponseHandler 17 } 18 19 // CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types 20 // passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment 21 // type conditions. 22 func CollectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string) []CollectedField { 23 return collectFields(reqCtx, selSet, satisfies, map[string]bool{}) 24 } 25 26 func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField { 27 groupedFields := make([]CollectedField, 0, len(selSet)) 28 29 for _, sel := range selSet { 30 switch sel := sel.(type) { 31 case *ast.Field: 32 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) { 33 continue 34 } 35 f := getOrCreateAndAppendField(&groupedFields, sel.Name, sel.Alias, sel.ObjectDefinition, func() CollectedField { 36 return CollectedField{Field: sel} 37 }) 38 39 f.Selections = append(f.Selections, sel.SelectionSet...) 40 41 case *ast.InlineFragment: 42 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) { 43 continue 44 } 45 if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) { 46 continue 47 } 48 for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) { 49 f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField }) 50 f.Selections = append(f.Selections, childField.Selections...) 51 } 52 53 case *ast.FragmentSpread: 54 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) { 55 continue 56 } 57 fragmentName := sel.Name 58 if _, seen := visited[fragmentName]; seen { 59 continue 60 } 61 visited[fragmentName] = true 62 63 fragment := reqCtx.Doc.Fragments.ForName(fragmentName) 64 if fragment == nil { 65 // should never happen, validator has already run 66 panic(fmt.Errorf("missing fragment %s", fragmentName)) 67 } 68 69 if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) { 70 continue 71 } 72 73 for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) { 74 f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField }) 75 f.Selections = append(f.Selections, childField.Selections...) 76 } 77 78 default: 79 panic(fmt.Errorf("unsupported %T", sel)) 80 } 81 } 82 83 return groupedFields 84 } 85 86 type CollectedField struct { 87 *ast.Field 88 89 Selections ast.SelectionSet 90 } 91 92 func instanceOf(val string, satisfies []string) bool { 93 for _, s := range satisfies { 94 if val == s { 95 return true 96 } 97 } 98 return false 99 } 100 101 func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, objectDefinition *ast.Definition, creator func() CollectedField) *CollectedField { 102 for i, cf := range *c { 103 if cf.Name == name && cf.Alias == alias { 104 if cf.ObjectDefinition == objectDefinition { 105 return &(*c)[i] 106 } 107 108 if cf.ObjectDefinition == nil || objectDefinition == nil { 109 continue 110 } 111 112 if cf.ObjectDefinition.Name == objectDefinition.Name { 113 return &(*c)[i] 114 } 115 116 for _, ifc := range objectDefinition.Interfaces { 117 if ifc == cf.ObjectDefinition.Name { 118 return &(*c)[i] 119 } 120 } 121 for _, ifc := range cf.ObjectDefinition.Interfaces { 122 if ifc == objectDefinition.Name { 123 return &(*c)[i] 124 } 125 } 126 } 127 } 128 129 f := creator() 130 131 *c = append(*c, f) 132 return &(*c)[len(*c)-1] 133 } 134 135 func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool { 136 if len(directives) == 0 { 137 return true 138 } 139 140 skip, include := false, true 141 142 if d := directives.ForName("skip"); d != nil { 143 skip = resolveIfArgument(d, variables) 144 } 145 146 if d := directives.ForName("include"); d != nil { 147 include = resolveIfArgument(d, variables) 148 } 149 150 return !skip && include 151 } 152 153 func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool { 154 arg := d.Arguments.ForName("if") 155 if arg == nil { 156 panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name)) 157 } 158 value, err := arg.Value.Value(variables) 159 if err != nil { 160 panic(err) 161 } 162 ret, ok := value.(bool) 163 if !ok { 164 panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name)) 165 } 166 return ret 167 }