github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/graphql/executor/executor.go (about) 1 package executor 2 3 import ( 4 "context" 5 6 "github.com/spread-ai/gqlgen/graphql" 7 "github.com/spread-ai/gqlgen/graphql/errcode" 8 "github.com/vektah/gqlparser/v2/ast" 9 "github.com/vektah/gqlparser/v2/gqlerror" 10 "github.com/vektah/gqlparser/v2/parser" 11 "github.com/vektah/gqlparser/v2/validator" 12 ) 13 14 // Executor executes graphql queries against a schema. 15 type Executor struct { 16 es graphql.ExecutableSchema 17 extensions []graphql.HandlerExtension 18 ext extensions 19 20 errorPresenter graphql.ErrorPresenterFunc 21 recoverFunc graphql.RecoverFunc 22 queryCache graphql.Cache 23 } 24 25 var _ graphql.GraphExecutor = &Executor{} 26 27 // New creates a new Executor with the given schema, and a default error and 28 // recovery callbacks, and no query cache or extensions. 29 func New(es graphql.ExecutableSchema) *Executor { 30 e := &Executor{ 31 es: es, 32 errorPresenter: graphql.DefaultErrorPresenter, 33 recoverFunc: graphql.DefaultRecover, 34 queryCache: graphql.NoCache{}, 35 ext: processExtensions(nil), 36 } 37 return e 38 } 39 40 func (e *Executor) CreateOperationContext( 41 ctx context.Context, 42 params *graphql.RawParams, 43 ) (*graphql.OperationContext, gqlerror.List) { 44 rc := &graphql.OperationContext{ 45 DisableIntrospection: true, 46 RecoverFunc: e.recoverFunc, 47 ResolverMiddleware: e.ext.fieldMiddleware, 48 RootResolverMiddleware: e.ext.rootFieldMiddleware, 49 Stats: graphql.Stats{ 50 Read: params.ReadTime, 51 OperationStart: graphql.GetStartTime(ctx), 52 }, 53 } 54 ctx = graphql.WithOperationContext(ctx, rc) 55 56 for _, p := range e.ext.operationParameterMutators { 57 if err := p.MutateOperationParameters(ctx, params); err != nil { 58 return rc, gqlerror.List{err} 59 } 60 } 61 62 rc.RawQuery = params.Query 63 rc.OperationName = params.OperationName 64 rc.Headers = params.Headers 65 66 var listErr gqlerror.List 67 rc.Doc, listErr = e.parseQuery(ctx, &rc.Stats, params.Query) 68 if len(listErr) != 0 { 69 return rc, listErr 70 } 71 72 rc.Operation = rc.Doc.Operations.ForName(params.OperationName) 73 if rc.Operation == nil { 74 err := gqlerror.Errorf("operation %s not found", params.OperationName) 75 errcode.Set(err, errcode.ValidationFailed) 76 return rc, gqlerror.List{err} 77 } 78 79 var err error 80 rc.Variables, err = validator.VariableValues(e.es.Schema(), rc.Operation, params.Variables) 81 82 if err != nil { 83 gqlErr, ok := err.(*gqlerror.Error) 84 if ok { 85 errcode.Set(gqlErr, errcode.ValidationFailed) 86 return rc, gqlerror.List{gqlErr} 87 } 88 } 89 rc.Stats.Validation.End = graphql.Now() 90 91 for _, p := range e.ext.operationContextMutators { 92 if err := p.MutateOperationContext(ctx, rc); err != nil { 93 return rc, gqlerror.List{err} 94 } 95 } 96 97 return rc, nil 98 } 99 100 func (e *Executor) DispatchOperation( 101 ctx context.Context, 102 rc *graphql.OperationContext, 103 ) (graphql.ResponseHandler, context.Context) { 104 ctx = graphql.WithOperationContext(ctx, rc) 105 106 var innerCtx context.Context 107 res := e.ext.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler { 108 innerCtx = ctx 109 110 tmpResponseContext := graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) 111 responses := e.es.Exec(tmpResponseContext) 112 if errs := graphql.GetErrors(tmpResponseContext); errs != nil { 113 return graphql.OneShot(&graphql.Response{Errors: errs}) 114 } 115 116 return func(ctx context.Context) *graphql.Response { 117 ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) 118 resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { 119 resp := responses(ctx) 120 if resp == nil { 121 return nil 122 } 123 resp.Errors = append(resp.Errors, graphql.GetErrors(ctx)...) 124 resp.Extensions = graphql.GetExtensions(ctx) 125 return resp 126 }) 127 if resp == nil { 128 return nil 129 } 130 131 return resp 132 } 133 }) 134 135 return res, innerCtx 136 } 137 138 func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graphql.Response { 139 ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) 140 for _, gErr := range list { 141 graphql.AddError(ctx, gErr) 142 } 143 144 resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { 145 resp := &graphql.Response{ 146 Errors: graphql.GetErrors(ctx), 147 } 148 resp.Extensions = graphql.GetExtensions(ctx) 149 return resp 150 }) 151 152 return resp 153 } 154 155 func (e *Executor) PresentRecoveredError(ctx context.Context, err interface{}) error { 156 return e.errorPresenter(ctx, e.recoverFunc(ctx, err)) 157 } 158 159 func (e *Executor) SetQueryCache(cache graphql.Cache) { 160 e.queryCache = cache 161 } 162 163 func (e *Executor) SetErrorPresenter(f graphql.ErrorPresenterFunc) { 164 e.errorPresenter = f 165 } 166 167 func (e *Executor) SetRecoverFunc(f graphql.RecoverFunc) { 168 e.recoverFunc = f 169 } 170 171 // parseQuery decodes the incoming query and validates it, pulling from cache if present. 172 // 173 // NOTE: This should NOT look at variables, they will change per request. It should only parse and 174 // validate 175 // the raw query string. 176 func (e *Executor) parseQuery( 177 ctx context.Context, 178 stats *graphql.Stats, 179 query string, 180 ) (*ast.QueryDocument, gqlerror.List) { 181 stats.Parsing.Start = graphql.Now() 182 183 if doc, ok := e.queryCache.Get(ctx, query); ok { 184 now := graphql.Now() 185 186 stats.Parsing.End = now 187 stats.Validation.Start = now 188 return doc.(*ast.QueryDocument), nil 189 } 190 191 doc, err := parser.ParseQuery(&ast.Source{Input: query}) 192 if err != nil { 193 gqlErr, ok := err.(*gqlerror.Error) 194 if ok { 195 errcode.Set(gqlErr, errcode.ParseFailed) 196 return nil, gqlerror.List{gqlErr} 197 } 198 } 199 stats.Parsing.End = graphql.Now() 200 201 stats.Validation.Start = graphql.Now() 202 203 if len(doc.Operations) == 0 { 204 err = gqlerror.Errorf("no operation provided") 205 gqlErr, _ := err.(*gqlerror.Error) 206 errcode.Set(err, errcode.ValidationFailed) 207 return nil, gqlerror.List{gqlErr} 208 } 209 210 listErr := validator.Validate(e.es.Schema(), doc) 211 if len(listErr) != 0 { 212 for _, e := range listErr { 213 errcode.Set(e, errcode.ValidationFailed) 214 } 215 return nil, listErr 216 } 217 218 e.queryCache.Add(ctx, query, doc) 219 220 return doc, nil 221 }