github.com/zikaeroh/gqlgen@v0.7.2/handler/graphql.go (about) 1 package handler 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "time" 11 12 "github.com/99designs/gqlgen/complexity" 13 "github.com/99designs/gqlgen/graphql" 14 "github.com/gorilla/websocket" 15 "github.com/hashicorp/golang-lru" 16 "github.com/vektah/gqlparser/ast" 17 "github.com/vektah/gqlparser/gqlerror" 18 "github.com/vektah/gqlparser/parser" 19 "github.com/vektah/gqlparser/validator" 20 ) 21 22 type params struct { 23 Query string `json:"query"` 24 OperationName string `json:"operationName"` 25 Variables map[string]interface{} `json:"variables"` 26 } 27 28 type Config struct { 29 cacheSize int 30 upgrader websocket.Upgrader 31 recover graphql.RecoverFunc 32 errorPresenter graphql.ErrorPresenterFunc 33 resolverHook graphql.FieldMiddleware 34 requestHook graphql.RequestMiddleware 35 tracer graphql.Tracer 36 complexityLimit int 37 disableIntrospection bool 38 connectionKeepAlivePingInterval time.Duration 39 } 40 41 func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { 42 reqCtx := graphql.NewRequestContext(doc, query, variables) 43 reqCtx.DisableIntrospection = c.disableIntrospection 44 45 if hook := c.recover; hook != nil { 46 reqCtx.Recover = hook 47 } 48 49 if hook := c.errorPresenter; hook != nil { 50 reqCtx.ErrorPresenter = hook 51 } 52 53 if hook := c.resolverHook; hook != nil { 54 reqCtx.ResolverMiddleware = hook 55 } 56 57 if hook := c.requestHook; hook != nil { 58 reqCtx.RequestMiddleware = hook 59 } 60 61 if hook := c.tracer; hook != nil { 62 reqCtx.Tracer = hook 63 } else { 64 reqCtx.Tracer = &graphql.NopTracer{} 65 } 66 67 if c.complexityLimit > 0 { 68 reqCtx.ComplexityLimit = c.complexityLimit 69 operationComplexity := complexity.Calculate(es, op, variables) 70 reqCtx.OperationComplexity = operationComplexity 71 } 72 73 return reqCtx 74 } 75 76 type Option func(cfg *Config) 77 78 func WebsocketUpgrader(upgrader websocket.Upgrader) Option { 79 return func(cfg *Config) { 80 cfg.upgrader = upgrader 81 } 82 } 83 84 func RecoverFunc(recover graphql.RecoverFunc) Option { 85 return func(cfg *Config) { 86 cfg.recover = recover 87 } 88 } 89 90 // ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides 91 // a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default 92 // implementation in graphql.DefaultErrorPresenter for an example. 93 func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { 94 return func(cfg *Config) { 95 cfg.errorPresenter = f 96 } 97 } 98 99 // IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont 100 // want clients introspecting the full schema. 101 func IntrospectionEnabled(enabled bool) Option { 102 return func(cfg *Config) { 103 cfg.disableIntrospection = !enabled 104 } 105 } 106 107 // ComplexityLimit sets a maximum query complexity that is allowed to be executed. 108 // If a query is submitted that exceeds the limit, a 422 status code will be returned. 109 func ComplexityLimit(limit int) Option { 110 return func(cfg *Config) { 111 cfg.complexityLimit = limit 112 } 113 } 114 115 // ResolverMiddleware allows you to define a function that will be called around every resolver, 116 // useful for logging. 117 func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { 118 return func(cfg *Config) { 119 if cfg.resolverHook == nil { 120 cfg.resolverHook = middleware 121 return 122 } 123 124 lastResolve := cfg.resolverHook 125 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 126 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) { 127 return middleware(ctx, next) 128 }) 129 } 130 } 131 } 132 133 // RequestMiddleware allows you to define a function that will be called around the root request, 134 // after the query has been parsed. This is useful for logging 135 func RequestMiddleware(middleware graphql.RequestMiddleware) Option { 136 return func(cfg *Config) { 137 if cfg.requestHook == nil { 138 cfg.requestHook = middleware 139 return 140 } 141 142 lastResolve := cfg.requestHook 143 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { 144 return lastResolve(ctx, func(ctx context.Context) []byte { 145 return middleware(ctx, next) 146 }) 147 } 148 } 149 } 150 151 // Tracer allows you to add a request/resolver tracer that will be called around the root request, 152 // calling resolver. This is useful for tracing 153 func Tracer(tracer graphql.Tracer) Option { 154 return func(cfg *Config) { 155 if cfg.tracer == nil { 156 cfg.tracer = tracer 157 158 } else { 159 lastResolve := cfg.tracer 160 cfg.tracer = &tracerWrapper{ 161 tracer1: lastResolve, 162 tracer2: tracer, 163 } 164 } 165 166 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { 167 ctx = tracer.StartOperationExecution(ctx) 168 resp := next(ctx) 169 tracer.EndOperationExecution(ctx) 170 171 return resp 172 }) 173 opt(cfg) 174 } 175 } 176 177 type tracerWrapper struct { 178 tracer1 graphql.Tracer 179 tracer2 graphql.Tracer 180 } 181 182 func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context { 183 ctx = tw.tracer1.StartOperationParsing(ctx) 184 ctx = tw.tracer2.StartOperationParsing(ctx) 185 return ctx 186 } 187 188 func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) { 189 tw.tracer2.EndOperationParsing(ctx) 190 tw.tracer1.EndOperationParsing(ctx) 191 } 192 193 func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context { 194 ctx = tw.tracer1.StartOperationValidation(ctx) 195 ctx = tw.tracer2.StartOperationValidation(ctx) 196 return ctx 197 } 198 199 func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) { 200 tw.tracer2.EndOperationValidation(ctx) 201 tw.tracer1.EndOperationValidation(ctx) 202 } 203 204 func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context { 205 ctx = tw.tracer1.StartOperationExecution(ctx) 206 ctx = tw.tracer2.StartOperationExecution(ctx) 207 return ctx 208 } 209 210 func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { 211 ctx = tw.tracer1.StartFieldExecution(ctx, field) 212 ctx = tw.tracer2.StartFieldExecution(ctx, field) 213 return ctx 214 } 215 216 func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { 217 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc) 218 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc) 219 return ctx 220 } 221 222 func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context { 223 ctx = tw.tracer1.StartFieldChildExecution(ctx) 224 ctx = tw.tracer2.StartFieldChildExecution(ctx) 225 return ctx 226 } 227 228 func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) { 229 tw.tracer2.EndFieldExecution(ctx) 230 tw.tracer1.EndFieldExecution(ctx) 231 } 232 233 func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { 234 tw.tracer2.EndOperationExecution(ctx) 235 tw.tracer1.EndOperationExecution(ctx) 236 } 237 238 // CacheSize sets the maximum size of the query cache. 239 // If size is less than or equal to 0, the cache is disabled. 240 func CacheSize(size int) Option { 241 return func(cfg *Config) { 242 cfg.cacheSize = size 243 } 244 } 245 246 const DefaultCacheSize = 1000 247 248 // WebsocketKeepAliveDuration allows you to reconfigure the keepAlive behavior. 249 // By default, keep-alive is disabled. 250 func WebsocketKeepAliveDuration(duration time.Duration) Option { 251 return func(cfg *Config) { 252 cfg.connectionKeepAlivePingInterval = duration 253 } 254 } 255 256 func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { 257 cfg := &Config{ 258 cacheSize: DefaultCacheSize, 259 upgrader: websocket.Upgrader{ 260 ReadBufferSize: 1024, 261 WriteBufferSize: 1024, 262 }, 263 } 264 265 for _, option := range options { 266 option(cfg) 267 } 268 269 var cache *lru.Cache 270 if cfg.cacheSize > 0 { 271 var err error 272 cache, err = lru.New(DefaultCacheSize) 273 if err != nil { 274 // An error is only returned for non-positive cache size 275 // and we already checked for that. 276 panic("unexpected error creating cache: " + err.Error()) 277 } 278 } 279 if cfg.tracer == nil { 280 cfg.tracer = &graphql.NopTracer{} 281 } 282 283 handler := &graphqlHandler{ 284 cfg: cfg, 285 cache: cache, 286 exec: exec, 287 } 288 289 return handler.ServeHTTP 290 } 291 292 var _ http.Handler = (*graphqlHandler)(nil) 293 294 type graphqlHandler struct { 295 cfg *Config 296 cache *lru.Cache 297 exec graphql.ExecutableSchema 298 } 299 300 func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 301 if r.Method == http.MethodOptions { 302 w.Header().Set("Allow", "OPTIONS, GET, POST") 303 w.WriteHeader(http.StatusOK) 304 return 305 } 306 307 if strings.Contains(r.Header.Get("Upgrade"), "websocket") { 308 connectWs(gh.exec, w, r, gh.cfg) 309 return 310 } 311 312 var reqParams params 313 switch r.Method { 314 case http.MethodGet: 315 reqParams.Query = r.URL.Query().Get("query") 316 reqParams.OperationName = r.URL.Query().Get("operationName") 317 318 if variables := r.URL.Query().Get("variables"); variables != "" { 319 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { 320 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") 321 return 322 } 323 } 324 case http.MethodPost: 325 if err := jsonDecode(r.Body, &reqParams); err != nil { 326 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) 327 return 328 } 329 default: 330 w.WriteHeader(http.StatusMethodNotAllowed) 331 return 332 } 333 w.Header().Set("Content-Type", "application/json") 334 335 ctx := r.Context() 336 337 var doc *ast.QueryDocument 338 var cacheHit bool 339 if gh.cache != nil { 340 val, ok := gh.cache.Get(reqParams.Query) 341 if ok { 342 doc = val.(*ast.QueryDocument) 343 cacheHit = true 344 } 345 } 346 347 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{ 348 Query: reqParams.Query, 349 CachedDoc: doc, 350 }) 351 if gqlErr != nil { 352 sendError(w, http.StatusUnprocessableEntity, gqlErr) 353 return 354 } 355 356 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{ 357 Doc: doc, 358 OperationName: reqParams.OperationName, 359 CacheHit: cacheHit, 360 R: r, 361 Variables: reqParams.Variables, 362 }) 363 if len(listErr) != 0 { 364 sendError(w, http.StatusUnprocessableEntity, listErr...) 365 return 366 } 367 368 if gh.cache != nil && !cacheHit { 369 gh.cache.Add(reqParams.Query, doc) 370 } 371 372 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars) 373 ctx = graphql.WithRequestContext(ctx, reqCtx) 374 375 defer func() { 376 if err := recover(); err != nil { 377 userErr := reqCtx.Recover(ctx, err) 378 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error()) 379 } 380 }() 381 382 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit { 383 sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit) 384 return 385 } 386 387 switch op.Operation { 388 case ast.Query: 389 b, err := json.Marshal(gh.exec.Query(ctx, op)) 390 if err != nil { 391 panic(err) 392 } 393 w.Write(b) 394 case ast.Mutation: 395 b, err := json.Marshal(gh.exec.Mutation(ctx, op)) 396 if err != nil { 397 panic(err) 398 } 399 w.Write(b) 400 default: 401 sendErrorf(w, http.StatusBadRequest, "unsupported operation type") 402 } 403 } 404 405 type parseOperationArgs struct { 406 Query string 407 CachedDoc *ast.QueryDocument 408 } 409 410 func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) { 411 ctx = gh.cfg.tracer.StartOperationParsing(ctx) 412 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }() 413 414 if args.CachedDoc != nil { 415 return ctx, args.CachedDoc, nil 416 } 417 418 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query}) 419 if gqlErr != nil { 420 return ctx, nil, gqlErr 421 } 422 423 return ctx, doc, nil 424 } 425 426 type validateOperationArgs struct { 427 Doc *ast.QueryDocument 428 OperationName string 429 CacheHit bool 430 R *http.Request 431 Variables map[string]interface{} 432 } 433 434 func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) { 435 ctx = gh.cfg.tracer.StartOperationValidation(ctx) 436 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }() 437 438 if !args.CacheHit { 439 listErr := validator.Validate(gh.exec.Schema(), args.Doc) 440 if len(listErr) != 0 { 441 return ctx, nil, nil, listErr 442 } 443 } 444 445 op := args.Doc.Operations.ForName(args.OperationName) 446 if op == nil { 447 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)} 448 } 449 450 if op.Operation != ast.Query && args.R.Method == http.MethodGet { 451 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} 452 } 453 454 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables) 455 if err != nil { 456 return ctx, nil, nil, gqlerror.List{err} 457 } 458 459 return ctx, op, vars, nil 460 } 461 462 func jsonDecode(r io.Reader, val interface{}) error { 463 dec := json.NewDecoder(r) 464 dec.UseNumber() 465 return dec.Decode(val) 466 } 467 468 func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { 469 w.WriteHeader(code) 470 b, err := json.Marshal(&graphql.Response{Errors: errors}) 471 if err != nil { 472 panic(err) 473 } 474 w.Write(b) 475 } 476 477 func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) { 478 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 479 }