github.com/eagleql/xray-core@v1.4.4/transport/internet/websocket/connection.go (about) 1 package websocket 2 3 import ( 4 "io" 5 "net" 6 "time" 7 8 "github.com/eagleql/xray-core/common/buf" 9 "github.com/eagleql/xray-core/common/errors" 10 "github.com/eagleql/xray-core/common/serial" 11 "github.com/gorilla/websocket" 12 ) 13 14 var ( 15 _ buf.Writer = (*connection)(nil) 16 ) 17 18 // connection is a wrapper for net.Conn over WebSocket connection. 19 type connection struct { 20 conn *websocket.Conn 21 reader io.Reader 22 remoteAddr net.Addr 23 } 24 25 func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection { 26 return &connection{ 27 conn: conn, 28 remoteAddr: remoteAddr, 29 reader: extraReader, 30 } 31 } 32 33 // Read implements net.Conn.Read() 34 func (c *connection) Read(b []byte) (int, error) { 35 for { 36 reader, err := c.getReader() 37 if err != nil { 38 return 0, err 39 } 40 41 nBytes, err := reader.Read(b) 42 if errors.Cause(err) == io.EOF { 43 c.reader = nil 44 continue 45 } 46 return nBytes, err 47 } 48 } 49 50 func (c *connection) getReader() (io.Reader, error) { 51 if c.reader != nil { 52 return c.reader, nil 53 } 54 55 _, reader, err := c.conn.NextReader() 56 if err != nil { 57 return nil, err 58 } 59 c.reader = reader 60 return reader, nil 61 } 62 63 // Write implements io.Writer. 64 func (c *connection) Write(b []byte) (int, error) { 65 if err := c.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { 66 return 0, err 67 } 68 return len(b), nil 69 } 70 71 func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { 72 mb = buf.Compact(mb) 73 mb, err := buf.WriteMultiBuffer(c, mb) 74 buf.ReleaseMulti(mb) 75 return err 76 } 77 78 func (c *connection) Close() error { 79 var errors []interface{} 80 if err := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { 81 errors = append(errors, err) 82 } 83 if err := c.conn.Close(); err != nil { 84 errors = append(errors, err) 85 } 86 if len(errors) > 0 { 87 return newError("failed to close connection").Base(newError(serial.Concat(errors...))) 88 } 89 return nil 90 } 91 92 func (c *connection) LocalAddr() net.Addr { 93 return c.conn.LocalAddr() 94 } 95 96 func (c *connection) RemoteAddr() net.Addr { 97 return c.remoteAddr 98 } 99 100 func (c *connection) SetDeadline(t time.Time) error { 101 if err := c.SetReadDeadline(t); err != nil { 102 return err 103 } 104 return c.SetWriteDeadline(t) 105 } 106 107 func (c *connection) SetReadDeadline(t time.Time) error { 108 return c.conn.SetReadDeadline(t) 109 } 110 111 func (c *connection) SetWriteDeadline(t time.Time) error { 112 return c.conn.SetWriteDeadline(t) 113 }