github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/graphql/exec.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/vektah/gqlparser/ast"
     8  )
     9  
    10  type ExecutableSchema interface {
    11  	Schema() *ast.Schema
    12  
    13  	Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
    14  	Query(ctx context.Context, op *ast.OperationDefinition) *Response
    15  	Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
    16  	Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
    17  }
    18  
    19  // CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
    20  // passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
    21  // type conditions.
    22  func CollectFields(ctx context.Context, selSet ast.SelectionSet, satisfies []string) []CollectedField {
    23  	return collectFields(GetRequestContext(ctx), selSet, satisfies, map[string]bool{})
    24  }
    25  
    26  func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
    27  	var groupedFields []CollectedField
    28  
    29  	for _, sel := range selSet {
    30  		switch sel := sel.(type) {
    31  		case *ast.Field:
    32  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    33  				continue
    34  			}
    35  			f := getOrCreateField(&groupedFields, sel.Alias, func() CollectedField {
    36  				return CollectedField{Field: sel}
    37  			})
    38  
    39  			f.Selections = append(f.Selections, sel.SelectionSet...)
    40  		case *ast.InlineFragment:
    41  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    42  				continue
    43  			}
    44  			if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
    45  				continue
    46  			}
    47  			for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
    48  				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
    49  				f.Selections = append(f.Selections, childField.Selections...)
    50  			}
    51  
    52  		case *ast.FragmentSpread:
    53  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    54  				continue
    55  			}
    56  			fragmentName := sel.Name
    57  			if _, seen := visited[fragmentName]; seen {
    58  				continue
    59  			}
    60  			visited[fragmentName] = true
    61  
    62  			fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
    63  			if fragment == nil {
    64  				// should never happen, validator has already run
    65  				panic(fmt.Errorf("missing fragment %s", fragmentName))
    66  			}
    67  
    68  			if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
    69  				continue
    70  			}
    71  
    72  			for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
    73  				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
    74  				f.Selections = append(f.Selections, childField.Selections...)
    75  			}
    76  
    77  		default:
    78  			panic(fmt.Errorf("unsupported %T", sel))
    79  		}
    80  	}
    81  
    82  	return groupedFields
    83  }
    84  
    85  type CollectedField struct {
    86  	*ast.Field
    87  
    88  	Selections ast.SelectionSet
    89  }
    90  
    91  func instanceOf(val string, satisfies []string) bool {
    92  	for _, s := range satisfies {
    93  		if val == s {
    94  			return true
    95  		}
    96  	}
    97  	return false
    98  }
    99  
   100  func getOrCreateField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
   101  	for i, cf := range *c {
   102  		if cf.Alias == name {
   103  			return &(*c)[i]
   104  		}
   105  	}
   106  
   107  	f := creator()
   108  
   109  	*c = append(*c, f)
   110  	return &(*c)[len(*c)-1]
   111  }
   112  
   113  func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
   114  	skip, include := false, true
   115  
   116  	if d := directives.ForName("skip"); d != nil {
   117  		skip = resolveIfArgument(d, variables)
   118  	}
   119  
   120  	if d := directives.ForName("include"); d != nil {
   121  		include = resolveIfArgument(d, variables)
   122  	}
   123  
   124  	return !skip && include
   125  }
   126  
   127  func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
   128  	arg := d.Arguments.ForName("if")
   129  	if arg == nil {
   130  		panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
   131  	}
   132  	value, err := arg.Value.Value(variables)
   133  	if err != nil {
   134  		panic(err)
   135  	}
   136  	ret, ok := value.(bool)
   137  	if !ok {
   138  		panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
   139  	}
   140  	return ret
   141  }