github.com/metaworking/channeld@v0.7.3/pkg/channeld/connection_websocket.go (about)

     1  package channeld
     2  
     3  import (
     4  	"net"
     5  	"net/http"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/gorilla/websocket"
    10  	"github.com/metaworking/channeld/pkg/channeldpb"
    11  	"go.uber.org/zap"
    12  )
    13  
    14  type wsConn struct {
    15  	conn *websocket.Conn
    16  }
    17  
    18  func (c *wsConn) Read(b []byte) (n int, err error) {
    19  	_, body, err := c.conn.ReadMessage()
    20  	return copy(b, body), err
    21  }
    22  
    23  func (c *wsConn) Write(b []byte) (n int, err error) {
    24  	return len(b), c.conn.WriteMessage(websocket.BinaryMessage, b)
    25  }
    26  
    27  func (c *wsConn) Close() error {
    28  	return c.conn.Close()
    29  }
    30  
    31  func (c *wsConn) LocalAddr() net.Addr {
    32  	return c.conn.LocalAddr()
    33  }
    34  
    35  func (c *wsConn) RemoteAddr() net.Addr {
    36  	return c.conn.RemoteAddr()
    37  }
    38  
    39  func (c *wsConn) SetDeadline(t time.Time) error {
    40  	return c.conn.UnderlyingConn().SetDeadline(t)
    41  }
    42  
    43  func (c *wsConn) SetReadDeadline(t time.Time) error {
    44  	return c.conn.SetReadDeadline(t)
    45  }
    46  
    47  func (c *wsConn) SetWriteDeadline(t time.Time) error {
    48  	return c.conn.SetWriteDeadline(t)
    49  }
    50  
    51  var trustedOrigins []string
    52  
    53  func SetWebSocketTrustedOrigins(addrs []string) {
    54  	trustedOrigins = addrs
    55  }
    56  
    57  var upgrader websocket.Upgrader = websocket.Upgrader{
    58  	CheckOrigin: func(r *http.Request) bool {
    59  		if trustedOrigins == nil {
    60  			return true
    61  		} else {
    62  			for _, addr := range trustedOrigins {
    63  				if addr == r.RemoteAddr {
    64  					return true
    65  				}
    66  			}
    67  			return false
    68  		}
    69  	},
    70  }
    71  
    72  func startWebSocketServer(t channeldpb.ConnectionType, address string) {
    73  	if protocolIndex := strings.Index(address, "://"); protocolIndex >= 0 {
    74  		address = address[protocolIndex+3:]
    75  	}
    76  
    77  	pattern := "/"
    78  	if pathIndex := strings.Index(address, "/"); pathIndex >= 0 {
    79  		pattern = address[pathIndex:]
    80  		address = address[:pathIndex-1]
    81  	}
    82  
    83  	mux := http.NewServeMux()
    84  	connsToAdd := make(chan *websocket.Conn, 128)
    85  	mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
    86  		conn, err := upgrader.Upgrade(w, r, nil)
    87  		if err != nil {
    88  			rootLogger.Panic("Upgrade to websocket connection", zap.Error(err))
    89  		}
    90  		// Add the websocket connection to a blocking queue instead of calling AddConnection() immediately,
    91  		// as a new goroutines is created per request.
    92  		connsToAdd <- conn
    93  	})
    94  
    95  	serverClosed := false
    96  	// Call AddConnection() in a separate goroutine, to avoid the race condition.
    97  	go func() {
    98  		for !serverClosed {
    99  			conn := <-connsToAdd
   100  			c := AddConnection(&wsConn{conn}, t)
   101  			startGoroutines(c)
   102  		}
   103  	}()
   104  
   105  	server := http.Server{
   106  		Addr:    address,
   107  		Handler: mux,
   108  	}
   109  
   110  	defer server.Close()
   111  
   112  	rootLogger.Error("stopped listening", zap.Error(server.ListenAndServe()))
   113  	serverClosed = true
   114  }