github.com/deliveroo/gqlgen@v0.7.2/handler/graphql.go (about)

     1  package handler
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/99designs/gqlgen/complexity"
    13  	"github.com/99designs/gqlgen/graphql"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/hashicorp/golang-lru"
    16  	"github.com/vektah/gqlparser/ast"
    17  	"github.com/vektah/gqlparser/gqlerror"
    18  	"github.com/vektah/gqlparser/parser"
    19  	"github.com/vektah/gqlparser/validator"
    20  )
    21  
    22  type params struct {
    23  	Query         string                 `json:"query"`
    24  	OperationName string                 `json:"operationName"`
    25  	Variables     map[string]interface{} `json:"variables"`
    26  }
    27  
    28  type Config struct {
    29  	cacheSize                       int
    30  	upgrader                        websocket.Upgrader
    31  	recover                         graphql.RecoverFunc
    32  	errorPresenter                  graphql.ErrorPresenterFunc
    33  	resolverHook                    graphql.FieldMiddleware
    34  	requestHook                     graphql.RequestMiddleware
    35  	tracer                          graphql.Tracer
    36  	complexityLimit                 int
    37  	disableIntrospection            bool
    38  	connectionKeepAlivePingInterval time.Duration
    39  }
    40  
    41  func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
    42  	reqCtx := graphql.NewRequestContext(doc, query, variables)
    43  	reqCtx.DisableIntrospection = c.disableIntrospection
    44  
    45  	if hook := c.recover; hook != nil {
    46  		reqCtx.Recover = hook
    47  	}
    48  
    49  	if hook := c.errorPresenter; hook != nil {
    50  		reqCtx.ErrorPresenter = hook
    51  	}
    52  
    53  	if hook := c.resolverHook; hook != nil {
    54  		reqCtx.ResolverMiddleware = hook
    55  	}
    56  
    57  	if hook := c.requestHook; hook != nil {
    58  		reqCtx.RequestMiddleware = hook
    59  	}
    60  
    61  	if hook := c.tracer; hook != nil {
    62  		reqCtx.Tracer = hook
    63  	} else {
    64  		reqCtx.Tracer = &graphql.NopTracer{}
    65  	}
    66  
    67  	if c.complexityLimit > 0 {
    68  		reqCtx.ComplexityLimit = c.complexityLimit
    69  		operationComplexity := complexity.Calculate(es, op, variables)
    70  		reqCtx.OperationComplexity = operationComplexity
    71  	}
    72  
    73  	return reqCtx
    74  }
    75  
    76  type Option func(cfg *Config)
    77  
    78  func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
    79  	return func(cfg *Config) {
    80  		cfg.upgrader = upgrader
    81  	}
    82  }
    83  
    84  func RecoverFunc(recover graphql.RecoverFunc) Option {
    85  	return func(cfg *Config) {
    86  		cfg.recover = recover
    87  	}
    88  }
    89  
    90  // ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
    91  // a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
    92  // implementation in graphql.DefaultErrorPresenter for an example.
    93  func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
    94  	return func(cfg *Config) {
    95  		cfg.errorPresenter = f
    96  	}
    97  }
    98  
    99  // IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
   100  // want clients introspecting the full schema.
   101  func IntrospectionEnabled(enabled bool) Option {
   102  	return func(cfg *Config) {
   103  		cfg.disableIntrospection = !enabled
   104  	}
   105  }
   106  
   107  // ComplexityLimit sets a maximum query complexity that is allowed to be executed.
   108  // If a query is submitted that exceeds the limit, a 422 status code will be returned.
   109  func ComplexityLimit(limit int) Option {
   110  	return func(cfg *Config) {
   111  		cfg.complexityLimit = limit
   112  	}
   113  }
   114  
   115  // ResolverMiddleware allows you to define a function that will be called around every resolver,
   116  // useful for logging.
   117  func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
   118  	return func(cfg *Config) {
   119  		if cfg.resolverHook == nil {
   120  			cfg.resolverHook = middleware
   121  			return
   122  		}
   123  
   124  		lastResolve := cfg.resolverHook
   125  		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
   126  			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
   127  				return middleware(ctx, next)
   128  			})
   129  		}
   130  	}
   131  }
   132  
   133  // RequestMiddleware allows you to define a function that will be called around the root request,
   134  // after the query has been parsed. This is useful for logging
   135  func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
   136  	return func(cfg *Config) {
   137  		if cfg.requestHook == nil {
   138  			cfg.requestHook = middleware
   139  			return
   140  		}
   141  
   142  		lastResolve := cfg.requestHook
   143  		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   144  			return lastResolve(ctx, func(ctx context.Context) []byte {
   145  				return middleware(ctx, next)
   146  			})
   147  		}
   148  	}
   149  }
   150  
   151  // Tracer allows you to add a request/resolver tracer that will be called around the root request,
   152  // calling resolver. This is useful for tracing
   153  func Tracer(tracer graphql.Tracer) Option {
   154  	return func(cfg *Config) {
   155  		if cfg.tracer == nil {
   156  			cfg.tracer = tracer
   157  
   158  		} else {
   159  			lastResolve := cfg.tracer
   160  			cfg.tracer = &tracerWrapper{
   161  				tracer1: lastResolve,
   162  				tracer2: tracer,
   163  			}
   164  		}
   165  
   166  		opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   167  			ctx = tracer.StartOperationExecution(ctx)
   168  			resp := next(ctx)
   169  			tracer.EndOperationExecution(ctx)
   170  
   171  			return resp
   172  		})
   173  		opt(cfg)
   174  	}
   175  }
   176  
   177  type tracerWrapper struct {
   178  	tracer1 graphql.Tracer
   179  	tracer2 graphql.Tracer
   180  }
   181  
   182  func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
   183  	ctx = tw.tracer1.StartOperationParsing(ctx)
   184  	ctx = tw.tracer2.StartOperationParsing(ctx)
   185  	return ctx
   186  }
   187  
   188  func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
   189  	tw.tracer2.EndOperationParsing(ctx)
   190  	tw.tracer1.EndOperationParsing(ctx)
   191  }
   192  
   193  func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
   194  	ctx = tw.tracer1.StartOperationValidation(ctx)
   195  	ctx = tw.tracer2.StartOperationValidation(ctx)
   196  	return ctx
   197  }
   198  
   199  func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
   200  	tw.tracer2.EndOperationValidation(ctx)
   201  	tw.tracer1.EndOperationValidation(ctx)
   202  }
   203  
   204  func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
   205  	ctx = tw.tracer1.StartOperationExecution(ctx)
   206  	ctx = tw.tracer2.StartOperationExecution(ctx)
   207  	return ctx
   208  }
   209  
   210  func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
   211  	ctx = tw.tracer1.StartFieldExecution(ctx, field)
   212  	ctx = tw.tracer2.StartFieldExecution(ctx, field)
   213  	return ctx
   214  }
   215  
   216  func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
   217  	ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
   218  	ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
   219  	return ctx
   220  }
   221  
   222  func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
   223  	ctx = tw.tracer1.StartFieldChildExecution(ctx)
   224  	ctx = tw.tracer2.StartFieldChildExecution(ctx)
   225  	return ctx
   226  }
   227  
   228  func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
   229  	tw.tracer2.EndFieldExecution(ctx)
   230  	tw.tracer1.EndFieldExecution(ctx)
   231  }
   232  
   233  func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
   234  	tw.tracer2.EndOperationExecution(ctx)
   235  	tw.tracer1.EndOperationExecution(ctx)
   236  }
   237  
   238  // CacheSize sets the maximum size of the query cache.
   239  // If size is less than or equal to 0, the cache is disabled.
   240  func CacheSize(size int) Option {
   241  	return func(cfg *Config) {
   242  		cfg.cacheSize = size
   243  	}
   244  }
   245  
   246  const DefaultCacheSize = 1000
   247  
   248  // WebsocketKeepAliveDuration allows you to reconfigure the keepAlive behavior.
   249  // By default, keep-alive is disabled.
   250  func WebsocketKeepAliveDuration(duration time.Duration) Option {
   251  	return func(cfg *Config) {
   252  		cfg.connectionKeepAlivePingInterval = duration
   253  	}
   254  }
   255  
   256  func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
   257  	cfg := &Config{
   258  		cacheSize: DefaultCacheSize,
   259  		upgrader: websocket.Upgrader{
   260  			ReadBufferSize:  1024,
   261  			WriteBufferSize: 1024,
   262  		},
   263  	}
   264  
   265  	for _, option := range options {
   266  		option(cfg)
   267  	}
   268  
   269  	var cache *lru.Cache
   270  	if cfg.cacheSize > 0 {
   271  		var err error
   272  		cache, err = lru.New(DefaultCacheSize)
   273  		if err != nil {
   274  			// An error is only returned for non-positive cache size
   275  			// and we already checked for that.
   276  			panic("unexpected error creating cache: " + err.Error())
   277  		}
   278  	}
   279  	if cfg.tracer == nil {
   280  		cfg.tracer = &graphql.NopTracer{}
   281  	}
   282  
   283  	handler := &graphqlHandler{
   284  		cfg:   cfg,
   285  		cache: cache,
   286  		exec:  exec,
   287  	}
   288  
   289  	return handler.ServeHTTP
   290  }
   291  
   292  var _ http.Handler = (*graphqlHandler)(nil)
   293  
   294  type graphqlHandler struct {
   295  	cfg   *Config
   296  	cache *lru.Cache
   297  	exec  graphql.ExecutableSchema
   298  }
   299  
   300  func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   301  	if r.Method == http.MethodOptions {
   302  		w.Header().Set("Allow", "OPTIONS, GET, POST")
   303  		w.WriteHeader(http.StatusOK)
   304  		return
   305  	}
   306  
   307  	if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
   308  		connectWs(gh.exec, w, r, gh.cfg)
   309  		return
   310  	}
   311  
   312  	var reqParams params
   313  	switch r.Method {
   314  	case http.MethodGet:
   315  		reqParams.Query = r.URL.Query().Get("query")
   316  		reqParams.OperationName = r.URL.Query().Get("operationName")
   317  
   318  		if variables := r.URL.Query().Get("variables"); variables != "" {
   319  			if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
   320  				sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
   321  				return
   322  			}
   323  		}
   324  	case http.MethodPost:
   325  		if err := jsonDecode(r.Body, &reqParams); err != nil {
   326  			sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
   327  			return
   328  		}
   329  	default:
   330  		w.WriteHeader(http.StatusMethodNotAllowed)
   331  		return
   332  	}
   333  	w.Header().Set("Content-Type", "application/json")
   334  
   335  	ctx := r.Context()
   336  
   337  	var doc *ast.QueryDocument
   338  	var cacheHit bool
   339  	if gh.cache != nil {
   340  		val, ok := gh.cache.Get(reqParams.Query)
   341  		if ok {
   342  			doc = val.(*ast.QueryDocument)
   343  			cacheHit = true
   344  		}
   345  	}
   346  
   347  	ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
   348  		Query:     reqParams.Query,
   349  		CachedDoc: doc,
   350  	})
   351  	if gqlErr != nil {
   352  		sendError(w, http.StatusUnprocessableEntity, gqlErr)
   353  		return
   354  	}
   355  
   356  	ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
   357  		Doc:           doc,
   358  		OperationName: reqParams.OperationName,
   359  		CacheHit:      cacheHit,
   360  		R:             r,
   361  		Variables:     reqParams.Variables,
   362  	})
   363  	if len(listErr) != 0 {
   364  		sendError(w, http.StatusUnprocessableEntity, listErr...)
   365  		return
   366  	}
   367  
   368  	if gh.cache != nil && !cacheHit {
   369  		gh.cache.Add(reqParams.Query, doc)
   370  	}
   371  
   372  	reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
   373  	ctx = graphql.WithRequestContext(ctx, reqCtx)
   374  
   375  	defer func() {
   376  		if err := recover(); err != nil {
   377  			userErr := reqCtx.Recover(ctx, err)
   378  			sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
   379  		}
   380  	}()
   381  
   382  	if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
   383  		sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
   384  		return
   385  	}
   386  
   387  	switch op.Operation {
   388  	case ast.Query:
   389  		b, err := json.Marshal(gh.exec.Query(ctx, op))
   390  		if err != nil {
   391  			panic(err)
   392  		}
   393  		w.Write(b)
   394  	case ast.Mutation:
   395  		b, err := json.Marshal(gh.exec.Mutation(ctx, op))
   396  		if err != nil {
   397  			panic(err)
   398  		}
   399  		w.Write(b)
   400  	default:
   401  		sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
   402  	}
   403  }
   404  
   405  type parseOperationArgs struct {
   406  	Query     string
   407  	CachedDoc *ast.QueryDocument
   408  }
   409  
   410  func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
   411  	ctx = gh.cfg.tracer.StartOperationParsing(ctx)
   412  	defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
   413  
   414  	if args.CachedDoc != nil {
   415  		return ctx, args.CachedDoc, nil
   416  	}
   417  
   418  	doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
   419  	if gqlErr != nil {
   420  		return ctx, nil, gqlErr
   421  	}
   422  
   423  	return ctx, doc, nil
   424  }
   425  
   426  type validateOperationArgs struct {
   427  	Doc           *ast.QueryDocument
   428  	OperationName string
   429  	CacheHit      bool
   430  	R             *http.Request
   431  	Variables     map[string]interface{}
   432  }
   433  
   434  func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
   435  	ctx = gh.cfg.tracer.StartOperationValidation(ctx)
   436  	defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
   437  
   438  	if !args.CacheHit {
   439  		listErr := validator.Validate(gh.exec.Schema(), args.Doc)
   440  		if len(listErr) != 0 {
   441  			return ctx, nil, nil, listErr
   442  		}
   443  	}
   444  
   445  	op := args.Doc.Operations.ForName(args.OperationName)
   446  	if op == nil {
   447  		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
   448  	}
   449  
   450  	if op.Operation != ast.Query && args.R.Method == http.MethodGet {
   451  		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
   452  	}
   453  
   454  	vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
   455  	if err != nil {
   456  		return ctx, nil, nil, gqlerror.List{err}
   457  	}
   458  
   459  	return ctx, op, vars, nil
   460  }
   461  
   462  func jsonDecode(r io.Reader, val interface{}) error {
   463  	dec := json.NewDecoder(r)
   464  	dec.UseNumber()
   465  	return dec.Decode(val)
   466  }
   467  
   468  func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
   469  	w.WriteHeader(code)
   470  	b, err := json.Marshal(&graphql.Response{Errors: errors})
   471  	if err != nil {
   472  		panic(err)
   473  	}
   474  	w.Write(b)
   475  }
   476  
   477  func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
   478  	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   479  }