github.com/apipluspower/gqlgen@v0.15.2/codegen/directive.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"github.com/apipluspower/gqlgen/codegen/templates"
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  )
    11  
    12  type DirectiveList map[string]*Directive
    13  
    14  // LocationDirectives filter directives by location
    15  func (dl DirectiveList) LocationDirectives(location string) DirectiveList {
    16  	return locationDirectives(dl, ast.DirectiveLocation(location))
    17  }
    18  
    19  type Directive struct {
    20  	*ast.DirectiveDefinition
    21  	Name    string
    22  	Args    []*FieldArgument
    23  	Builtin bool
    24  }
    25  
    26  // IsLocation check location directive
    27  func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool {
    28  	for _, l := range d.Locations {
    29  		for _, a := range location {
    30  			if l == a {
    31  				return true
    32  			}
    33  		}
    34  	}
    35  
    36  	return false
    37  }
    38  
    39  func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive {
    40  	mDirectives := make(map[string]*Directive)
    41  	for name, d := range directives {
    42  		if d.IsLocation(location...) {
    43  			mDirectives[name] = d
    44  		}
    45  	}
    46  	return mDirectives
    47  }
    48  
    49  func (b *builder) buildDirectives() (map[string]*Directive, error) {
    50  	directives := make(map[string]*Directive, len(b.Schema.Directives))
    51  
    52  	for name, dir := range b.Schema.Directives {
    53  		if _, ok := directives[name]; ok {
    54  			return nil, fmt.Errorf("directive with name %s already exists", name)
    55  		}
    56  
    57  		var args []*FieldArgument
    58  		for _, arg := range dir.Arguments {
    59  			tr, err := b.Binder.TypeReference(arg.Type, nil)
    60  			if err != nil {
    61  				return nil, err
    62  			}
    63  
    64  			newArg := &FieldArgument{
    65  				ArgumentDefinition: arg,
    66  				TypeReference:      tr,
    67  				VarName:            templates.ToGoPrivate(arg.Name),
    68  			}
    69  
    70  			if arg.DefaultValue != nil {
    71  				var err error
    72  				newArg.Default, err = arg.DefaultValue.Value(nil)
    73  				if err != nil {
    74  					return nil, fmt.Errorf("default value for directive argument %s(%s) is not valid: %w", dir.Name, arg.Name, err)
    75  				}
    76  			}
    77  			args = append(args, newArg)
    78  		}
    79  
    80  		directives[name] = &Directive{
    81  			DirectiveDefinition: dir,
    82  			Name:                name,
    83  			Args:                args,
    84  			Builtin:             b.Config.Directives[name].SkipRuntime,
    85  		}
    86  	}
    87  
    88  	return directives, nil
    89  }
    90  
    91  func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
    92  	dirs := make([]*Directive, len(list))
    93  	for i, d := range list {
    94  		argValues := make(map[string]interface{}, len(d.Arguments))
    95  		for _, da := range d.Arguments {
    96  			val, err := da.Value.Value(nil)
    97  			if err != nil {
    98  				return nil, err
    99  			}
   100  			argValues[da.Name] = val
   101  		}
   102  		def, ok := b.Directives[d.Name]
   103  		if !ok {
   104  			return nil, fmt.Errorf("directive %s not found", d.Name)
   105  		}
   106  
   107  		var args []*FieldArgument
   108  		for _, a := range def.Args {
   109  			value := a.Default
   110  			if argValue, ok := argValues[a.Name]; ok {
   111  				value = argValue
   112  			}
   113  			args = append(args, &FieldArgument{
   114  				ArgumentDefinition: a.ArgumentDefinition,
   115  				Value:              value,
   116  				VarName:            a.VarName,
   117  				TypeReference:      a.TypeReference,
   118  			})
   119  		}
   120  		dirs[i] = &Directive{
   121  			Name:                d.Name,
   122  			Args:                args,
   123  			DirectiveDefinition: list[i].Definition,
   124  			Builtin:             b.Config.Directives[d.Name].SkipRuntime,
   125  		}
   126  
   127  	}
   128  
   129  	return dirs, nil
   130  }
   131  
   132  func (d *Directive) ArgsFunc() string {
   133  	if len(d.Args) == 0 {
   134  		return ""
   135  	}
   136  
   137  	return "dir_" + d.Name + "_args"
   138  }
   139  
   140  func (d *Directive) CallArgs() string {
   141  	args := []string{"ctx", "obj", "n"}
   142  
   143  	for _, arg := range d.Args {
   144  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   145  	}
   146  
   147  	return strings.Join(args, ", ")
   148  }
   149  
   150  func (d *Directive) ResolveArgs(obj string, next int) string {
   151  	args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)}
   152  
   153  	for _, arg := range d.Args {
   154  		dArg := arg.VarName
   155  		if arg.Value == nil && arg.Default == nil {
   156  			dArg = "nil"
   157  		}
   158  
   159  		args = append(args, dArg)
   160  	}
   161  
   162  	return strings.Join(args, ", ")
   163  }
   164  
   165  func (d *Directive) Declaration() string {
   166  	res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
   167  
   168  	for _, arg := range d.Args {
   169  		res += fmt.Sprintf(", %s %s", templates.ToGoPrivate(arg.Name), templates.CurrentImports.LookupType(arg.TypeReference.GO))
   170  	}
   171  
   172  	res += ") (res interface{}, err error)"
   173  	return res
   174  }