github.com/anycable/anycable-go@v1.5.1/ws/handler.go (about)

     1  package ws
     2  
     3  import (
     4  	"log/slog"
     5  	"net/http"
     6  	"net/url"
     7  	"strings"
     8  
     9  	"github.com/anycable/anycable-go/server"
    10  	"github.com/anycable/anycable-go/version"
    11  	"github.com/gorilla/websocket"
    12  )
    13  
    14  type sessionHandler = func(conn *websocket.Conn, info *server.RequestInfo, callback func()) error
    15  
    16  // WebsocketHandler generate a new http handler for WebSocket connections
    17  func WebsocketHandler(subprotocols []string, headersExtractor server.HeadersExtractor, config *Config, l *slog.Logger, sessionHandler sessionHandler) http.Handler {
    18  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    19  		ctx := l.With("context", "ws")
    20  
    21  		upgrader := websocket.Upgrader{
    22  			CheckOrigin:       CheckOrigin(config.AllowedOrigins),
    23  			Subprotocols:      subprotocols,
    24  			ReadBufferSize:    config.ReadBufferSize,
    25  			WriteBufferSize:   config.WriteBufferSize,
    26  			EnableCompression: config.EnableCompression,
    27  		}
    28  
    29  		rheader := map[string][]string{"X-AnyCable-Version": {version.Version()}}
    30  		wsc, err := upgrader.Upgrade(w, r, rheader)
    31  		if err != nil {
    32  			ctx.Debug("WebSocket connection upgrade failed", "error", err)
    33  			return
    34  		}
    35  
    36  		info, err := server.NewRequestInfo(r, headersExtractor)
    37  		if err != nil {
    38  			CloseWithReason(wsc, websocket.CloseAbnormalClosure, err.Error())
    39  			return
    40  		}
    41  
    42  		wsc.SetReadLimit(config.MaxMessageSize)
    43  
    44  		if config.EnableCompression {
    45  			wsc.EnableWriteCompression(true)
    46  		}
    47  
    48  		sessionCtx := l.With("sid", info.UID)
    49  
    50  		clientSubprotocol := r.Header.Get("Sec-Websocket-Protocol")
    51  
    52  		if wsc.Subprotocol() == "" && clientSubprotocol != "" {
    53  			sessionCtx.Debug("no subprotocol negotiated", "client", clientSubprotocol, "server", subprotocols)
    54  		}
    55  
    56  		// Separate goroutine for better GC of caller's data.
    57  		go func() {
    58  			sessionCtx.Debug("WebSocket session established")
    59  			serr := sessionHandler(wsc, info, func() {
    60  				sessionCtx.Debug("WebSocket session completed")
    61  			})
    62  
    63  			if serr != nil {
    64  				sessionCtx.Error("WebSocket session failed", "error", serr)
    65  				return
    66  			}
    67  		}()
    68  	})
    69  }
    70  
    71  func CheckOrigin(origins string) func(r *http.Request) bool {
    72  	if origins == "" {
    73  		return func(r *http.Request) bool { return true }
    74  	}
    75  
    76  	hosts := strings.Split(strings.ToLower(origins), ",")
    77  
    78  	return func(r *http.Request) bool {
    79  		origin := strings.ToLower(r.Header.Get("Origin"))
    80  		u, err := url.Parse(origin)
    81  		if err != nil {
    82  			return false
    83  		}
    84  
    85  		for _, host := range hosts {
    86  			if host[0] == '*' && strings.HasSuffix(u.Host, host[1:]) {
    87  				return true
    88  			}
    89  			if u.Host == host {
    90  				return true
    91  			}
    92  		}
    93  		return false
    94  	}
    95  }