github.com/sagernet/sing-box@v1.2.7/transport/v2raywebsocket/conn.go (about) 1 package v2raywebsocket 2 3 import ( 4 "context" 5 "encoding/base64" 6 "io" 7 "net" 8 "net/http" 9 "os" 10 "time" 11 12 C "github.com/sagernet/sing-box/constant" 13 "github.com/sagernet/sing/common" 14 "github.com/sagernet/sing/common/buf" 15 E "github.com/sagernet/sing/common/exceptions" 16 "github.com/sagernet/websocket" 17 ) 18 19 type WebsocketConn struct { 20 *websocket.Conn 21 *Writer 22 remoteAddr net.Addr 23 reader io.Reader 24 } 25 26 func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn { 27 return &WebsocketConn{ 28 Conn: wsConn, 29 remoteAddr: remoteAddr, 30 Writer: NewWriter(wsConn, true), 31 } 32 } 33 34 func (c *WebsocketConn) Close() error { 35 err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout)) 36 if err != nil { 37 return c.Conn.Close() 38 } 39 return nil 40 } 41 42 func (c *WebsocketConn) Read(b []byte) (n int, err error) { 43 for { 44 if c.reader == nil { 45 _, c.reader, err = c.NextReader() 46 if err != nil { 47 err = wrapError(err) 48 return 49 } 50 } 51 n, err = c.reader.Read(b) 52 if E.IsMulti(err, io.EOF) { 53 c.reader = nil 54 continue 55 } 56 err = wrapError(err) 57 return 58 } 59 } 60 61 func (c *WebsocketConn) RemoteAddr() net.Addr { 62 if c.remoteAddr != nil { 63 return c.remoteAddr 64 } 65 return c.Conn.RemoteAddr() 66 } 67 68 func (c *WebsocketConn) SetDeadline(t time.Time) error { 69 return os.ErrInvalid 70 } 71 72 func (c *WebsocketConn) SetReadDeadline(t time.Time) error { 73 return os.ErrInvalid 74 } 75 76 func (c *WebsocketConn) SetWriteDeadline(t time.Time) error { 77 return os.ErrInvalid 78 } 79 80 func (c *WebsocketConn) NeedAdditionalReadDeadline() bool { 81 return true 82 } 83 84 func (c *WebsocketConn) Upstream() any { 85 return c.Conn.NetConn() 86 } 87 88 func (c *WebsocketConn) UpstreamWriter() any { 89 return c.Writer 90 } 91 92 type EarlyWebsocketConn struct { 93 *Client 94 ctx context.Context 95 conn *WebsocketConn 96 create chan struct{} 97 } 98 99 func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { 100 if c.conn == nil { 101 <-c.create 102 } 103 return c.conn.Read(b) 104 } 105 106 func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { 107 if c.conn != nil { 108 return c.conn.Write(b) 109 } 110 var ( 111 earlyData []byte 112 lateData []byte 113 conn *websocket.Conn 114 response *http.Response 115 ) 116 if len(b) > int(c.maxEarlyData) { 117 earlyData = b[:c.maxEarlyData] 118 lateData = b[c.maxEarlyData:] 119 } else { 120 earlyData = b 121 } 122 if len(earlyData) > 0 { 123 earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) 124 if c.earlyDataHeaderName == "" { 125 conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers) 126 } else { 127 headers := c.headers.Clone() 128 headers.Set(c.earlyDataHeaderName, earlyDataString) 129 conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers) 130 } 131 } else { 132 conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers) 133 } 134 if err != nil { 135 return 0, wrapDialError(response, err) 136 } 137 c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} 138 close(c.create) 139 if len(lateData) > 0 { 140 _, err = c.conn.Write(lateData) 141 } 142 if err != nil { 143 return 144 } 145 return len(b), nil 146 } 147 148 func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { 149 if c.conn != nil { 150 return c.conn.WriteBuffer(buffer) 151 } 152 var ( 153 earlyData []byte 154 lateData []byte 155 conn *websocket.Conn 156 response *http.Response 157 err error 158 ) 159 if buffer.Len() > int(c.maxEarlyData) { 160 earlyData = buffer.Bytes()[:c.maxEarlyData] 161 lateData = buffer.Bytes()[c.maxEarlyData:] 162 } else { 163 earlyData = buffer.Bytes() 164 } 165 if len(earlyData) > 0 { 166 earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) 167 if c.earlyDataHeaderName == "" { 168 conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers) 169 } else { 170 headers := c.headers.Clone() 171 headers.Set(c.earlyDataHeaderName, earlyDataString) 172 conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers) 173 } 174 } else { 175 conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers) 176 } 177 if err != nil { 178 return wrapDialError(response, err) 179 } 180 c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} 181 close(c.create) 182 if len(lateData) > 0 { 183 _, err = c.conn.Write(lateData) 184 } 185 return err 186 } 187 188 func (c *EarlyWebsocketConn) Close() error { 189 if c.conn == nil { 190 return nil 191 } 192 return c.conn.Close() 193 } 194 195 func (c *EarlyWebsocketConn) LocalAddr() net.Addr { 196 if c.conn == nil { 197 return nil 198 } 199 return c.conn.LocalAddr() 200 } 201 202 func (c *EarlyWebsocketConn) RemoteAddr() net.Addr { 203 if c.conn == nil { 204 return nil 205 } 206 return c.conn.RemoteAddr() 207 } 208 209 func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error { 210 return os.ErrInvalid 211 } 212 213 func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error { 214 return os.ErrInvalid 215 } 216 217 func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error { 218 return os.ErrInvalid 219 } 220 221 func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool { 222 return true 223 } 224 225 func (c *EarlyWebsocketConn) Upstream() any { 226 return common.PtrOrNil(c.conn) 227 } 228 229 func (c *EarlyWebsocketConn) LazyHeadroom() bool { 230 return c.conn == nil 231 } 232 233 func wrapError(err error) error { 234 if websocket.IsCloseError(err, websocket.CloseNormalClosure) { 235 return io.EOF 236 } 237 if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { 238 return net.ErrClosed 239 } 240 return err 241 }