github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/codegen/directive.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"github.com/99designs/gqlgen/codegen/templates"
     9  	"github.com/pkg/errors"
    10  	"github.com/vektah/gqlparser/v2/ast"
    11  )
    12  
    13  type DirectiveList map[string]*Directive
    14  
    15  //LocationDirectives filter directives by location
    16  func (dl DirectiveList) LocationDirectives(location string) DirectiveList {
    17  	return locationDirectives(dl, ast.DirectiveLocation(location))
    18  }
    19  
    20  type Directive struct {
    21  	*ast.DirectiveDefinition
    22  	Name    string
    23  	Args    []*FieldArgument
    24  	Builtin bool
    25  }
    26  
    27  //IsLocation check location directive
    28  func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool {
    29  	for _, l := range d.Locations {
    30  		for _, a := range location {
    31  			if l == a {
    32  				return true
    33  			}
    34  		}
    35  	}
    36  
    37  	return false
    38  }
    39  
    40  func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive {
    41  	mDirectives := make(map[string]*Directive)
    42  	for name, d := range directives {
    43  		if d.IsLocation(location...) {
    44  			mDirectives[name] = d
    45  		}
    46  	}
    47  	return mDirectives
    48  }
    49  
    50  func (b *builder) buildDirectives() (map[string]*Directive, error) {
    51  	directives := make(map[string]*Directive, len(b.Schema.Directives))
    52  
    53  	for name, dir := range b.Schema.Directives {
    54  		if _, ok := directives[name]; ok {
    55  			return nil, errors.Errorf("directive with name %s already exists", name)
    56  		}
    57  
    58  		var args []*FieldArgument
    59  		for _, arg := range dir.Arguments {
    60  			tr, err := b.Binder.TypeReference(arg.Type, nil)
    61  			if err != nil {
    62  				return nil, err
    63  			}
    64  
    65  			newArg := &FieldArgument{
    66  				ArgumentDefinition: arg,
    67  				TypeReference:      tr,
    68  				VarName:            templates.ToGoPrivate(arg.Name),
    69  			}
    70  
    71  			if arg.DefaultValue != nil {
    72  				var err error
    73  				newArg.Default, err = arg.DefaultValue.Value(nil)
    74  				if err != nil {
    75  					return nil, errors.Errorf("default value for directive argument %s(%s) is not valid: %s", dir.Name, arg.Name, err.Error())
    76  				}
    77  			}
    78  			args = append(args, newArg)
    79  		}
    80  
    81  		directives[name] = &Directive{
    82  			DirectiveDefinition: dir,
    83  			Name:                name,
    84  			Args:                args,
    85  			Builtin:             b.Config.Directives[name].SkipRuntime,
    86  		}
    87  	}
    88  
    89  	return directives, nil
    90  }
    91  
    92  func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
    93  	dirs := make([]*Directive, len(list))
    94  	for i, d := range list {
    95  		argValues := make(map[string]interface{}, len(d.Arguments))
    96  		for _, da := range d.Arguments {
    97  			val, err := da.Value.Value(nil)
    98  			if err != nil {
    99  				return nil, err
   100  			}
   101  			argValues[da.Name] = val
   102  		}
   103  		def, ok := b.Directives[d.Name]
   104  		if !ok {
   105  			return nil, fmt.Errorf("directive %s not found", d.Name)
   106  		}
   107  
   108  		var args []*FieldArgument
   109  		for _, a := range def.Args {
   110  			value := a.Default
   111  			if argValue, ok := argValues[a.Name]; ok {
   112  				value = argValue
   113  			}
   114  			args = append(args, &FieldArgument{
   115  				ArgumentDefinition: a.ArgumentDefinition,
   116  				Value:              value,
   117  				VarName:            a.VarName,
   118  				TypeReference:      a.TypeReference,
   119  			})
   120  		}
   121  		dirs[i] = &Directive{
   122  			Name:                d.Name,
   123  			Args:                args,
   124  			DirectiveDefinition: list[i].Definition,
   125  			Builtin:             b.Config.Directives[d.Name].SkipRuntime,
   126  		}
   127  
   128  	}
   129  
   130  	return dirs, nil
   131  }
   132  
   133  func (d *Directive) ArgsFunc() string {
   134  	if len(d.Args) == 0 {
   135  		return ""
   136  	}
   137  
   138  	return "dir_" + d.Name + "_args"
   139  }
   140  
   141  func (d *Directive) CallArgs() string {
   142  	args := []string{"ctx", "obj", "n"}
   143  
   144  	for _, arg := range d.Args {
   145  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   146  	}
   147  
   148  	return strings.Join(args, ", ")
   149  }
   150  
   151  func (d *Directive) ResolveArgs(obj string, next int) string {
   152  	args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)}
   153  
   154  	for _, arg := range d.Args {
   155  		dArg := arg.VarName
   156  		if arg.Value == nil && arg.Default == nil {
   157  			dArg = "nil"
   158  		}
   159  
   160  		args = append(args, dArg)
   161  	}
   162  
   163  	return strings.Join(args, ", ")
   164  }
   165  
   166  func (d *Directive) Declaration() string {
   167  	res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
   168  
   169  	for _, arg := range d.Args {
   170  		res += fmt.Sprintf(", %s %s", templates.ToGoPrivate(arg.Name), templates.CurrentImports.LookupType(arg.TypeReference.GO))
   171  	}
   172  
   173  	res += ") (res interface{}, err error)"
   174  	return res
   175  }