github.com/glide-im/glide@v1.6.0/pkg/conn/ws_conn.go (about) 1 package conn 2 3 import ( 4 "github.com/gorilla/websocket" 5 "net" 6 "strings" 7 "time" 8 ) 9 10 type WsConnection struct { 11 options *WsServerOptions 12 conn *websocket.Conn 13 } 14 15 func NewWsConnection(conn *websocket.Conn, options *WsServerOptions) *WsConnection { 16 c := new(WsConnection) 17 c.conn = conn 18 c.options = options 19 c.conn.SetCloseHandler(func(code int, text string) error { 20 return ErrClosed 21 }) 22 return c 23 } 24 25 func (c *WsConnection) Write(data []byte) error { 26 deadLine := time.Now().Add(c.options.WriteTimeout) 27 _ = c.conn.SetWriteDeadline(deadLine) 28 29 err := c.conn.WriteMessage(websocket.TextMessage, data) 30 return c.wrapError(err) 31 } 32 33 func (c *WsConnection) Read() ([]byte, error) { 34 35 deadLine := time.Now().Add(c.options.ReadTimeout) 36 _ = c.conn.SetReadDeadline(deadLine) 37 38 msgType, bytes, err := c.conn.ReadMessage() 39 if err != nil { 40 return nil, c.wrapError(err) 41 } 42 43 switch msgType { 44 case websocket.TextMessage: 45 case websocket.PingMessage: 46 case websocket.BinaryMessage: 47 default: 48 return nil, ErrBadPackage 49 } 50 51 return bytes, err 52 } 53 54 func (c *WsConnection) Close() error { 55 return c.wrapError(c.conn.Close()) 56 } 57 58 func (c *WsConnection) GetConnInfo() *ConnectionInfo { 59 c.conn.UnderlyingConn() 60 remoteAddr := c.conn.RemoteAddr().(*net.TCPAddr) 61 info := ConnectionInfo{ 62 Ip: remoteAddr.IP.String(), 63 Port: remoteAddr.Port, 64 Addr: c.conn.RemoteAddr().String(), 65 } 66 return &info 67 } 68 69 func (c *WsConnection) wrapError(err error) error { 70 if err == nil { 71 return nil 72 } 73 if websocket.IsUnexpectedCloseError(err) { 74 return ErrClosed 75 } 76 if websocket.IsCloseError(err) { 77 return ErrClosed 78 } 79 if strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host") { 80 _ = c.conn.Close() 81 return ErrClosed 82 } 83 if strings.Contains(err.Error(), "use of closed network conn") { 84 _ = c.conn.Close() 85 return ErrClosed 86 } 87 if strings.Contains(err.Error(), "i/o timeout") { 88 return ErrReadTimeout 89 } 90 return err 91 }