github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/graphql/context.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/vektah/gqlparser/ast"
     9  	"github.com/vektah/gqlparser/gqlerror"
    10  )
    11  
    12  type Resolver func(ctx context.Context) (res interface{}, err error)
    13  type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
    14  type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
    15  
    16  type RequestContext struct {
    17  	RawQuery  string
    18  	Variables map[string]interface{}
    19  	Doc       *ast.QueryDocument
    20  
    21  	ComplexityLimit      int
    22  	OperationComplexity  int
    23  	DisableIntrospection bool
    24  
    25  	// ErrorPresenter will be used to generate the error
    26  	// message from errors given to Error().
    27  	ErrorPresenter      ErrorPresenterFunc
    28  	Recover             RecoverFunc
    29  	ResolverMiddleware  FieldMiddleware
    30  	DirectiveMiddleware FieldMiddleware
    31  	RequestMiddleware   RequestMiddleware
    32  	Tracer              Tracer
    33  
    34  	errorsMu     sync.Mutex
    35  	Errors       gqlerror.List
    36  	extensionsMu sync.Mutex
    37  	Extensions   map[string]interface{}
    38  }
    39  
    40  func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
    41  	return next(ctx)
    42  }
    43  
    44  func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
    45  	return next(ctx)
    46  }
    47  
    48  func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
    49  	return next(ctx)
    50  }
    51  
    52  func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
    53  	return &RequestContext{
    54  		Doc:                 doc,
    55  		RawQuery:            query,
    56  		Variables:           variables,
    57  		ResolverMiddleware:  DefaultResolverMiddleware,
    58  		DirectiveMiddleware: DefaultDirectiveMiddleware,
    59  		RequestMiddleware:   DefaultRequestMiddleware,
    60  		Recover:             DefaultRecover,
    61  		ErrorPresenter:      DefaultErrorPresenter,
    62  		Tracer:              &NopTracer{},
    63  	}
    64  }
    65  
    66  type key string
    67  
    68  const (
    69  	request  key = "request_context"
    70  	resolver key = "resolver_context"
    71  )
    72  
    73  func GetRequestContext(ctx context.Context) *RequestContext {
    74  	if val, ok := ctx.Value(request).(*RequestContext); ok {
    75  		return val
    76  	}
    77  	return nil
    78  }
    79  
    80  func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
    81  	return context.WithValue(ctx, request, rc)
    82  }
    83  
    84  type ResolverContext struct {
    85  	Parent *ResolverContext
    86  	// The name of the type this field belongs to
    87  	Object string
    88  	// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
    89  	Args map[string]interface{}
    90  	// The raw field
    91  	Field CollectedField
    92  	// The index of array in path.
    93  	Index *int
    94  	// The result object of resolver
    95  	Result interface{}
    96  }
    97  
    98  func (r *ResolverContext) Path() []interface{} {
    99  	var path []interface{}
   100  	for it := r; it != nil; it = it.Parent {
   101  		if it.Index != nil {
   102  			path = append(path, *it.Index)
   103  		} else if it.Field.Field != nil {
   104  			path = append(path, it.Field.Alias)
   105  		}
   106  	}
   107  
   108  	// because we are walking up the chain, all the elements are backwards, do an inplace flip.
   109  	for i := len(path)/2 - 1; i >= 0; i-- {
   110  		opp := len(path) - 1 - i
   111  		path[i], path[opp] = path[opp], path[i]
   112  	}
   113  
   114  	return path
   115  }
   116  
   117  func GetResolverContext(ctx context.Context) *ResolverContext {
   118  	if val, ok := ctx.Value(resolver).(*ResolverContext); ok {
   119  		return val
   120  	}
   121  	return nil
   122  }
   123  
   124  func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
   125  	rc.Parent = GetResolverContext(ctx)
   126  	return context.WithValue(ctx, resolver, rc)
   127  }
   128  
   129  // This is just a convenient wrapper method for CollectFields
   130  func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
   131  	resctx := GetResolverContext(ctx)
   132  	return CollectFields(ctx, resctx.Field.Selections, satisfies)
   133  }
   134  
   135  // CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
   136  // The slice will contain the unique set of all field names requested regardless of fragment type conditions.
   137  func CollectAllFields(ctx context.Context) []string {
   138  	resctx := GetResolverContext(ctx)
   139  	collected := CollectFields(ctx, resctx.Field.Selections, nil)
   140  	uniq := make([]string, 0, len(collected))
   141  Next:
   142  	for _, f := range collected {
   143  		for _, name := range uniq {
   144  			if name == f.Name {
   145  				continue Next
   146  			}
   147  		}
   148  		uniq = append(uniq, f.Name)
   149  	}
   150  	return uniq
   151  }
   152  
   153  // Errorf sends an error string to the client, passing it through the formatter.
   154  func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
   155  	c.errorsMu.Lock()
   156  	defer c.errorsMu.Unlock()
   157  
   158  	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
   159  }
   160  
   161  // Error sends an error to the client, passing it through the formatter.
   162  func (c *RequestContext) Error(ctx context.Context, err error) {
   163  	c.errorsMu.Lock()
   164  	defer c.errorsMu.Unlock()
   165  
   166  	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
   167  }
   168  
   169  // HasError returns true if the current field has already errored
   170  func (c *RequestContext) HasError(rctx *ResolverContext) bool {
   171  	c.errorsMu.Lock()
   172  	defer c.errorsMu.Unlock()
   173  	path := rctx.Path()
   174  
   175  	for _, err := range c.Errors {
   176  		if equalPath(err.Path, path) {
   177  			return true
   178  		}
   179  	}
   180  	return false
   181  }
   182  
   183  // GetErrors returns a list of errors that occurred in the current field
   184  func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
   185  	c.errorsMu.Lock()
   186  	defer c.errorsMu.Unlock()
   187  	path := rctx.Path()
   188  
   189  	var errs gqlerror.List
   190  	for _, err := range c.Errors {
   191  		if equalPath(err.Path, path) {
   192  			errs = append(errs, err)
   193  		}
   194  	}
   195  	return errs
   196  }
   197  
   198  func equalPath(a []interface{}, b []interface{}) bool {
   199  	if len(a) != len(b) {
   200  		return false
   201  	}
   202  
   203  	for i := 0; i < len(a); i++ {
   204  		if a[i] != b[i] {
   205  			return false
   206  		}
   207  	}
   208  
   209  	return true
   210  }
   211  
   212  // AddError is a convenience method for adding an error to the current response
   213  func AddError(ctx context.Context, err error) {
   214  	GetRequestContext(ctx).Error(ctx, err)
   215  }
   216  
   217  // AddErrorf is a convenience method for adding an error to the current response
   218  func AddErrorf(ctx context.Context, format string, args ...interface{}) {
   219  	GetRequestContext(ctx).Errorf(ctx, format, args...)
   220  }
   221  
   222  // RegisterExtension registers an extension, returns error if extension has already been registered
   223  func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
   224  	c.extensionsMu.Lock()
   225  	defer c.extensionsMu.Unlock()
   226  
   227  	if c.Extensions == nil {
   228  		c.Extensions = make(map[string]interface{})
   229  	}
   230  
   231  	if _, ok := c.Extensions[key]; ok {
   232  		return fmt.Errorf("extension already registered for key %s", key)
   233  	}
   234  
   235  	c.Extensions[key] = value
   236  	return nil
   237  }
   238  
   239  // ChainFieldMiddleware add chain by FieldMiddleware
   240  func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
   241  	n := len(handleFunc)
   242  
   243  	if n > 1 {
   244  		lastI := n - 1
   245  		return func(ctx context.Context, next Resolver) (interface{}, error) {
   246  			var (
   247  				chainHandler Resolver
   248  				curI         int
   249  			)
   250  			chainHandler = func(currentCtx context.Context) (interface{}, error) {
   251  				if curI == lastI {
   252  					return next(currentCtx)
   253  				}
   254  				curI++
   255  				res, err := handleFunc[curI](currentCtx, chainHandler)
   256  				curI--
   257  				return res, err
   258  
   259  			}
   260  			return handleFunc[0](ctx, chainHandler)
   261  		}
   262  	}
   263  
   264  	if n == 1 {
   265  		return handleFunc[0]
   266  	}
   267  
   268  	return func(ctx context.Context, next Resolver) (interface{}, error) {
   269  		return next(ctx)
   270  	}
   271  }