github.com/robertoortis/gqlgenm@v0.7.2/graphql/context.go (about) 1 package graphql 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 8 "github.com/vektah/gqlparser/ast" 9 "github.com/vektah/gqlparser/gqlerror" 10 ) 11 12 type Resolver func(ctx context.Context) (res interface{}, err error) 13 type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error) 14 type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte 15 16 type RequestContext struct { 17 RawQuery string 18 Variables map[string]interface{} 19 Doc *ast.QueryDocument 20 21 ComplexityLimit int 22 OperationComplexity int 23 DisableIntrospection bool 24 25 // ErrorPresenter will be used to generate the error 26 // message from errors given to Error(). 27 ErrorPresenter ErrorPresenterFunc 28 Recover RecoverFunc 29 ResolverMiddleware FieldMiddleware 30 DirectiveMiddleware FieldMiddleware 31 RequestMiddleware RequestMiddleware 32 Tracer Tracer 33 34 errorsMu sync.Mutex 35 Errors gqlerror.List 36 extensionsMu sync.Mutex 37 Extensions map[string]interface{} 38 } 39 40 func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) { 41 return next(ctx) 42 } 43 44 func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) { 45 return next(ctx) 46 } 47 48 func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte { 49 return next(ctx) 50 } 51 52 func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext { 53 return &RequestContext{ 54 Doc: doc, 55 RawQuery: query, 56 Variables: variables, 57 ResolverMiddleware: DefaultResolverMiddleware, 58 DirectiveMiddleware: DefaultDirectiveMiddleware, 59 RequestMiddleware: DefaultRequestMiddleware, 60 Recover: DefaultRecover, 61 ErrorPresenter: DefaultErrorPresenter, 62 Tracer: &NopTracer{}, 63 } 64 } 65 66 type key string 67 68 const ( 69 request key = "request_context" 70 resolver key = "resolver_context" 71 ) 72 73 func GetRequestContext(ctx context.Context) *RequestContext { 74 val := ctx.Value(request) 75 if val == nil { 76 return nil 77 } 78 79 return val.(*RequestContext) 80 } 81 82 func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context { 83 return context.WithValue(ctx, request, rc) 84 } 85 86 type ResolverContext struct { 87 Parent *ResolverContext 88 // The name of the type this field belongs to 89 Object string 90 // These are the args after processing, they can be mutated in middleware to change what the resolver will get. 91 Args map[string]interface{} 92 // The raw field 93 Field CollectedField 94 // The index of array in path. 95 Index *int 96 // The result object of resolver 97 Result interface{} 98 } 99 100 func (r *ResolverContext) Path() []interface{} { 101 var path []interface{} 102 for it := r; it != nil; it = it.Parent { 103 if it.Index != nil { 104 path = append(path, *it.Index) 105 } else if it.Field.Field != nil { 106 path = append(path, it.Field.Alias) 107 } 108 } 109 110 // because we are walking up the chain, all the elements are backwards, do an inplace flip. 111 for i := len(path)/2 - 1; i >= 0; i-- { 112 opp := len(path) - 1 - i 113 path[i], path[opp] = path[opp], path[i] 114 } 115 116 return path 117 } 118 119 func GetResolverContext(ctx context.Context) *ResolverContext { 120 val, _ := ctx.Value(resolver).(*ResolverContext) 121 return val 122 } 123 124 func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context { 125 rc.Parent = GetResolverContext(ctx) 126 return context.WithValue(ctx, resolver, rc) 127 } 128 129 // This is just a convenient wrapper method for CollectFields 130 func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField { 131 resctx := GetResolverContext(ctx) 132 return CollectFields(ctx, resctx.Field.Selections, satisfies) 133 } 134 135 // Errorf sends an error string to the client, passing it through the formatter. 136 func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) { 137 c.errorsMu.Lock() 138 defer c.errorsMu.Unlock() 139 140 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...))) 141 } 142 143 // Error sends an error to the client, passing it through the formatter. 144 func (c *RequestContext) Error(ctx context.Context, err error) { 145 c.errorsMu.Lock() 146 defer c.errorsMu.Unlock() 147 148 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err)) 149 } 150 151 // HasError returns true if the current field has already errored 152 func (c *RequestContext) HasError(rctx *ResolverContext) bool { 153 c.errorsMu.Lock() 154 defer c.errorsMu.Unlock() 155 path := rctx.Path() 156 157 for _, err := range c.Errors { 158 if equalPath(err.Path, path) { 159 return true 160 } 161 } 162 return false 163 } 164 165 // GetErrors returns a list of errors that occurred in the current field 166 func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List { 167 c.errorsMu.Lock() 168 defer c.errorsMu.Unlock() 169 path := rctx.Path() 170 171 var errs gqlerror.List 172 for _, err := range c.Errors { 173 if equalPath(err.Path, path) { 174 errs = append(errs, err) 175 } 176 } 177 return errs 178 } 179 180 func equalPath(a []interface{}, b []interface{}) bool { 181 if len(a) != len(b) { 182 return false 183 } 184 185 for i := 0; i < len(a); i++ { 186 if a[i] != b[i] { 187 return false 188 } 189 } 190 191 return true 192 } 193 194 // AddError is a convenience method for adding an error to the current response 195 func AddError(ctx context.Context, err error) { 196 GetRequestContext(ctx).Error(ctx, err) 197 } 198 199 // AddErrorf is a convenience method for adding an error to the current response 200 func AddErrorf(ctx context.Context, format string, args ...interface{}) { 201 GetRequestContext(ctx).Errorf(ctx, format, args...) 202 } 203 204 // RegisterExtension registers an extension, returns error if extension has already been registered 205 func (c *RequestContext) RegisterExtension(key string, value interface{}) error { 206 c.extensionsMu.Lock() 207 defer c.extensionsMu.Unlock() 208 209 if c.Extensions == nil { 210 c.Extensions = make(map[string]interface{}) 211 } 212 213 if _, ok := c.Extensions[key]; ok { 214 return fmt.Errorf("extension already registered for key %s", key) 215 } 216 217 c.Extensions[key] = value 218 return nil 219 }