github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/graphql/executable_schema.go (about)

     1  //go:generate go run github.com/matryer/moq -out executable_schema_mock.go . ExecutableSchema
     2  
     3  package graphql
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  )
    11  
    12  type ExecutableSchema interface {
    13  	Schema() *ast.Schema
    14  
    15  	Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
    16  	Exec(ctx context.Context) ResponseHandler
    17  }
    18  
    19  // CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
    20  // passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
    21  // type conditions.
    22  func CollectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string) []CollectedField {
    23  	return collectFields(reqCtx, selSet, satisfies, map[string]bool{})
    24  }
    25  
    26  func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
    27  	groupedFields := make([]CollectedField, 0, len(selSet))
    28  
    29  	for _, sel := range selSet {
    30  		switch sel := sel.(type) {
    31  		case *ast.Field:
    32  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    33  				continue
    34  			}
    35  			f := getOrCreateAndAppendField(&groupedFields, sel.Name, sel.Alias, sel.ObjectDefinition, func() CollectedField {
    36  				return CollectedField{Field: sel}
    37  			})
    38  
    39  			f.Selections = append(f.Selections, sel.SelectionSet...)
    40  
    41  		case *ast.InlineFragment:
    42  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    43  				continue
    44  			}
    45  			if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
    46  				continue
    47  			}
    48  			for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
    49  				f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField })
    50  				f.Selections = append(f.Selections, childField.Selections...)
    51  			}
    52  
    53  		case *ast.FragmentSpread:
    54  			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
    55  				continue
    56  			}
    57  			fragmentName := sel.Name
    58  			if _, seen := visited[fragmentName]; seen {
    59  				continue
    60  			}
    61  			visited[fragmentName] = true
    62  
    63  			fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
    64  			if fragment == nil {
    65  				// should never happen, validator has already run
    66  				panic(fmt.Errorf("missing fragment %s", fragmentName))
    67  			}
    68  
    69  			if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
    70  				continue
    71  			}
    72  
    73  			for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
    74  				f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField })
    75  				f.Selections = append(f.Selections, childField.Selections...)
    76  			}
    77  
    78  		default:
    79  			panic(fmt.Errorf("unsupported %T", sel))
    80  		}
    81  	}
    82  
    83  	return groupedFields
    84  }
    85  
    86  type CollectedField struct {
    87  	*ast.Field
    88  
    89  	Selections ast.SelectionSet
    90  }
    91  
    92  func instanceOf(val string, satisfies []string) bool {
    93  	for _, s := range satisfies {
    94  		if val == s {
    95  			return true
    96  		}
    97  	}
    98  	return false
    99  }
   100  
   101  func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, objectDefinition *ast.Definition, creator func() CollectedField) *CollectedField {
   102  	for i, cf := range *c {
   103  		if cf.Name == name && cf.Alias == alias {
   104  			if cf.ObjectDefinition == objectDefinition {
   105  				return &(*c)[i]
   106  			}
   107  
   108  			if cf.ObjectDefinition == nil || objectDefinition == nil {
   109  				continue
   110  			}
   111  
   112  			if cf.ObjectDefinition.Name == objectDefinition.Name {
   113  				return &(*c)[i]
   114  			}
   115  
   116  			for _, ifc := range objectDefinition.Interfaces {
   117  				if ifc == cf.ObjectDefinition.Name {
   118  					return &(*c)[i]
   119  				}
   120  			}
   121  			for _, ifc := range cf.ObjectDefinition.Interfaces {
   122  				if ifc == objectDefinition.Name {
   123  					return &(*c)[i]
   124  				}
   125  			}
   126  		}
   127  	}
   128  
   129  	f := creator()
   130  
   131  	*c = append(*c, f)
   132  	return &(*c)[len(*c)-1]
   133  }
   134  
   135  func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
   136  	if len(directives) == 0 {
   137  		return true
   138  	}
   139  
   140  	skip, include := false, true
   141  
   142  	if d := directives.ForName("skip"); d != nil {
   143  		skip = resolveIfArgument(d, variables)
   144  	}
   145  
   146  	if d := directives.ForName("include"); d != nil {
   147  		include = resolveIfArgument(d, variables)
   148  	}
   149  
   150  	return !skip && include
   151  }
   152  
   153  func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
   154  	arg := d.Arguments.ForName("if")
   155  	if arg == nil {
   156  		panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
   157  	}
   158  	value, err := arg.Value.Value(variables)
   159  	if err != nil {
   160  		panic(err)
   161  	}
   162  	ret, ok := value.(bool)
   163  	if !ok {
   164  		panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
   165  	}
   166  	return ret
   167  }