github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/handler/graphql.go (about) 1 package handler 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "time" 11 12 "github.com/HaswinVidanage/gqlgen/complexity" 13 "github.com/HaswinVidanage/gqlgen/graphql" 14 "github.com/gorilla/websocket" 15 "github.com/hashicorp/golang-lru" 16 "github.com/vektah/gqlparser/ast" 17 "github.com/vektah/gqlparser/gqlerror" 18 "github.com/vektah/gqlparser/parser" 19 "github.com/vektah/gqlparser/validator" 20 ) 21 22 type params struct { 23 Query string `json:"query"` 24 OperationName string `json:"operationName"` 25 Variables map[string]interface{} `json:"variables"` 26 } 27 28 type Config struct { 29 cacheSize int 30 upgrader websocket.Upgrader 31 recover graphql.RecoverFunc 32 errorPresenter graphql.ErrorPresenterFunc 33 resolverHook graphql.FieldMiddleware 34 requestHook graphql.RequestMiddleware 35 tracer graphql.Tracer 36 complexityLimit int 37 disableIntrospection bool 38 connectionKeepAlivePingInterval time.Duration 39 } 40 41 func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { 42 reqCtx := graphql.NewRequestContext(doc, query, variables) 43 reqCtx.DisableIntrospection = c.disableIntrospection 44 45 if hook := c.recover; hook != nil { 46 reqCtx.Recover = hook 47 } 48 49 if hook := c.errorPresenter; hook != nil { 50 reqCtx.ErrorPresenter = hook 51 } 52 53 if hook := c.resolverHook; hook != nil { 54 reqCtx.ResolverMiddleware = hook 55 } 56 57 if hook := c.requestHook; hook != nil { 58 reqCtx.RequestMiddleware = hook 59 } 60 61 if hook := c.tracer; hook != nil { 62 reqCtx.Tracer = hook 63 } 64 65 if c.complexityLimit > 0 { 66 reqCtx.ComplexityLimit = c.complexityLimit 67 operationComplexity := complexity.Calculate(es, op, variables) 68 reqCtx.OperationComplexity = operationComplexity 69 } 70 71 return reqCtx 72 } 73 74 type Option func(cfg *Config) 75 76 func WebsocketUpgrader(upgrader websocket.Upgrader) Option { 77 return func(cfg *Config) { 78 cfg.upgrader = upgrader 79 } 80 } 81 82 func RecoverFunc(recover graphql.RecoverFunc) Option { 83 return func(cfg *Config) { 84 cfg.recover = recover 85 } 86 } 87 88 // ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides 89 // a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default 90 // implementation in graphql.DefaultErrorPresenter for an example. 91 func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { 92 return func(cfg *Config) { 93 cfg.errorPresenter = f 94 } 95 } 96 97 // IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont 98 // want clients introspecting the full schema. 99 func IntrospectionEnabled(enabled bool) Option { 100 return func(cfg *Config) { 101 cfg.disableIntrospection = !enabled 102 } 103 } 104 105 // ComplexityLimit sets a maximum query complexity that is allowed to be executed. 106 // If a query is submitted that exceeds the limit, a 422 status code will be returned. 107 func ComplexityLimit(limit int) Option { 108 return func(cfg *Config) { 109 cfg.complexityLimit = limit 110 } 111 } 112 113 // ResolverMiddleware allows you to define a function that will be called around every resolver, 114 // useful for logging. 115 func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { 116 return func(cfg *Config) { 117 if cfg.resolverHook == nil { 118 cfg.resolverHook = middleware 119 return 120 } 121 122 lastResolve := cfg.resolverHook 123 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 124 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) { 125 return middleware(ctx, next) 126 }) 127 } 128 } 129 } 130 131 // RequestMiddleware allows you to define a function that will be called around the root request, 132 // after the query has been parsed. This is useful for logging 133 func RequestMiddleware(middleware graphql.RequestMiddleware) Option { 134 return func(cfg *Config) { 135 if cfg.requestHook == nil { 136 cfg.requestHook = middleware 137 return 138 } 139 140 lastResolve := cfg.requestHook 141 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { 142 return lastResolve(ctx, func(ctx context.Context) []byte { 143 return middleware(ctx, next) 144 }) 145 } 146 } 147 } 148 149 // Tracer allows you to add a request/resolver tracer that will be called around the root request, 150 // calling resolver. This is useful for tracing 151 func Tracer(tracer graphql.Tracer) Option { 152 return func(cfg *Config) { 153 if cfg.tracer == nil { 154 cfg.tracer = tracer 155 156 } else { 157 lastResolve := cfg.tracer 158 cfg.tracer = &tracerWrapper{ 159 tracer1: lastResolve, 160 tracer2: tracer, 161 } 162 } 163 164 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { 165 ctx = tracer.StartOperationExecution(ctx) 166 resp := next(ctx) 167 tracer.EndOperationExecution(ctx) 168 169 return resp 170 }) 171 opt(cfg) 172 } 173 } 174 175 type tracerWrapper struct { 176 tracer1 graphql.Tracer 177 tracer2 graphql.Tracer 178 } 179 180 func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context { 181 ctx = tw.tracer1.StartOperationParsing(ctx) 182 ctx = tw.tracer2.StartOperationParsing(ctx) 183 return ctx 184 } 185 186 func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) { 187 tw.tracer2.EndOperationParsing(ctx) 188 tw.tracer1.EndOperationParsing(ctx) 189 } 190 191 func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context { 192 ctx = tw.tracer1.StartOperationValidation(ctx) 193 ctx = tw.tracer2.StartOperationValidation(ctx) 194 return ctx 195 } 196 197 func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) { 198 tw.tracer2.EndOperationValidation(ctx) 199 tw.tracer1.EndOperationValidation(ctx) 200 } 201 202 func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context { 203 ctx = tw.tracer1.StartOperationExecution(ctx) 204 ctx = tw.tracer2.StartOperationExecution(ctx) 205 return ctx 206 } 207 208 func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { 209 ctx = tw.tracer1.StartFieldExecution(ctx, field) 210 ctx = tw.tracer2.StartFieldExecution(ctx, field) 211 return ctx 212 } 213 214 func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { 215 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc) 216 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc) 217 return ctx 218 } 219 220 func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context { 221 ctx = tw.tracer1.StartFieldChildExecution(ctx) 222 ctx = tw.tracer2.StartFieldChildExecution(ctx) 223 return ctx 224 } 225 226 func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) { 227 tw.tracer2.EndFieldExecution(ctx) 228 tw.tracer1.EndFieldExecution(ctx) 229 } 230 231 func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { 232 tw.tracer2.EndOperationExecution(ctx) 233 tw.tracer1.EndOperationExecution(ctx) 234 } 235 236 // CacheSize sets the maximum size of the query cache. 237 // If size is less than or equal to 0, the cache is disabled. 238 func CacheSize(size int) Option { 239 return func(cfg *Config) { 240 cfg.cacheSize = size 241 } 242 } 243 244 // WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior. 245 // By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval 246 // duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive 247 // altogether. 248 func WebsocketKeepAliveDuration(duration time.Duration) Option { 249 return func(cfg *Config) { 250 cfg.connectionKeepAlivePingInterval = duration 251 } 252 } 253 254 const DefaultCacheSize = 1000 255 const DefaultConnectionKeepAlivePingInterval = 25 * time.Second 256 257 func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { 258 cfg := &Config{ 259 cacheSize: DefaultCacheSize, 260 connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval, 261 upgrader: websocket.Upgrader{ 262 ReadBufferSize: 1024, 263 WriteBufferSize: 1024, 264 }, 265 } 266 267 for _, option := range options { 268 option(cfg) 269 } 270 271 var cache *lru.Cache 272 if cfg.cacheSize > 0 { 273 var err error 274 cache, err = lru.New(cfg.cacheSize) 275 if err != nil { 276 // An error is only returned for non-positive cache size 277 // and we already checked for that. 278 panic("unexpected error creating cache: " + err.Error()) 279 } 280 } 281 if cfg.tracer == nil { 282 cfg.tracer = &graphql.NopTracer{} 283 } 284 285 handler := &graphqlHandler{ 286 cfg: cfg, 287 cache: cache, 288 exec: exec, 289 } 290 291 return handler.ServeHTTP 292 } 293 294 var _ http.Handler = (*graphqlHandler)(nil) 295 296 type graphqlHandler struct { 297 cfg *Config 298 cache *lru.Cache 299 exec graphql.ExecutableSchema 300 } 301 302 func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 303 if r.Method == http.MethodOptions { 304 w.Header().Set("Allow", "OPTIONS, GET, POST") 305 w.WriteHeader(http.StatusOK) 306 return 307 } 308 309 if strings.Contains(r.Header.Get("Upgrade"), "websocket") { 310 connectWs(gh.exec, w, r, gh.cfg, gh.cache) 311 return 312 } 313 314 ctx := r.Context() 315 316 w.Header().Set("Content-Type", "application/json") 317 var reqParams params 318 switch r.Method { 319 case http.MethodGet: 320 reqParams.Query = r.URL.Query().Get("query") 321 reqParams.OperationName = r.URL.Query().Get("operationName") 322 323 if variables := r.URL.Query().Get("variables"); variables != "" { 324 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { 325 gh.sendErrorf(ctx, w, http.StatusBadRequest, "variables could not be decoded") 326 return 327 } 328 } 329 case http.MethodPost: 330 if err := jsonDecode(r.Body, &reqParams); err != nil { 331 gh.sendErrorf(ctx, w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) 332 return 333 } 334 default: 335 w.WriteHeader(http.StatusMethodNotAllowed) 336 return 337 } 338 339 var doc *ast.QueryDocument 340 var cacheHit bool 341 if gh.cache != nil { 342 val, ok := gh.cache.Get(reqParams.Query) 343 if ok { 344 doc = val.(*ast.QueryDocument) 345 cacheHit = true 346 } 347 } 348 349 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{ 350 Query: reqParams.Query, 351 CachedDoc: doc, 352 }) 353 if gqlErr != nil { 354 gh.sendError(ctx, w, http.StatusUnprocessableEntity, gqlErr) 355 return 356 } 357 358 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{ 359 Doc: doc, 360 OperationName: reqParams.OperationName, 361 CacheHit: cacheHit, 362 R: r, 363 Variables: reqParams.Variables, 364 }) 365 if len(listErr) != 0 { 366 gh.sendError(ctx, w, http.StatusUnprocessableEntity, listErr...) 367 return 368 } 369 370 if gh.cache != nil && !cacheHit { 371 gh.cache.Add(reqParams.Query, doc) 372 } 373 374 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars) 375 ctx = graphql.WithRequestContext(ctx, reqCtx) 376 377 defer func() { 378 if err := recover(); err != nil { 379 userErr := reqCtx.Recover(ctx, err) 380 gh.sendErrorf(ctx, w, http.StatusUnprocessableEntity, userErr.Error()) 381 } 382 }() 383 384 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit { 385 gh.sendErrorf(ctx, w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit) 386 return 387 } 388 389 switch op.Operation { 390 case ast.Query: 391 b, err := json.Marshal(gh.exec.Query(ctx, op)) 392 if err != nil { 393 panic(err) 394 } 395 w.Write(b) 396 case ast.Mutation: 397 b, err := json.Marshal(gh.exec.Mutation(ctx, op)) 398 if err != nil { 399 panic(err) 400 } 401 w.Write(b) 402 default: 403 gh.sendErrorf(ctx, w, http.StatusBadRequest, "unsupported operation type") 404 } 405 } 406 407 type parseOperationArgs struct { 408 Query string 409 CachedDoc *ast.QueryDocument 410 } 411 412 func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) { 413 ctx = gh.cfg.tracer.StartOperationParsing(ctx) 414 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }() 415 416 if args.CachedDoc != nil { 417 return ctx, args.CachedDoc, nil 418 } 419 420 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query}) 421 if gqlErr != nil { 422 return ctx, nil, gqlErr 423 } 424 425 return ctx, doc, nil 426 } 427 428 type validateOperationArgs struct { 429 Doc *ast.QueryDocument 430 OperationName string 431 CacheHit bool 432 R *http.Request 433 Variables map[string]interface{} 434 } 435 436 func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) { 437 ctx = gh.cfg.tracer.StartOperationValidation(ctx) 438 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }() 439 440 if !args.CacheHit { 441 listErr := validator.Validate(gh.exec.Schema(), args.Doc) 442 if len(listErr) != 0 { 443 return ctx, nil, nil, listErr 444 } 445 } 446 447 op := args.Doc.Operations.ForName(args.OperationName) 448 if op == nil { 449 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)} 450 } 451 452 if op.Operation != ast.Query && args.R.Method == http.MethodGet { 453 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} 454 } 455 456 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables) 457 if err != nil { 458 return ctx, nil, nil, gqlerror.List{err} 459 } 460 461 return ctx, op, vars, nil 462 } 463 464 func jsonDecode(r io.Reader, val interface{}) error { 465 dec := json.NewDecoder(r) 466 dec.UseNumber() 467 return dec.Decode(val) 468 } 469 470 func (gh *graphqlHandler) sendError(ctx context.Context, w http.ResponseWriter, code int, errors ...*gqlerror.Error) { 471 if gh.cfg.errorPresenter != nil { 472 var newErrors []*gqlerror.Error 473 for _, err := range errors { 474 newErr := gh.cfg.errorPresenter(ctx, err) 475 newErrors = append(newErrors, newErr) 476 } 477 errors = newErrors 478 } 479 w.WriteHeader(code) 480 b, err := json.Marshal(&graphql.Response{Errors: errors}) 481 if err != nil { 482 panic(err) 483 } 484 w.Write(b) 485 } 486 487 func (gh *graphqlHandler) sendErrorf(ctx context.Context, w http.ResponseWriter, code int, format string, args ...interface{}) { 488 gh.sendError(ctx, w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 489 } 490 491 func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { 492 w.WriteHeader(code) 493 b, err := json.Marshal(&graphql.Response{Errors: errors}) 494 if err != nil { 495 panic(err) 496 } 497 w.Write(b) 498 } 499 500 func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) { 501 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 502 }