github.com/metaworking/channeld@v0.7.3/pkg/channeld/connection_websocket.go (about) 1 package channeld 2 3 import ( 4 "net" 5 "net/http" 6 "strings" 7 "time" 8 9 "github.com/gorilla/websocket" 10 "github.com/metaworking/channeld/pkg/channeldpb" 11 "go.uber.org/zap" 12 ) 13 14 type wsConn struct { 15 conn *websocket.Conn 16 } 17 18 func (c *wsConn) Read(b []byte) (n int, err error) { 19 _, body, err := c.conn.ReadMessage() 20 return copy(b, body), err 21 } 22 23 func (c *wsConn) Write(b []byte) (n int, err error) { 24 return len(b), c.conn.WriteMessage(websocket.BinaryMessage, b) 25 } 26 27 func (c *wsConn) Close() error { 28 return c.conn.Close() 29 } 30 31 func (c *wsConn) LocalAddr() net.Addr { 32 return c.conn.LocalAddr() 33 } 34 35 func (c *wsConn) RemoteAddr() net.Addr { 36 return c.conn.RemoteAddr() 37 } 38 39 func (c *wsConn) SetDeadline(t time.Time) error { 40 return c.conn.UnderlyingConn().SetDeadline(t) 41 } 42 43 func (c *wsConn) SetReadDeadline(t time.Time) error { 44 return c.conn.SetReadDeadline(t) 45 } 46 47 func (c *wsConn) SetWriteDeadline(t time.Time) error { 48 return c.conn.SetWriteDeadline(t) 49 } 50 51 var trustedOrigins []string 52 53 func SetWebSocketTrustedOrigins(addrs []string) { 54 trustedOrigins = addrs 55 } 56 57 var upgrader websocket.Upgrader = websocket.Upgrader{ 58 CheckOrigin: func(r *http.Request) bool { 59 if trustedOrigins == nil { 60 return true 61 } else { 62 for _, addr := range trustedOrigins { 63 if addr == r.RemoteAddr { 64 return true 65 } 66 } 67 return false 68 } 69 }, 70 } 71 72 func startWebSocketServer(t channeldpb.ConnectionType, address string) { 73 if protocolIndex := strings.Index(address, "://"); protocolIndex >= 0 { 74 address = address[protocolIndex+3:] 75 } 76 77 pattern := "/" 78 if pathIndex := strings.Index(address, "/"); pathIndex >= 0 { 79 pattern = address[pathIndex:] 80 address = address[:pathIndex-1] 81 } 82 83 mux := http.NewServeMux() 84 connsToAdd := make(chan *websocket.Conn, 128) 85 mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { 86 conn, err := upgrader.Upgrade(w, r, nil) 87 if err != nil { 88 rootLogger.Panic("Upgrade to websocket connection", zap.Error(err)) 89 } 90 // Add the websocket connection to a blocking queue instead of calling AddConnection() immediately, 91 // as a new goroutines is created per request. 92 connsToAdd <- conn 93 }) 94 95 serverClosed := false 96 // Call AddConnection() in a separate goroutine, to avoid the race condition. 97 go func() { 98 for !serverClosed { 99 conn := <-connsToAdd 100 c := AddConnection(&wsConn{conn}, t) 101 startGoroutines(c) 102 } 103 }() 104 105 server := http.Server{ 106 Addr: address, 107 Handler: mux, 108 } 109 110 defer server.Close() 111 112 rootLogger.Error("stopped listening", zap.Error(server.ListenAndServe())) 113 serverClosed = true 114 }