github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/handler/graphql.go (about)

     1  package handler
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/HaswinVidanage/gqlgen/complexity"
    13  	"github.com/HaswinVidanage/gqlgen/graphql"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/hashicorp/golang-lru"
    16  	"github.com/vektah/gqlparser/ast"
    17  	"github.com/vektah/gqlparser/gqlerror"
    18  	"github.com/vektah/gqlparser/parser"
    19  	"github.com/vektah/gqlparser/validator"
    20  )
    21  
    22  type params struct {
    23  	Query         string                 `json:"query"`
    24  	OperationName string                 `json:"operationName"`
    25  	Variables     map[string]interface{} `json:"variables"`
    26  }
    27  
    28  type Config struct {
    29  	cacheSize                       int
    30  	upgrader                        websocket.Upgrader
    31  	recover                         graphql.RecoverFunc
    32  	errorPresenter                  graphql.ErrorPresenterFunc
    33  	resolverHook                    graphql.FieldMiddleware
    34  	requestHook                     graphql.RequestMiddleware
    35  	tracer                          graphql.Tracer
    36  	complexityLimit                 int
    37  	disableIntrospection            bool
    38  	connectionKeepAlivePingInterval time.Duration
    39  }
    40  
    41  func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
    42  	reqCtx := graphql.NewRequestContext(doc, query, variables)
    43  	reqCtx.DisableIntrospection = c.disableIntrospection
    44  
    45  	if hook := c.recover; hook != nil {
    46  		reqCtx.Recover = hook
    47  	}
    48  
    49  	if hook := c.errorPresenter; hook != nil {
    50  		reqCtx.ErrorPresenter = hook
    51  	}
    52  
    53  	if hook := c.resolverHook; hook != nil {
    54  		reqCtx.ResolverMiddleware = hook
    55  	}
    56  
    57  	if hook := c.requestHook; hook != nil {
    58  		reqCtx.RequestMiddleware = hook
    59  	}
    60  
    61  	if hook := c.tracer; hook != nil {
    62  		reqCtx.Tracer = hook
    63  	}
    64  
    65  	if c.complexityLimit > 0 {
    66  		reqCtx.ComplexityLimit = c.complexityLimit
    67  		operationComplexity := complexity.Calculate(es, op, variables)
    68  		reqCtx.OperationComplexity = operationComplexity
    69  	}
    70  
    71  	return reqCtx
    72  }
    73  
    74  type Option func(cfg *Config)
    75  
    76  func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
    77  	return func(cfg *Config) {
    78  		cfg.upgrader = upgrader
    79  	}
    80  }
    81  
    82  func RecoverFunc(recover graphql.RecoverFunc) Option {
    83  	return func(cfg *Config) {
    84  		cfg.recover = recover
    85  	}
    86  }
    87  
    88  // ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
    89  // a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
    90  // implementation in graphql.DefaultErrorPresenter for an example.
    91  func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
    92  	return func(cfg *Config) {
    93  		cfg.errorPresenter = f
    94  	}
    95  }
    96  
    97  // IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
    98  // want clients introspecting the full schema.
    99  func IntrospectionEnabled(enabled bool) Option {
   100  	return func(cfg *Config) {
   101  		cfg.disableIntrospection = !enabled
   102  	}
   103  }
   104  
   105  // ComplexityLimit sets a maximum query complexity that is allowed to be executed.
   106  // If a query is submitted that exceeds the limit, a 422 status code will be returned.
   107  func ComplexityLimit(limit int) Option {
   108  	return func(cfg *Config) {
   109  		cfg.complexityLimit = limit
   110  	}
   111  }
   112  
   113  // ResolverMiddleware allows you to define a function that will be called around every resolver,
   114  // useful for logging.
   115  func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
   116  	return func(cfg *Config) {
   117  		if cfg.resolverHook == nil {
   118  			cfg.resolverHook = middleware
   119  			return
   120  		}
   121  
   122  		lastResolve := cfg.resolverHook
   123  		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
   124  			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
   125  				return middleware(ctx, next)
   126  			})
   127  		}
   128  	}
   129  }
   130  
   131  // RequestMiddleware allows you to define a function that will be called around the root request,
   132  // after the query has been parsed. This is useful for logging
   133  func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
   134  	return func(cfg *Config) {
   135  		if cfg.requestHook == nil {
   136  			cfg.requestHook = middleware
   137  			return
   138  		}
   139  
   140  		lastResolve := cfg.requestHook
   141  		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   142  			return lastResolve(ctx, func(ctx context.Context) []byte {
   143  				return middleware(ctx, next)
   144  			})
   145  		}
   146  	}
   147  }
   148  
   149  // Tracer allows you to add a request/resolver tracer that will be called around the root request,
   150  // calling resolver. This is useful for tracing
   151  func Tracer(tracer graphql.Tracer) Option {
   152  	return func(cfg *Config) {
   153  		if cfg.tracer == nil {
   154  			cfg.tracer = tracer
   155  
   156  		} else {
   157  			lastResolve := cfg.tracer
   158  			cfg.tracer = &tracerWrapper{
   159  				tracer1: lastResolve,
   160  				tracer2: tracer,
   161  			}
   162  		}
   163  
   164  		opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
   165  			ctx = tracer.StartOperationExecution(ctx)
   166  			resp := next(ctx)
   167  			tracer.EndOperationExecution(ctx)
   168  
   169  			return resp
   170  		})
   171  		opt(cfg)
   172  	}
   173  }
   174  
   175  type tracerWrapper struct {
   176  	tracer1 graphql.Tracer
   177  	tracer2 graphql.Tracer
   178  }
   179  
   180  func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
   181  	ctx = tw.tracer1.StartOperationParsing(ctx)
   182  	ctx = tw.tracer2.StartOperationParsing(ctx)
   183  	return ctx
   184  }
   185  
   186  func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
   187  	tw.tracer2.EndOperationParsing(ctx)
   188  	tw.tracer1.EndOperationParsing(ctx)
   189  }
   190  
   191  func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
   192  	ctx = tw.tracer1.StartOperationValidation(ctx)
   193  	ctx = tw.tracer2.StartOperationValidation(ctx)
   194  	return ctx
   195  }
   196  
   197  func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
   198  	tw.tracer2.EndOperationValidation(ctx)
   199  	tw.tracer1.EndOperationValidation(ctx)
   200  }
   201  
   202  func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
   203  	ctx = tw.tracer1.StartOperationExecution(ctx)
   204  	ctx = tw.tracer2.StartOperationExecution(ctx)
   205  	return ctx
   206  }
   207  
   208  func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
   209  	ctx = tw.tracer1.StartFieldExecution(ctx, field)
   210  	ctx = tw.tracer2.StartFieldExecution(ctx, field)
   211  	return ctx
   212  }
   213  
   214  func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
   215  	ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
   216  	ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
   217  	return ctx
   218  }
   219  
   220  func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
   221  	ctx = tw.tracer1.StartFieldChildExecution(ctx)
   222  	ctx = tw.tracer2.StartFieldChildExecution(ctx)
   223  	return ctx
   224  }
   225  
   226  func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
   227  	tw.tracer2.EndFieldExecution(ctx)
   228  	tw.tracer1.EndFieldExecution(ctx)
   229  }
   230  
   231  func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
   232  	tw.tracer2.EndOperationExecution(ctx)
   233  	tw.tracer1.EndOperationExecution(ctx)
   234  }
   235  
   236  // CacheSize sets the maximum size of the query cache.
   237  // If size is less than or equal to 0, the cache is disabled.
   238  func CacheSize(size int) Option {
   239  	return func(cfg *Config) {
   240  		cfg.cacheSize = size
   241  	}
   242  }
   243  
   244  // WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior.
   245  // By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval
   246  // duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive
   247  // altogether.
   248  func WebsocketKeepAliveDuration(duration time.Duration) Option {
   249  	return func(cfg *Config) {
   250  		cfg.connectionKeepAlivePingInterval = duration
   251  	}
   252  }
   253  
   254  const DefaultCacheSize = 1000
   255  const DefaultConnectionKeepAlivePingInterval = 25 * time.Second
   256  
   257  func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
   258  	cfg := &Config{
   259  		cacheSize:                       DefaultCacheSize,
   260  		connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval,
   261  		upgrader: websocket.Upgrader{
   262  			ReadBufferSize:  1024,
   263  			WriteBufferSize: 1024,
   264  		},
   265  	}
   266  
   267  	for _, option := range options {
   268  		option(cfg)
   269  	}
   270  
   271  	var cache *lru.Cache
   272  	if cfg.cacheSize > 0 {
   273  		var err error
   274  		cache, err = lru.New(cfg.cacheSize)
   275  		if err != nil {
   276  			// An error is only returned for non-positive cache size
   277  			// and we already checked for that.
   278  			panic("unexpected error creating cache: " + err.Error())
   279  		}
   280  	}
   281  	if cfg.tracer == nil {
   282  		cfg.tracer = &graphql.NopTracer{}
   283  	}
   284  
   285  	handler := &graphqlHandler{
   286  		cfg:   cfg,
   287  		cache: cache,
   288  		exec:  exec,
   289  	}
   290  
   291  	return handler.ServeHTTP
   292  }
   293  
   294  var _ http.Handler = (*graphqlHandler)(nil)
   295  
   296  type graphqlHandler struct {
   297  	cfg   *Config
   298  	cache *lru.Cache
   299  	exec  graphql.ExecutableSchema
   300  }
   301  
   302  func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   303  	if r.Method == http.MethodOptions {
   304  		w.Header().Set("Allow", "OPTIONS, GET, POST")
   305  		w.WriteHeader(http.StatusOK)
   306  		return
   307  	}
   308  
   309  	if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
   310  		connectWs(gh.exec, w, r, gh.cfg, gh.cache)
   311  		return
   312  	}
   313  
   314  	ctx := r.Context()
   315  
   316  	w.Header().Set("Content-Type", "application/json")
   317  	var reqParams params
   318  	switch r.Method {
   319  	case http.MethodGet:
   320  		reqParams.Query = r.URL.Query().Get("query")
   321  		reqParams.OperationName = r.URL.Query().Get("operationName")
   322  
   323  		if variables := r.URL.Query().Get("variables"); variables != "" {
   324  			if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
   325  				gh.sendErrorf(ctx, w, http.StatusBadRequest, "variables could not be decoded")
   326  				return
   327  			}
   328  		}
   329  	case http.MethodPost:
   330  		if err := jsonDecode(r.Body, &reqParams); err != nil {
   331  			gh.sendErrorf(ctx, w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
   332  			return
   333  		}
   334  	default:
   335  		w.WriteHeader(http.StatusMethodNotAllowed)
   336  		return
   337  	}
   338  
   339  	var doc *ast.QueryDocument
   340  	var cacheHit bool
   341  	if gh.cache != nil {
   342  		val, ok := gh.cache.Get(reqParams.Query)
   343  		if ok {
   344  			doc = val.(*ast.QueryDocument)
   345  			cacheHit = true
   346  		}
   347  	}
   348  
   349  	ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
   350  		Query:     reqParams.Query,
   351  		CachedDoc: doc,
   352  	})
   353  	if gqlErr != nil {
   354  		gh.sendError(ctx, w, http.StatusUnprocessableEntity, gqlErr)
   355  		return
   356  	}
   357  
   358  	ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
   359  		Doc:           doc,
   360  		OperationName: reqParams.OperationName,
   361  		CacheHit:      cacheHit,
   362  		R:             r,
   363  		Variables:     reqParams.Variables,
   364  	})
   365  	if len(listErr) != 0 {
   366  		gh.sendError(ctx, w, http.StatusUnprocessableEntity, listErr...)
   367  		return
   368  	}
   369  
   370  	if gh.cache != nil && !cacheHit {
   371  		gh.cache.Add(reqParams.Query, doc)
   372  	}
   373  
   374  	reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
   375  	ctx = graphql.WithRequestContext(ctx, reqCtx)
   376  
   377  	defer func() {
   378  		if err := recover(); err != nil {
   379  			userErr := reqCtx.Recover(ctx, err)
   380  			gh.sendErrorf(ctx, w, http.StatusUnprocessableEntity, userErr.Error())
   381  		}
   382  	}()
   383  
   384  	if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
   385  		gh.sendErrorf(ctx, w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
   386  		return
   387  	}
   388  
   389  	switch op.Operation {
   390  	case ast.Query:
   391  		b, err := json.Marshal(gh.exec.Query(ctx, op))
   392  		if err != nil {
   393  			panic(err)
   394  		}
   395  		w.Write(b)
   396  	case ast.Mutation:
   397  		b, err := json.Marshal(gh.exec.Mutation(ctx, op))
   398  		if err != nil {
   399  			panic(err)
   400  		}
   401  		w.Write(b)
   402  	default:
   403  		gh.sendErrorf(ctx, w, http.StatusBadRequest, "unsupported operation type")
   404  	}
   405  }
   406  
   407  type parseOperationArgs struct {
   408  	Query     string
   409  	CachedDoc *ast.QueryDocument
   410  }
   411  
   412  func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
   413  	ctx = gh.cfg.tracer.StartOperationParsing(ctx)
   414  	defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
   415  
   416  	if args.CachedDoc != nil {
   417  		return ctx, args.CachedDoc, nil
   418  	}
   419  
   420  	doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
   421  	if gqlErr != nil {
   422  		return ctx, nil, gqlErr
   423  	}
   424  
   425  	return ctx, doc, nil
   426  }
   427  
   428  type validateOperationArgs struct {
   429  	Doc           *ast.QueryDocument
   430  	OperationName string
   431  	CacheHit      bool
   432  	R             *http.Request
   433  	Variables     map[string]interface{}
   434  }
   435  
   436  func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
   437  	ctx = gh.cfg.tracer.StartOperationValidation(ctx)
   438  	defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
   439  
   440  	if !args.CacheHit {
   441  		listErr := validator.Validate(gh.exec.Schema(), args.Doc)
   442  		if len(listErr) != 0 {
   443  			return ctx, nil, nil, listErr
   444  		}
   445  	}
   446  
   447  	op := args.Doc.Operations.ForName(args.OperationName)
   448  	if op == nil {
   449  		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
   450  	}
   451  
   452  	if op.Operation != ast.Query && args.R.Method == http.MethodGet {
   453  		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
   454  	}
   455  
   456  	vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
   457  	if err != nil {
   458  		return ctx, nil, nil, gqlerror.List{err}
   459  	}
   460  
   461  	return ctx, op, vars, nil
   462  }
   463  
   464  func jsonDecode(r io.Reader, val interface{}) error {
   465  	dec := json.NewDecoder(r)
   466  	dec.UseNumber()
   467  	return dec.Decode(val)
   468  }
   469  
   470  func (gh *graphqlHandler) sendError(ctx context.Context, w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
   471  	if gh.cfg.errorPresenter != nil {
   472  		var newErrors []*gqlerror.Error
   473  		for _, err := range errors {
   474  			newErr := gh.cfg.errorPresenter(ctx, err)
   475  			newErrors = append(newErrors, newErr)
   476  		}
   477  		errors = newErrors
   478  	}
   479  	w.WriteHeader(code)
   480  	b, err := json.Marshal(&graphql.Response{Errors: errors})
   481  	if err != nil {
   482  		panic(err)
   483  	}
   484  	w.Write(b)
   485  }
   486  
   487  func (gh *graphqlHandler) sendErrorf(ctx context.Context, w http.ResponseWriter, code int, format string, args ...interface{}) {
   488  	gh.sendError(ctx, w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   489  }
   490  
   491  func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
   492  	w.WriteHeader(code)
   493  	b, err := json.Marshal(&graphql.Response{Errors: errors})
   494  	if err != nil {
   495  		panic(err)
   496  	}
   497  	w.Write(b)
   498  }
   499  
   500  func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
   501  	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   502  }