github.com/operandinc/gqlgen@v0.16.1/codegen/directive.go (about) 1 package codegen 2 3 import ( 4 "fmt" 5 "strconv" 6 "strings" 7 8 "github.com/operandinc/gqlgen/codegen/templates" 9 "github.com/vektah/gqlparser/v2/ast" 10 ) 11 12 type DirectiveList map[string]*Directive 13 14 // LocationDirectives filter directives by location 15 func (dl DirectiveList) LocationDirectives(location string) DirectiveList { 16 return locationDirectives(dl, ast.DirectiveLocation(location)) 17 } 18 19 type Directive struct { 20 *ast.DirectiveDefinition 21 Name string 22 Args []*FieldArgument 23 Builtin bool 24 } 25 26 // IsLocation check location directive 27 func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool { 28 for _, l := range d.Locations { 29 for _, a := range location { 30 if l == a { 31 return true 32 } 33 } 34 } 35 36 return false 37 } 38 39 func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive { 40 mDirectives := make(map[string]*Directive) 41 for name, d := range directives { 42 if d.IsLocation(location...) { 43 mDirectives[name] = d 44 } 45 } 46 return mDirectives 47 } 48 49 func (b *builder) buildDirectives() (map[string]*Directive, error) { 50 directives := make(map[string]*Directive, len(b.Schema.Directives)) 51 52 for name, dir := range b.Schema.Directives { 53 if _, ok := directives[name]; ok { 54 return nil, fmt.Errorf("directive with name %s already exists", name) 55 } 56 57 var args []*FieldArgument 58 for _, arg := range dir.Arguments { 59 tr, err := b.Binder.TypeReference(arg.Type, nil) 60 if err != nil { 61 return nil, err 62 } 63 64 newArg := &FieldArgument{ 65 ArgumentDefinition: arg, 66 TypeReference: tr, 67 VarName: templates.ToGoPrivate(arg.Name), 68 } 69 70 if arg.DefaultValue != nil { 71 var err error 72 newArg.Default, err = arg.DefaultValue.Value(nil) 73 if err != nil { 74 return nil, fmt.Errorf("default value for directive argument %s(%s) is not valid: %w", dir.Name, arg.Name, err) 75 } 76 } 77 args = append(args, newArg) 78 } 79 80 directives[name] = &Directive{ 81 DirectiveDefinition: dir, 82 Name: name, 83 Args: args, 84 Builtin: b.Config.Directives[name].SkipRuntime, 85 } 86 } 87 88 return directives, nil 89 } 90 91 func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) { 92 dirs := make([]*Directive, len(list)) 93 for i, d := range list { 94 argValues := make(map[string]interface{}, len(d.Arguments)) 95 for _, da := range d.Arguments { 96 val, err := da.Value.Value(nil) 97 if err != nil { 98 return nil, err 99 } 100 argValues[da.Name] = val 101 } 102 def, ok := b.Directives[d.Name] 103 if !ok { 104 return nil, fmt.Errorf("directive %s not found", d.Name) 105 } 106 107 var args []*FieldArgument 108 for _, a := range def.Args { 109 value := a.Default 110 if argValue, ok := argValues[a.Name]; ok { 111 value = argValue 112 } 113 args = append(args, &FieldArgument{ 114 ArgumentDefinition: a.ArgumentDefinition, 115 Value: value, 116 VarName: a.VarName, 117 TypeReference: a.TypeReference, 118 }) 119 } 120 dirs[i] = &Directive{ 121 Name: d.Name, 122 Args: args, 123 DirectiveDefinition: list[i].Definition, 124 Builtin: b.Config.Directives[d.Name].SkipRuntime, 125 } 126 127 } 128 129 return dirs, nil 130 } 131 132 func (d *Directive) ArgsFunc() string { 133 if len(d.Args) == 0 { 134 return "" 135 } 136 137 return "dir_" + d.Name + "_args" 138 } 139 140 func (d *Directive) CallArgs() string { 141 args := []string{"ctx", "obj", "n"} 142 143 for _, arg := range d.Args { 144 args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")") 145 } 146 147 return strings.Join(args, ", ") 148 } 149 150 func (d *Directive) ResolveArgs(obj string, next int) string { 151 args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)} 152 153 for _, arg := range d.Args { 154 dArg := arg.VarName 155 if arg.Value == nil && arg.Default == nil { 156 dArg = "nil" 157 } 158 159 args = append(args, dArg) 160 } 161 162 return strings.Join(args, ", ") 163 } 164 165 func (d *Directive) Declaration() string { 166 res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver" 167 168 for _, arg := range d.Args { 169 res += fmt.Sprintf(", %s %s", templates.ToGoPrivate(arg.Name), templates.CurrentImports.LookupType(arg.TypeReference.GO)) 170 } 171 172 res += ") (res interface{}, err error)" 173 return res 174 }