github.com/zikaeroh/gqlgen@v0.7.2/graphql/exec.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/vektah/gqlparser/ast"
     8  )
     9  
    10  type ExecutableSchema interface {
    11  	Schema() *ast.Schema
    12  
    13  	Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
    14  	Query(ctx context.Context, op *ast.OperationDefinition) *Response
    15  	Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
    16  	Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
    17  }
    18  
    19  func CollectFields(ctx context.Context, selSet ast.SelectionSet, satisfies []string) []CollectedField {
    20  	return collectFields(GetRequestContext(ctx), selSet, satisfies, map[string]bool{})
    21  }
    22  
    23  func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
    24  	var groupedFields []CollectedField
    25  
    26  	for _, sel := range selSet {
    27  		switch sel := sel.(type) {
    28  		case *ast.Field:
    29  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    30  				continue
    31  			}
    32  			f := getOrCreateField(&groupedFields, sel.Alias, func() CollectedField {
    33  				return CollectedField{Field: sel}
    34  			})
    35  
    36  			f.Selections = append(f.Selections, sel.SelectionSet...)
    37  		case *ast.InlineFragment:
    38  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) || !instanceOf(sel.TypeCondition, satisfies) {
    39  				continue
    40  			}
    41  			for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
    42  				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
    43  				f.Selections = append(f.Selections, childField.Selections...)
    44  			}
    45  
    46  		case *ast.FragmentSpread:
    47  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    48  				continue
    49  			}
    50  			fragmentName := sel.Name
    51  			if _, seen := visited[fragmentName]; seen {
    52  				continue
    53  			}
    54  			visited[fragmentName] = true
    55  
    56  			fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
    57  			if fragment == nil {
    58  				// should never happen, validator has already run
    59  				panic(fmt.Errorf("missing fragment %s", fragmentName))
    60  			}
    61  
    62  			if !instanceOf(fragment.TypeCondition, satisfies) {
    63  				continue
    64  			}
    65  
    66  			for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
    67  				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
    68  				f.Selections = append(f.Selections, childField.Selections...)
    69  			}
    70  
    71  		default:
    72  			panic(fmt.Errorf("unsupported %T", sel))
    73  		}
    74  	}
    75  
    76  	return groupedFields
    77  }
    78  
    79  type CollectedField struct {
    80  	*ast.Field
    81  
    82  	Selections ast.SelectionSet
    83  }
    84  
    85  func instanceOf(val string, satisfies []string) bool {
    86  	for _, s := range satisfies {
    87  		if val == s {
    88  			return true
    89  		}
    90  	}
    91  	return false
    92  }
    93  
    94  func getOrCreateField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
    95  	for i, cf := range *c {
    96  		if cf.Alias == name {
    97  			return &(*c)[i]
    98  		}
    99  	}
   100  
   101  	f := creator()
   102  
   103  	*c = append(*c, f)
   104  	return &(*c)[len(*c)-1]
   105  }
   106  
   107  func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
   108  	skip, include := false, true
   109  
   110  	if d := directives.ForName("skip"); d != nil {
   111  		skip = resolveIfArgument(d, variables)
   112  	}
   113  
   114  	if d := directives.ForName("include"); d != nil {
   115  		include = resolveIfArgument(d, variables)
   116  	}
   117  
   118  	return !skip && include
   119  }
   120  
   121  func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
   122  	arg := d.Arguments.ForName("if")
   123  	if arg == nil {
   124  		panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
   125  	}
   126  	value, err := arg.Value.Value(variables)
   127  	if err != nil {
   128  		panic(err)
   129  	}
   130  	ret, ok := value.(bool)
   131  	if !ok {
   132  		panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
   133  	}
   134  	return ret
   135  }