goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/websocket/websocket.go (about)

     1  package websocket
     2  
     3  import (
     4  	"net/http"
     5  	"time"
     6  
     7  	stderrors "errors"
     8  
     9  	"goyave.dev/goyave/v5"
    10  	"goyave.dev/goyave/v5/util/errors"
    11  
    12  	ws "github.com/gorilla/websocket"
    13  )
    14  
    15  const (
    16  	// NormalClosureMessage the message sent with the close frame
    17  	// during the close handshake.
    18  	NormalClosureMessage = "Server closed connection"
    19  )
    20  
    21  // Controller component for websockets.
    22  type Controller interface {
    23  	goyave.Composable
    24  
    25  	// Serve is a handler for websocket connections.
    26  	// The request parameter contains the original upgraded HTTP request.
    27  	//
    28  	// To keep the connection alive, these handlers should run an infinite for loop that
    29  	// can return on error or exit in a predictable manner.
    30  	//
    31  	// They also can start goroutines for reads and writes, but shouldn't return before
    32  	// both of them do. The handler is responsible of synchronizing the goroutines it started,
    33  	// and ensure no reader nor writer are still active when it returns.
    34  	//
    35  	// When the websocket handler returns, the closing handshake is performed (if not already done
    36  	// using "conn.Close()") and the connection is closed.
    37  	//
    38  	// If the websocket handler returns nil, it means that everything went fine and the
    39  	// connection can be closed normally. On the other hand, the websocket handler
    40  	// can return an error, such as a write error, to indicate that the connection should not
    41  	// be closed normally. The behavior used when this happens depend on the implementation
    42  	// of the HTTP handler that upgraded the connection.
    43  	//
    44  	// By default, the server shutdown doesn't wait for hijacked connections to be closed gracefully.
    45  	// It is advised to register a shutdown hook blocking until all the connections are gracefully
    46  	// closed using `*websocket.Conn.CloseNormal()`.
    47  	//
    48  	// The following websocket Handler is a simple example of an "echo" feature using websockets:
    49  	//
    50  	//	func (c *EchoController) Serve(c *websocket.Conn, request *goyave.Request) error {
    51  	//		for {
    52  	//			mt, message, err := c.ReadMessage()
    53  	//			if err != nil {
    54  	//				return errors.New(err)
    55  	//			}
    56  	//			c.Logger().Debug("recv", "message", string(message))
    57  	//			err = c.WriteMessage(mt, message)
    58  	//			if err != nil {
    59  	//				return errors.Errof("write: %w", err)
    60  	//			}
    61  	//		}
    62  	//	}
    63  	Serve(*Conn, *goyave.Request) error
    64  }
    65  
    66  // Registrer qualifies a `websocket.Controller` that registers its route itself, allowing
    67  // to define validation rules, middleware, route meta, etc.
    68  //
    69  // If the `websocket.Controller` doesn't implement this interface, the route is registered
    70  // for the GET method and an empty path.
    71  type Registrer interface {
    72  	// RegisterRoute registers the route for the websocket upgrade. The route must only match the
    73  	// GET HTTP method and use the `goyave.Handler` received as a parameter.
    74  	RegisterRoute(*goyave.Router, goyave.Handler)
    75  }
    76  
    77  type upgradeErrorHandlerFunc func(response *goyave.Response, request *goyave.Request, status int, reason error)
    78  
    79  // UpgradeErrorHandler allows a `websocket.Controller` to define a custom behavior when
    80  // the protocol switching process fails.
    81  //
    82  // If the `websocket.Controller` doesn't implement this interface, the default
    83  // error handler returns a JSON response containing the status text
    84  // corresponding to the status code returned. If debugging is enabled, the reason error
    85  // message is returned instead.
    86  //
    87  //	{"error": "message"}
    88  type UpgradeErrorHandler interface {
    89  	// OnUpgradeError specifies the function for generating HTTP error responses if the
    90  	// protocol switching process fails. The error can be a user error or server error.
    91  	OnUpgradeError(response *goyave.Response, request *goyave.Request, status int, reason error)
    92  }
    93  
    94  // ErrorHandler allows a `websocket.Controller` to define a custom behavior in case of error
    95  // occurring in the controller's `Serve` function. This custom error handler is called for both
    96  // handled and unhandled errors (panics).
    97  //
    98  // If the `websocket.Controller` doesn't implement this interface, the error is logged at the error level.
    99  type ErrorHandler interface {
   100  	// ErrorHandler specifies the function handling errors returned by the controller's `Serve` function
   101  	// or if this function panics.
   102  	OnError(request *goyave.Request, err error)
   103  }
   104  
   105  // OriginChecker allows a `websocket.Controller` to define custom origin header checking behavior.
   106  //
   107  // If the `websocket.Controller` doesn't implement this interface, a safe default is used:
   108  // return false if the Origin request header is present and the origin host is not equal to
   109  // request Host header.
   110  type OriginChecker interface {
   111  	// CheckOrigin returns true if the request Origin header is acceptable.
   112  	//
   113  	// A CheckOrigin function should carefully validate the request origin to
   114  	// prevent cross-site request forgery.
   115  	CheckOrigin(r *goyave.Request) bool
   116  }
   117  
   118  // HeaderUpgrader allows a `websocket.Controller` to define custom HTTP headers in the
   119  // protocol switching response.
   120  type HeaderUpgrader interface {
   121  	// UpgradeHeaders function generating headers to be sent with the protocol switching response.
   122  	UpgradeHeaders(r *goyave.Request) http.Header
   123  }
   124  
   125  // Upgrader is responsible for the upgrade of HTTP connections to
   126  // websocket connections.
   127  type Upgrader struct {
   128  	goyave.Component
   129  
   130  	Controller Controller
   131  
   132  	// Settings the parameters for upgrading the connection. "Error" and "CheckOrigin" are
   133  	// ignored: use implementations of the interfaces `UpgradeErrorHandler` and `ErrorHandler`.
   134  	Settings ws.Upgrader
   135  }
   136  
   137  // New create a new Upgrader with default settings.
   138  func New(controller Controller) *Upgrader {
   139  	return &Upgrader{
   140  		Controller: controller,
   141  	}
   142  }
   143  
   144  // RegisterRoutes implementation of `goyave.Registrer`.
   145  //
   146  // If the `websocket.Controller` implements `websocket.Registrer`, uses its implementation
   147  // to register the route. Otherwise registers the route for the GET method and an empty path.
   148  func (u *Upgrader) RegisterRoutes(router *goyave.Router) {
   149  	if registrer, ok := u.Controller.(Registrer); ok {
   150  		registrer.RegisterRoute(router, u.Handler())
   151  		return
   152  	}
   153  	router.Get("", u.Handler())
   154  }
   155  
   156  func (u *Upgrader) defaultUpgradeErrorHandler(response *goyave.Response, _ *goyave.Request, status int, reason error) {
   157  	text := http.StatusText(status)
   158  	if u.Config().GetBool("app.debug") && reason != nil {
   159  		text = reason.Error()
   160  	}
   161  	message := map[string]string{
   162  		"error": text,
   163  	}
   164  	response.JSON(status, message)
   165  }
   166  
   167  func (u *Upgrader) makeUpgrader(request *goyave.Request) *ws.Upgrader {
   168  	upgradeErrorHandlerFunc := u.defaultUpgradeErrorHandler
   169  	if upgradeErrorHandler, ok := u.Controller.(UpgradeErrorHandler); ok {
   170  		upgradeErrorHandlerFunc = upgradeErrorHandler.OnUpgradeError
   171  	}
   172  
   173  	var checkOrigin func(r *goyave.Request) bool
   174  	if originChecker, ok := u.Controller.(OriginChecker); ok {
   175  		checkOrigin = originChecker.CheckOrigin
   176  	}
   177  
   178  	a := adapter{
   179  		upgradeErrorHandler: upgradeErrorHandlerFunc,
   180  		checkOrigin:         checkOrigin,
   181  		request:             request,
   182  	}
   183  
   184  	upgrader := u.Settings
   185  	upgrader.Error = a.onError
   186  	upgrader.CheckOrigin = a.getCheckOriginFunc()
   187  	return &upgrader
   188  }
   189  
   190  // Handler create an HTTP handler upgrading the HTTP connection before passing it
   191  // to the given websocket Handler.
   192  //
   193  // HTTP response's status is set to "101 Switching Protocols".
   194  //
   195  // The connection is closed automatically after the websocket Handler returns, using the
   196  // closing handshake defined by RFC 6455 Section 1.4 if possible and if not already
   197  // performed using "conn.Close()".
   198  //
   199  // If the websocket Handler returns an error that is not a CloseError, the Upgrader's error
   200  // handler will be executed and the close frame sent to the client will have status code
   201  // 1011 (internal server error) and "Internal server error" as message.
   202  // If debug is enabled, the message will be the error message returned by the
   203  // websocket handler. Otherwise the close frame will have status code 1000 (normal closure)
   204  // and "Server closed connection" as a message.
   205  //
   206  // This HTTP handler features a recovery mechanism. If the websocket Handler panics,
   207  // the connection will be gracefully closed just like if the websocket Handler returned
   208  // an error without panicking.
   209  //
   210  // This HTTP Handler returns once the connection has been successfully upgraded. That means
   211  // that, for example, logging middleware will log the request right away instead of waiting
   212  // for the websocket connection to be closed.
   213  func (u *Upgrader) Handler() goyave.Handler {
   214  	u.Controller.Init(u.Server())
   215  	return func(response *goyave.Response, request *goyave.Request) {
   216  		var headers http.Header
   217  		if headerUpgrader, ok := u.Controller.(HeaderUpgrader); ok {
   218  			headers = headerUpgrader.UpgradeHeaders(request)
   219  		}
   220  
   221  		c, err := u.makeUpgrader(request).Upgrade(response, request.Request(), headers)
   222  		if err != nil {
   223  			return
   224  		}
   225  		response.Status(http.StatusSwitchingProtocols)
   226  
   227  		go u.serve(c, request, u.Controller.Serve)
   228  	}
   229  }
   230  
   231  func (u *Upgrader) serve(c *ws.Conn, request *goyave.Request, handler func(*Conn, *goyave.Request) error) {
   232  	conn := newConn(c, time.Duration(u.Config().GetInt("server.websocketCloseTimeout"))*time.Second)
   233  	panicked := true
   234  	var err error
   235  	defer func() { // Panic recovery
   236  		if panicReason := recover(); panicReason != nil || panicked {
   237  			err = errors.NewSkip(panicReason, 4) // Skipped: runtime.Callers, NewSkip, this func, runtime.panic
   238  		}
   239  
   240  		if IsCloseError(err) {
   241  			_ = conn.CloseNormal()
   242  			return
   243  		}
   244  		if err != nil {
   245  			if errorHandler, ok := u.Controller.(ErrorHandler); ok {
   246  				errorHandler.OnError(request, err)
   247  			} else {
   248  				u.Logger().Error(err)
   249  			}
   250  			_ = conn.CloseWithError(err)
   251  		} else {
   252  			_ = conn.internalClose(ws.CloseNormalClosure, NormalClosureMessage)
   253  		}
   254  	}()
   255  
   256  	err = handler(conn, request)
   257  	if err != nil {
   258  		err = errors.New(err)
   259  	}
   260  	panicked = false
   261  }
   262  
   263  type adapter struct {
   264  	upgradeErrorHandler upgradeErrorHandlerFunc
   265  	checkOrigin         func(r *goyave.Request) bool
   266  	request             *goyave.Request
   267  }
   268  
   269  func (a *adapter) onError(w http.ResponseWriter, _ *http.Request, status int, reason error) {
   270  	if status == http.StatusInternalServerError {
   271  		panic(errors.New(reason))
   272  	}
   273  	w.Header().Set("Sec-Websocket-Version", "13")
   274  	a.upgradeErrorHandler(w.(*goyave.Response), a.request, status, reason)
   275  }
   276  
   277  func (a *adapter) getCheckOriginFunc() func(r *http.Request) bool {
   278  	if a.checkOrigin != nil {
   279  		return func(_ *http.Request) bool {
   280  			return a.checkOrigin(a.request)
   281  		}
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  // IsCloseError returns true if the error is one of the following close errors:
   288  // CloseNormalClosure (1000), CloseGoingAway (1001) or CloseNoStatusReceived (1005)
   289  func IsCloseError(err error) bool {
   290  	var closeError *ws.CloseError
   291  	if stderrors.As(err, &closeError) {
   292  		err = closeError
   293  	}
   294  	return ws.IsCloseError(err,
   295  		ws.CloseNormalClosure,
   296  		ws.CloseGoingAway,
   297  		ws.CloseNoStatusReceived,
   298  	)
   299  }