github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/server.go (about)

     1  package handler
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"time"
     9  
    10  	"github.com/99designs/gqlgen/graphql"
    11  	"github.com/99designs/gqlgen/graphql/executor"
    12  	"github.com/99designs/gqlgen/graphql/handler/extension"
    13  	"github.com/99designs/gqlgen/graphql/handler/lru"
    14  	"github.com/99designs/gqlgen/graphql/handler/transport"
    15  	"github.com/vektah/gqlparser/v2/gqlerror"
    16  )
    17  
    18  type (
    19  	Server struct {
    20  		transports []graphql.Transport
    21  		exec       *executor.Executor
    22  	}
    23  )
    24  
    25  func New(es graphql.ExecutableSchema) *Server {
    26  	return &Server{
    27  		exec: executor.New(es),
    28  	}
    29  }
    30  
    31  func NewDefaultServer(es graphql.ExecutableSchema) *Server {
    32  	srv := New(es)
    33  
    34  	srv.AddTransport(transport.Websocket{
    35  		KeepAlivePingInterval: 10 * time.Second,
    36  	})
    37  	srv.AddTransport(transport.Options{})
    38  	srv.AddTransport(transport.GET{})
    39  	srv.AddTransport(transport.POST{})
    40  	srv.AddTransport(transport.MultipartForm{})
    41  
    42  	srv.SetQueryCache(lru.New(1000))
    43  
    44  	srv.Use(extension.Introspection{})
    45  	srv.Use(extension.AutomaticPersistedQuery{
    46  		Cache: lru.New(100),
    47  	})
    48  
    49  	return srv
    50  }
    51  
    52  func (s *Server) AddTransport(transport graphql.Transport) {
    53  	s.transports = append(s.transports, transport)
    54  }
    55  
    56  func (s *Server) SetErrorPresenter(f graphql.ErrorPresenterFunc) {
    57  	s.exec.SetErrorPresenter(f)
    58  }
    59  
    60  func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) {
    61  	s.exec.SetRecoverFunc(f)
    62  }
    63  
    64  func (s *Server) SetQueryCache(cache graphql.Cache) {
    65  	s.exec.SetQueryCache(cache)
    66  }
    67  
    68  func (s *Server) Use(extension graphql.HandlerExtension) {
    69  	s.exec.Use(extension)
    70  }
    71  
    72  // AroundFields is a convenience method for creating an extension that only implements field middleware
    73  func (s *Server) AroundFields(f graphql.FieldMiddleware) {
    74  	s.exec.AroundFields(f)
    75  }
    76  
    77  // AroundOperations is a convenience method for creating an extension that only implements operation middleware
    78  func (s *Server) AroundOperations(f graphql.OperationMiddleware) {
    79  	s.exec.AroundOperations(f)
    80  }
    81  
    82  // AroundResponses is a convenience method for creating an extension that only implements response middleware
    83  func (s *Server) AroundResponses(f graphql.ResponseMiddleware) {
    84  	s.exec.AroundResponses(f)
    85  }
    86  
    87  func (s *Server) getTransport(r *http.Request) graphql.Transport {
    88  	for _, t := range s.transports {
    89  		if t.Supports(r) {
    90  			return t
    91  		}
    92  	}
    93  	return nil
    94  }
    95  
    96  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    97  	defer func() {
    98  		if err := recover(); err != nil {
    99  			err := s.exec.PresentRecoveredError(r.Context(), err)
   100  			resp := &graphql.Response{Errors: []*gqlerror.Error{err}}
   101  			b, _ := json.Marshal(resp)
   102  			w.WriteHeader(http.StatusUnprocessableEntity)
   103  			w.Write(b)
   104  		}
   105  	}()
   106  
   107  	r = r.WithContext(graphql.StartOperationTrace(r.Context()))
   108  
   109  	transport := s.getTransport(r)
   110  	if transport == nil {
   111  		sendErrorf(w, http.StatusBadRequest, "transport not supported")
   112  		return
   113  	}
   114  
   115  	transport.Do(w, r, s.exec)
   116  }
   117  
   118  func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
   119  	w.WriteHeader(code)
   120  	b, err := json.Marshal(&graphql.Response{Errors: errors})
   121  	if err != nil {
   122  		panic(err)
   123  	}
   124  	w.Write(b)
   125  }
   126  
   127  func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
   128  	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   129  }
   130  
   131  type OperationFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
   132  
   133  func (r OperationFunc) ExtensionName() string {
   134  	return "InlineOperationFunc"
   135  }
   136  
   137  func (r OperationFunc) Validate(schema graphql.ExecutableSchema) error {
   138  	if r == nil {
   139  		return fmt.Errorf("OperationFunc can not be nil")
   140  	}
   141  	return nil
   142  }
   143  
   144  func (r OperationFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
   145  	return r(ctx, next)
   146  }
   147  
   148  type ResponseFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response
   149  
   150  func (r ResponseFunc) ExtensionName() string {
   151  	return "InlineResponseFunc"
   152  }
   153  
   154  func (r ResponseFunc) Validate(schema graphql.ExecutableSchema) error {
   155  	if r == nil {
   156  		return fmt.Errorf("ResponseFunc can not be nil")
   157  	}
   158  	return nil
   159  }
   160  
   161  func (r ResponseFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   162  	return r(ctx, next)
   163  }
   164  
   165  type FieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)
   166  
   167  func (f FieldFunc) ExtensionName() string {
   168  	return "InlineFieldFunc"
   169  }
   170  
   171  func (f FieldFunc) Validate(schema graphql.ExecutableSchema) error {
   172  	if f == nil {
   173  		return fmt.Errorf("FieldFunc can not be nil")
   174  	}
   175  	return nil
   176  }
   177  
   178  func (f FieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
   179  	return f(ctx, next)
   180  }