github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/graphql/handler/server.go (about)

     1  package handler
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"time"
     9  
    10  	"github.com/mstephano/gqlgen-schemagen/graphql"
    11  	"github.com/mstephano/gqlgen-schemagen/graphql/executor"
    12  	"github.com/mstephano/gqlgen-schemagen/graphql/handler/extension"
    13  	"github.com/mstephano/gqlgen-schemagen/graphql/handler/lru"
    14  	"github.com/mstephano/gqlgen-schemagen/graphql/handler/transport"
    15  	"github.com/vektah/gqlparser/v2/gqlerror"
    16  )
    17  
    18  type (
    19  	Server struct {
    20  		transports []graphql.Transport
    21  		exec       *executor.Executor
    22  	}
    23  )
    24  
    25  func New(es graphql.ExecutableSchema) *Server {
    26  	return &Server{
    27  		exec: executor.New(es),
    28  	}
    29  }
    30  
    31  func NewDefaultServer(es graphql.ExecutableSchema) *Server {
    32  	srv := New(es)
    33  
    34  	srv.AddTransport(transport.Websocket{
    35  		KeepAlivePingInterval: 10 * time.Second,
    36  	})
    37  	srv.AddTransport(transport.Options{})
    38  	srv.AddTransport(transport.GET{})
    39  	srv.AddTransport(transport.POST{})
    40  	srv.AddTransport(transport.MultipartForm{})
    41  
    42  	srv.SetQueryCache(lru.New(1000))
    43  
    44  	srv.Use(extension.Introspection{})
    45  	srv.Use(extension.AutomaticPersistedQuery{
    46  		Cache: lru.New(100),
    47  	})
    48  
    49  	return srv
    50  }
    51  
    52  func (s *Server) AddTransport(transport graphql.Transport) {
    53  	s.transports = append(s.transports, transport)
    54  }
    55  
    56  func (s *Server) SetErrorPresenter(f graphql.ErrorPresenterFunc) {
    57  	s.exec.SetErrorPresenter(f)
    58  }
    59  
    60  func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) {
    61  	s.exec.SetRecoverFunc(f)
    62  }
    63  
    64  func (s *Server) SetQueryCache(cache graphql.Cache) {
    65  	s.exec.SetQueryCache(cache)
    66  }
    67  
    68  func (s *Server) Use(extension graphql.HandlerExtension) {
    69  	s.exec.Use(extension)
    70  }
    71  
    72  // AroundFields is a convenience method for creating an extension that only implements field middleware
    73  func (s *Server) AroundFields(f graphql.FieldMiddleware) {
    74  	s.exec.AroundFields(f)
    75  }
    76  
    77  // AroundRootFields is a convenience method for creating an extension that only implements field middleware
    78  func (s *Server) AroundRootFields(f graphql.RootFieldMiddleware) {
    79  	s.exec.AroundRootFields(f)
    80  }
    81  
    82  // AroundOperations is a convenience method for creating an extension that only implements operation middleware
    83  func (s *Server) AroundOperations(f graphql.OperationMiddleware) {
    84  	s.exec.AroundOperations(f)
    85  }
    86  
    87  // AroundResponses is a convenience method for creating an extension that only implements response middleware
    88  func (s *Server) AroundResponses(f graphql.ResponseMiddleware) {
    89  	s.exec.AroundResponses(f)
    90  }
    91  
    92  func (s *Server) getTransport(r *http.Request) graphql.Transport {
    93  	for _, t := range s.transports {
    94  		if t.Supports(r) {
    95  			return t
    96  		}
    97  	}
    98  	return nil
    99  }
   100  
   101  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   102  	defer func() {
   103  		if err := recover(); err != nil {
   104  			err := s.exec.PresentRecoveredError(r.Context(), err)
   105  			gqlErr, _ := err.(*gqlerror.Error)
   106  			resp := &graphql.Response{Errors: []*gqlerror.Error{gqlErr}}
   107  			b, _ := json.Marshal(resp)
   108  			w.WriteHeader(http.StatusUnprocessableEntity)
   109  			w.Write(b)
   110  		}
   111  	}()
   112  
   113  	r = r.WithContext(graphql.StartOperationTrace(r.Context()))
   114  
   115  	transport := s.getTransport(r)
   116  	if transport == nil {
   117  		sendErrorf(w, http.StatusBadRequest, "transport not supported")
   118  		return
   119  	}
   120  
   121  	transport.Do(w, r, s.exec)
   122  }
   123  
   124  func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
   125  	w.WriteHeader(code)
   126  	b, err := json.Marshal(&graphql.Response{Errors: errors})
   127  	if err != nil {
   128  		panic(err)
   129  	}
   130  	w.Write(b)
   131  }
   132  
   133  func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
   134  	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   135  }
   136  
   137  type OperationFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
   138  
   139  func (r OperationFunc) ExtensionName() string {
   140  	return "InlineOperationFunc"
   141  }
   142  
   143  func (r OperationFunc) Validate(schema graphql.ExecutableSchema) error {
   144  	if r == nil {
   145  		return fmt.Errorf("OperationFunc can not be nil")
   146  	}
   147  	return nil
   148  }
   149  
   150  func (r OperationFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
   151  	return r(ctx, next)
   152  }
   153  
   154  type ResponseFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response
   155  
   156  func (r ResponseFunc) ExtensionName() string {
   157  	return "InlineResponseFunc"
   158  }
   159  
   160  func (r ResponseFunc) Validate(schema graphql.ExecutableSchema) error {
   161  	if r == nil {
   162  		return fmt.Errorf("ResponseFunc can not be nil")
   163  	}
   164  	return nil
   165  }
   166  
   167  func (r ResponseFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   168  	return r(ctx, next)
   169  }
   170  
   171  type FieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)
   172  
   173  func (f FieldFunc) ExtensionName() string {
   174  	return "InlineFieldFunc"
   175  }
   176  
   177  func (f FieldFunc) Validate(schema graphql.ExecutableSchema) error {
   178  	if f == nil {
   179  		return fmt.Errorf("FieldFunc can not be nil")
   180  	}
   181  	return nil
   182  }
   183  
   184  func (f FieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
   185  	return f(ctx, next)
   186  }