github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/graphql/context.go (about) 1 package graphql 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 8 "github.com/vektah/gqlparser/ast" 9 "github.com/vektah/gqlparser/gqlerror" 10 ) 11 12 type Resolver func(ctx context.Context) (res interface{}, err error) 13 type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error) 14 type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte 15 16 type RequestContext struct { 17 RawQuery string 18 Variables map[string]interface{} 19 Doc *ast.QueryDocument 20 21 ComplexityLimit int 22 OperationComplexity int 23 DisableIntrospection bool 24 25 // ErrorPresenter will be used to generate the error 26 // message from errors given to Error(). 27 ErrorPresenter ErrorPresenterFunc 28 Recover RecoverFunc 29 ResolverMiddleware FieldMiddleware 30 DirectiveMiddleware FieldMiddleware 31 RequestMiddleware RequestMiddleware 32 Tracer Tracer 33 34 errorsMu sync.Mutex 35 Errors gqlerror.List 36 extensionsMu sync.Mutex 37 Extensions map[string]interface{} 38 } 39 40 func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) { 41 return next(ctx) 42 } 43 44 func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) { 45 return next(ctx) 46 } 47 48 func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte { 49 return next(ctx) 50 } 51 52 func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext { 53 return &RequestContext{ 54 Doc: doc, 55 RawQuery: query, 56 Variables: variables, 57 ResolverMiddleware: DefaultResolverMiddleware, 58 DirectiveMiddleware: DefaultDirectiveMiddleware, 59 RequestMiddleware: DefaultRequestMiddleware, 60 Recover: DefaultRecover, 61 ErrorPresenter: DefaultErrorPresenter, 62 Tracer: &NopTracer{}, 63 } 64 } 65 66 type key string 67 68 const ( 69 request key = "request_context" 70 resolver key = "resolver_context" 71 ) 72 73 func GetRequestContext(ctx context.Context) *RequestContext { 74 if val, ok := ctx.Value(request).(*RequestContext); ok { 75 return val 76 } 77 return nil 78 } 79 80 func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context { 81 return context.WithValue(ctx, request, rc) 82 } 83 84 type ResolverContext struct { 85 Parent *ResolverContext 86 // The name of the type this field belongs to 87 Object string 88 // These are the args after processing, they can be mutated in middleware to change what the resolver will get. 89 Args map[string]interface{} 90 // The raw field 91 Field CollectedField 92 // The index of array in path. 93 Index *int 94 // The result object of resolver 95 Result interface{} 96 } 97 98 func (r *ResolverContext) Path() []interface{} { 99 var path []interface{} 100 for it := r; it != nil; it = it.Parent { 101 if it.Index != nil { 102 path = append(path, *it.Index) 103 } else if it.Field.Field != nil { 104 path = append(path, it.Field.Alias) 105 } 106 } 107 108 // because we are walking up the chain, all the elements are backwards, do an inplace flip. 109 for i := len(path)/2 - 1; i >= 0; i-- { 110 opp := len(path) - 1 - i 111 path[i], path[opp] = path[opp], path[i] 112 } 113 114 return path 115 } 116 117 func GetResolverContext(ctx context.Context) *ResolverContext { 118 if val, ok := ctx.Value(resolver).(*ResolverContext); ok { 119 return val 120 } 121 return nil 122 } 123 124 func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context { 125 rc.Parent = GetResolverContext(ctx) 126 return context.WithValue(ctx, resolver, rc) 127 } 128 129 // This is just a convenient wrapper method for CollectFields 130 func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField { 131 resctx := GetResolverContext(ctx) 132 return CollectFields(ctx, resctx.Field.Selections, satisfies) 133 } 134 135 // CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context. 136 // The slice will contain the unique set of all field names requested regardless of fragment type conditions. 137 func CollectAllFields(ctx context.Context) []string { 138 resctx := GetResolverContext(ctx) 139 collected := CollectFields(ctx, resctx.Field.Selections, nil) 140 uniq := make([]string, 0, len(collected)) 141 Next: 142 for _, f := range collected { 143 for _, name := range uniq { 144 if name == f.Name { 145 continue Next 146 } 147 } 148 uniq = append(uniq, f.Name) 149 } 150 return uniq 151 } 152 153 // Errorf sends an error string to the client, passing it through the formatter. 154 func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) { 155 c.errorsMu.Lock() 156 defer c.errorsMu.Unlock() 157 158 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...))) 159 } 160 161 // Error sends an error to the client, passing it through the formatter. 162 func (c *RequestContext) Error(ctx context.Context, err error) { 163 c.errorsMu.Lock() 164 defer c.errorsMu.Unlock() 165 166 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err)) 167 } 168 169 // HasError returns true if the current field has already errored 170 func (c *RequestContext) HasError(rctx *ResolverContext) bool { 171 c.errorsMu.Lock() 172 defer c.errorsMu.Unlock() 173 path := rctx.Path() 174 175 for _, err := range c.Errors { 176 if equalPath(err.Path, path) { 177 return true 178 } 179 } 180 return false 181 } 182 183 // GetErrors returns a list of errors that occurred in the current field 184 func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List { 185 c.errorsMu.Lock() 186 defer c.errorsMu.Unlock() 187 path := rctx.Path() 188 189 var errs gqlerror.List 190 for _, err := range c.Errors { 191 if equalPath(err.Path, path) { 192 errs = append(errs, err) 193 } 194 } 195 return errs 196 } 197 198 func equalPath(a []interface{}, b []interface{}) bool { 199 if len(a) != len(b) { 200 return false 201 } 202 203 for i := 0; i < len(a); i++ { 204 if a[i] != b[i] { 205 return false 206 } 207 } 208 209 return true 210 } 211 212 // AddError is a convenience method for adding an error to the current response 213 func AddError(ctx context.Context, err error) { 214 GetRequestContext(ctx).Error(ctx, err) 215 } 216 217 // AddErrorf is a convenience method for adding an error to the current response 218 func AddErrorf(ctx context.Context, format string, args ...interface{}) { 219 GetRequestContext(ctx).Errorf(ctx, format, args...) 220 } 221 222 // RegisterExtension registers an extension, returns error if extension has already been registered 223 func (c *RequestContext) RegisterExtension(key string, value interface{}) error { 224 c.extensionsMu.Lock() 225 defer c.extensionsMu.Unlock() 226 227 if c.Extensions == nil { 228 c.Extensions = make(map[string]interface{}) 229 } 230 231 if _, ok := c.Extensions[key]; ok { 232 return fmt.Errorf("extension already registered for key %s", key) 233 } 234 235 c.Extensions[key] = value 236 return nil 237 } 238 239 // ChainFieldMiddleware add chain by FieldMiddleware 240 func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware { 241 n := len(handleFunc) 242 243 if n > 1 { 244 lastI := n - 1 245 return func(ctx context.Context, next Resolver) (interface{}, error) { 246 var ( 247 chainHandler Resolver 248 curI int 249 ) 250 chainHandler = func(currentCtx context.Context) (interface{}, error) { 251 if curI == lastI { 252 return next(currentCtx) 253 } 254 curI++ 255 res, err := handleFunc[curI](currentCtx, chainHandler) 256 curI-- 257 return res, err 258 259 } 260 return handleFunc[0](ctx, chainHandler) 261 } 262 } 263 264 if n == 1 { 265 return handleFunc[0] 266 } 267 268 return func(ctx context.Context, next Resolver) (interface{}, error) { 269 return next(ctx) 270 } 271 }