github.com/anycable/anycable-go@v1.5.1/ws/handler.go (about) 1 package ws 2 3 import ( 4 "log/slog" 5 "net/http" 6 "net/url" 7 "strings" 8 9 "github.com/anycable/anycable-go/server" 10 "github.com/anycable/anycable-go/version" 11 "github.com/gorilla/websocket" 12 ) 13 14 type sessionHandler = func(conn *websocket.Conn, info *server.RequestInfo, callback func()) error 15 16 // WebsocketHandler generate a new http handler for WebSocket connections 17 func WebsocketHandler(subprotocols []string, headersExtractor server.HeadersExtractor, config *Config, l *slog.Logger, sessionHandler sessionHandler) http.Handler { 18 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 19 ctx := l.With("context", "ws") 20 21 upgrader := websocket.Upgrader{ 22 CheckOrigin: CheckOrigin(config.AllowedOrigins), 23 Subprotocols: subprotocols, 24 ReadBufferSize: config.ReadBufferSize, 25 WriteBufferSize: config.WriteBufferSize, 26 EnableCompression: config.EnableCompression, 27 } 28 29 rheader := map[string][]string{"X-AnyCable-Version": {version.Version()}} 30 wsc, err := upgrader.Upgrade(w, r, rheader) 31 if err != nil { 32 ctx.Debug("WebSocket connection upgrade failed", "error", err) 33 return 34 } 35 36 info, err := server.NewRequestInfo(r, headersExtractor) 37 if err != nil { 38 CloseWithReason(wsc, websocket.CloseAbnormalClosure, err.Error()) 39 return 40 } 41 42 wsc.SetReadLimit(config.MaxMessageSize) 43 44 if config.EnableCompression { 45 wsc.EnableWriteCompression(true) 46 } 47 48 sessionCtx := l.With("sid", info.UID) 49 50 clientSubprotocol := r.Header.Get("Sec-Websocket-Protocol") 51 52 if wsc.Subprotocol() == "" && clientSubprotocol != "" { 53 sessionCtx.Debug("no subprotocol negotiated", "client", clientSubprotocol, "server", subprotocols) 54 } 55 56 // Separate goroutine for better GC of caller's data. 57 go func() { 58 sessionCtx.Debug("WebSocket session established") 59 serr := sessionHandler(wsc, info, func() { 60 sessionCtx.Debug("WebSocket session completed") 61 }) 62 63 if serr != nil { 64 sessionCtx.Error("WebSocket session failed", "error", serr) 65 return 66 } 67 }() 68 }) 69 } 70 71 func CheckOrigin(origins string) func(r *http.Request) bool { 72 if origins == "" { 73 return func(r *http.Request) bool { return true } 74 } 75 76 hosts := strings.Split(strings.ToLower(origins), ",") 77 78 return func(r *http.Request) bool { 79 origin := strings.ToLower(r.Header.Get("Origin")) 80 u, err := url.Parse(origin) 81 if err != nil { 82 return false 83 } 84 85 for _, host := range hosts { 86 if host[0] == '*' && strings.HasSuffix(u.Host, host[1:]) { 87 return true 88 } 89 if u.Host == host { 90 return true 91 } 92 } 93 return false 94 } 95 }