github.com/wilhelmeek/gqlgen@v0.7.2/graphql/context.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/vektah/gqlparser/ast"
     9  	"github.com/vektah/gqlparser/gqlerror"
    10  )
    11  
    12  type Resolver func(ctx context.Context) (res interface{}, err error)
    13  type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
    14  type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
    15  
    16  type RequestContext struct {
    17  	RawQuery  string
    18  	Variables map[string]interface{}
    19  	Doc       *ast.QueryDocument
    20  
    21  	ComplexityLimit      int
    22  	OperationComplexity  int
    23  	DisableIntrospection bool
    24  
    25  	// ErrorPresenter will be used to generate the error
    26  	// message from errors given to Error().
    27  	ErrorPresenter      ErrorPresenterFunc
    28  	Recover             RecoverFunc
    29  	ResolverMiddleware  FieldMiddleware
    30  	DirectiveMiddleware FieldMiddleware
    31  	RequestMiddleware   RequestMiddleware
    32  	Tracer              Tracer
    33  
    34  	errorsMu     sync.Mutex
    35  	Errors       gqlerror.List
    36  	extensionsMu sync.Mutex
    37  	Extensions   map[string]interface{}
    38  }
    39  
    40  func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
    41  	return next(ctx)
    42  }
    43  
    44  func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
    45  	return next(ctx)
    46  }
    47  
    48  func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
    49  	return next(ctx)
    50  }
    51  
    52  func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
    53  	return &RequestContext{
    54  		Doc:                 doc,
    55  		RawQuery:            query,
    56  		Variables:           variables,
    57  		ResolverMiddleware:  DefaultResolverMiddleware,
    58  		DirectiveMiddleware: DefaultDirectiveMiddleware,
    59  		RequestMiddleware:   DefaultRequestMiddleware,
    60  		Recover:             DefaultRecover,
    61  		ErrorPresenter:      DefaultErrorPresenter,
    62  		Tracer:              &NopTracer{},
    63  	}
    64  }
    65  
    66  type key string
    67  
    68  const (
    69  	request  key = "request_context"
    70  	resolver key = "resolver_context"
    71  )
    72  
    73  func GetRequestContext(ctx context.Context) *RequestContext {
    74  	val := ctx.Value(request)
    75  	if val == nil {
    76  		return nil
    77  	}
    78  
    79  	return val.(*RequestContext)
    80  }
    81  
    82  func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
    83  	return context.WithValue(ctx, request, rc)
    84  }
    85  
    86  type ResolverContext struct {
    87  	Parent *ResolverContext
    88  	// The name of the type this field belongs to
    89  	Object string
    90  	// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
    91  	Args map[string]interface{}
    92  	// The raw field
    93  	Field CollectedField
    94  	// The index of array in path.
    95  	Index *int
    96  	// The result object of resolver
    97  	Result interface{}
    98  }
    99  
   100  func (r *ResolverContext) Path() []interface{} {
   101  	var path []interface{}
   102  	for it := r; it != nil; it = it.Parent {
   103  		if it.Index != nil {
   104  			path = append(path, *it.Index)
   105  		} else if it.Field.Field != nil {
   106  			path = append(path, it.Field.Alias)
   107  		}
   108  	}
   109  
   110  	// because we are walking up the chain, all the elements are backwards, do an inplace flip.
   111  	for i := len(path)/2 - 1; i >= 0; i-- {
   112  		opp := len(path) - 1 - i
   113  		path[i], path[opp] = path[opp], path[i]
   114  	}
   115  
   116  	return path
   117  }
   118  
   119  func GetResolverContext(ctx context.Context) *ResolverContext {
   120  	val, _ := ctx.Value(resolver).(*ResolverContext)
   121  	return val
   122  }
   123  
   124  func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
   125  	rc.Parent = GetResolverContext(ctx)
   126  	return context.WithValue(ctx, resolver, rc)
   127  }
   128  
   129  // This is just a convenient wrapper method for CollectFields
   130  func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
   131  	resctx := GetResolverContext(ctx)
   132  	return CollectFields(ctx, resctx.Field.Selections, satisfies)
   133  }
   134  
   135  // Errorf sends an error string to the client, passing it through the formatter.
   136  func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
   137  	c.errorsMu.Lock()
   138  	defer c.errorsMu.Unlock()
   139  
   140  	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
   141  }
   142  
   143  // Error sends an error to the client, passing it through the formatter.
   144  func (c *RequestContext) Error(ctx context.Context, err error) {
   145  	c.errorsMu.Lock()
   146  	defer c.errorsMu.Unlock()
   147  
   148  	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
   149  }
   150  
   151  // HasError returns true if the current field has already errored
   152  func (c *RequestContext) HasError(rctx *ResolverContext) bool {
   153  	c.errorsMu.Lock()
   154  	defer c.errorsMu.Unlock()
   155  	path := rctx.Path()
   156  
   157  	for _, err := range c.Errors {
   158  		if equalPath(err.Path, path) {
   159  			return true
   160  		}
   161  	}
   162  	return false
   163  }
   164  
   165  // GetErrors returns a list of errors that occurred in the current field
   166  func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
   167  	c.errorsMu.Lock()
   168  	defer c.errorsMu.Unlock()
   169  	path := rctx.Path()
   170  
   171  	var errs gqlerror.List
   172  	for _, err := range c.Errors {
   173  		if equalPath(err.Path, path) {
   174  			errs = append(errs, err)
   175  		}
   176  	}
   177  	return errs
   178  }
   179  
   180  func equalPath(a []interface{}, b []interface{}) bool {
   181  	if len(a) != len(b) {
   182  		return false
   183  	}
   184  
   185  	for i := 0; i < len(a); i++ {
   186  		if a[i] != b[i] {
   187  			return false
   188  		}
   189  	}
   190  
   191  	return true
   192  }
   193  
   194  // AddError is a convenience method for adding an error to the current response
   195  func AddError(ctx context.Context, err error) {
   196  	GetRequestContext(ctx).Error(ctx, err)
   197  }
   198  
   199  // AddErrorf is a convenience method for adding an error to the current response
   200  func AddErrorf(ctx context.Context, format string, args ...interface{}) {
   201  	GetRequestContext(ctx).Errorf(ctx, format, args...)
   202  }
   203  
   204  // RegisterExtension registers an extension, returns error if extension has already been registered
   205  func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
   206  	c.extensionsMu.Lock()
   207  	defer c.extensionsMu.Unlock()
   208  
   209  	if c.Extensions == nil {
   210  		c.Extensions = make(map[string]interface{})
   211  	}
   212  
   213  	if _, ok := c.Extensions[key]; ok {
   214  		return fmt.Errorf("extension already registered for key %s", key)
   215  	}
   216  
   217  	c.Extensions[key] = value
   218  	return nil
   219  }