github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/websocket/connection.go (about) 1 package websocket 2 3 import ( 4 "context" 5 "io" 6 "net" 7 "time" 8 9 "github.com/gorilla/websocket" 10 11 "github.com/v2fly/v2ray-core/v5/common/buf" 12 "github.com/v2fly/v2ray-core/v5/common/errors" 13 "github.com/v2fly/v2ray-core/v5/common/serial" 14 ) 15 16 var _ buf.Writer = (*connection)(nil) 17 18 // connection is a wrapper for net.Conn over WebSocket connection. 19 type connection struct { 20 conn *websocket.Conn 21 reader io.Reader 22 remoteAddr net.Addr 23 24 shouldWait bool 25 delayedDialFinish context.Context 26 finishedDial context.CancelFunc 27 dialer DelayedDialer 28 } 29 30 type DelayedDialer interface { 31 Dial(earlyData []byte) (*websocket.Conn, error) 32 } 33 34 func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection { 35 return &connection{ 36 conn: conn, 37 remoteAddr: remoteAddr, 38 } 39 } 40 41 func newConnectionWithEarlyData(conn *websocket.Conn, remoteAddr net.Addr, earlyData io.Reader) *connection { 42 return &connection{ 43 conn: conn, 44 remoteAddr: remoteAddr, 45 reader: earlyData, 46 } 47 } 48 49 func newConnectionWithDelayedDial(dialer DelayedDialer) *connection { 50 delayedDialContext, cancelFunc := context.WithCancel(context.Background()) 51 return &connection{ 52 shouldWait: true, 53 delayedDialFinish: delayedDialContext, 54 finishedDial: cancelFunc, 55 dialer: dialer, 56 } 57 } 58 59 func newRelayedConnectionWithDelayedDial(dialer DelayedDialerForwarded) *connectionForwarder { 60 delayedDialContext, cancelFunc := context.WithCancel(context.Background()) 61 return &connectionForwarder{ 62 shouldWait: true, 63 delayedDialFinish: delayedDialContext, 64 finishedDial: cancelFunc, 65 dialer: dialer, 66 } 67 } 68 69 func newRelayedConnection(conn io.ReadWriteCloser) *connectionForwarder { 70 return &connectionForwarder{ 71 ReadWriteCloser: conn, 72 shouldWait: false, 73 } 74 } 75 76 // Read implements net.Conn.Read() 77 func (c *connection) Read(b []byte) (int, error) { 78 for { 79 reader, err := c.getReader() 80 if err != nil { 81 return 0, err 82 } 83 84 nBytes, err := reader.Read(b) 85 if errors.Cause(err) == io.EOF { 86 c.reader = nil 87 continue 88 } 89 return nBytes, err 90 } 91 } 92 93 func (c *connection) getReader() (io.Reader, error) { 94 if c.shouldWait { 95 <-c.delayedDialFinish.Done() 96 if c.conn == nil { 97 return nil, newError("unable to read delayed dial websocket connection as it do not exist") 98 } 99 } 100 if c.reader != nil { 101 return c.reader, nil 102 } 103 104 _, reader, err := c.conn.NextReader() 105 if err != nil { 106 return nil, err 107 } 108 c.reader = reader 109 return reader, nil 110 } 111 112 // Write implements io.Writer. 113 func (c *connection) Write(b []byte) (int, error) { 114 if c.shouldWait { 115 var err error 116 c.conn, err = c.dialer.Dial(b) 117 c.finishedDial() 118 if err != nil { 119 return 0, newError("Unable to proceed with delayed write").Base(err) 120 } 121 c.remoteAddr = c.conn.RemoteAddr() 122 c.shouldWait = false 123 return len(b), nil 124 } 125 if err := c.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { 126 return 0, err 127 } 128 return len(b), nil 129 } 130 131 func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { 132 mb = buf.Compact(mb) 133 mb, err := buf.WriteMultiBuffer(c, mb) 134 buf.ReleaseMulti(mb) 135 return err 136 } 137 138 func (c *connection) Close() error { 139 if c.shouldWait { 140 <-c.delayedDialFinish.Done() 141 if c.conn == nil { 142 return newError("unable to close delayed dial websocket connection as it do not exist") 143 } 144 } 145 var errors []interface{} 146 if err := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { 147 errors = append(errors, err) 148 } 149 if err := c.conn.Close(); err != nil { 150 errors = append(errors, err) 151 } 152 if len(errors) > 0 { 153 return newError("failed to close connection").Base(newError(serial.Concat(errors...))) 154 } 155 return nil 156 } 157 158 func (c *connection) LocalAddr() net.Addr { 159 if c.shouldWait { 160 <-c.delayedDialFinish.Done() 161 if c.conn == nil { 162 newError("websocket transport is not materialized when LocalAddr() is called").AtWarning().WriteToLog() 163 return &net.UnixAddr{ 164 Name: "@placeholder", 165 Net: "unix", 166 } 167 } 168 } 169 return c.conn.LocalAddr() 170 } 171 172 func (c *connection) RemoteAddr() net.Addr { 173 return c.remoteAddr 174 } 175 176 func (c *connection) SetDeadline(t time.Time) error { 177 if err := c.SetReadDeadline(t); err != nil { 178 return err 179 } 180 return c.SetWriteDeadline(t) 181 } 182 183 func (c *connection) SetReadDeadline(t time.Time) error { 184 if c.shouldWait { 185 <-c.delayedDialFinish.Done() 186 if c.conn == nil { 187 newError("websocket transport is not materialized when SetReadDeadline() is called").AtWarning().WriteToLog() 188 return nil 189 } 190 } 191 return c.conn.SetReadDeadline(t) 192 } 193 194 func (c *connection) SetWriteDeadline(t time.Time) error { 195 if c.shouldWait { 196 <-c.delayedDialFinish.Done() 197 if c.conn == nil { 198 newError("websocket transport is not materialized when SetWriteDeadline() is called").AtWarning().WriteToLog() 199 return nil 200 } 201 } 202 return c.conn.SetWriteDeadline(t) 203 }