github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2raywebsocket/conn.go (about) 1 package v2raywebsocket 2 3 import ( 4 "context" 5 "encoding/base64" 6 "io" 7 "net" 8 "os" 9 "sync" 10 "time" 11 12 C "github.com/sagernet/sing-box/constant" 13 "github.com/sagernet/sing/common" 14 "github.com/sagernet/sing/common/buf" 15 "github.com/sagernet/sing/common/debug" 16 E "github.com/sagernet/sing/common/exceptions" 17 M "github.com/sagernet/sing/common/metadata" 18 "github.com/sagernet/ws" 19 "github.com/sagernet/ws/wsutil" 20 ) 21 22 type WebsocketConn struct { 23 net.Conn 24 *Writer 25 state ws.State 26 reader *wsutil.Reader 27 controlHandler wsutil.FrameHandlerFunc 28 remoteAddr net.Addr 29 } 30 31 func NewConn(conn net.Conn, remoteAddr net.Addr, state ws.State) *WebsocketConn { 32 controlHandler := wsutil.ControlFrameHandler(conn, state) 33 return &WebsocketConn{ 34 Conn: conn, 35 state: state, 36 reader: &wsutil.Reader{ 37 Source: conn, 38 State: state, 39 SkipHeaderCheck: !debug.Enabled, 40 OnIntermediate: controlHandler, 41 }, 42 controlHandler: controlHandler, 43 remoteAddr: remoteAddr, 44 Writer: NewWriter(conn, state), 45 } 46 } 47 48 func (c *WebsocketConn) Close() error { 49 c.Conn.SetWriteDeadline(time.Now().Add(C.TCPTimeout)) 50 frame := ws.NewCloseFrame(ws.NewCloseFrameBody( 51 ws.StatusNormalClosure, "", 52 )) 53 if c.state == ws.StateClientSide { 54 frame = ws.MaskFrameInPlace(frame) 55 } 56 ws.WriteFrame(c.Conn, frame) 57 c.Conn.Close() 58 return nil 59 } 60 61 func (c *WebsocketConn) Read(b []byte) (n int, err error) { 62 var header ws.Header 63 for { 64 n, err = c.reader.Read(b) 65 if n > 0 { 66 err = nil 67 return 68 } 69 if !E.IsMulti(err, io.EOF, wsutil.ErrNoFrameAdvance) { 70 return 71 } 72 header, err = c.reader.NextFrame() 73 if err != nil { 74 return 75 } 76 if header.OpCode.IsControl() { 77 err = c.controlHandler(header, c.reader) 78 if err != nil { 79 return 80 } 81 continue 82 } 83 if header.OpCode&ws.OpBinary == 0 { 84 err = c.reader.Discard() 85 if err != nil { 86 return 87 } 88 continue 89 } 90 } 91 } 92 93 func (c *WebsocketConn) Write(p []byte) (n int, err error) { 94 err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p) 95 if err != nil { 96 return 97 } 98 n = len(p) 99 return 100 } 101 102 func (c *WebsocketConn) RemoteAddr() net.Addr { 103 if c.remoteAddr != nil { 104 return c.remoteAddr 105 } 106 return c.Conn.RemoteAddr() 107 } 108 109 func (c *WebsocketConn) SetDeadline(t time.Time) error { 110 return os.ErrInvalid 111 } 112 113 func (c *WebsocketConn) SetReadDeadline(t time.Time) error { 114 return os.ErrInvalid 115 } 116 117 func (c *WebsocketConn) SetWriteDeadline(t time.Time) error { 118 return os.ErrInvalid 119 } 120 121 func (c *WebsocketConn) NeedAdditionalReadDeadline() bool { 122 return true 123 } 124 125 func (c *WebsocketConn) Upstream() any { 126 return c.Conn 127 } 128 129 type EarlyWebsocketConn struct { 130 *Client 131 ctx context.Context 132 conn *WebsocketConn 133 access sync.Mutex 134 create chan struct{} 135 err error 136 } 137 138 func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { 139 if c.conn == nil { 140 <-c.create 141 if c.err != nil { 142 return 0, c.err 143 } 144 } 145 return c.conn.Read(b) 146 } 147 148 func (c *EarlyWebsocketConn) writeRequest(content []byte) error { 149 var ( 150 earlyData []byte 151 lateData []byte 152 conn *WebsocketConn 153 err error 154 ) 155 if len(content) > int(c.maxEarlyData) { 156 earlyData = content[:c.maxEarlyData] 157 lateData = content[c.maxEarlyData:] 158 } else { 159 earlyData = content 160 } 161 if len(earlyData) > 0 { 162 earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) 163 if c.earlyDataHeaderName == "" { 164 requestURL := c.requestURL 165 requestURL.Path += earlyDataString 166 conn, err = c.dialContext(c.ctx, &requestURL, c.headers) 167 } else { 168 headers := c.headers.Clone() 169 headers.Set(c.earlyDataHeaderName, earlyDataString) 170 conn, err = c.dialContext(c.ctx, &c.requestURL, headers) 171 } 172 } else { 173 conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers) 174 } 175 if err != nil { 176 return err 177 } 178 if len(lateData) > 0 { 179 _, err = conn.Write(lateData) 180 if err != nil { 181 return err 182 } 183 } 184 c.conn = conn 185 return nil 186 } 187 188 func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { 189 if c.conn != nil { 190 return c.conn.Write(b) 191 } 192 c.access.Lock() 193 defer c.access.Unlock() 194 if c.err != nil { 195 return 0, c.err 196 } 197 if c.conn != nil { 198 return c.conn.Write(b) 199 } 200 err = c.writeRequest(b) 201 c.err = err 202 close(c.create) 203 if err != nil { 204 return 205 } 206 return len(b), nil 207 } 208 209 func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { 210 if c.conn != nil { 211 return c.conn.WriteBuffer(buffer) 212 } 213 c.access.Lock() 214 defer c.access.Unlock() 215 if c.conn != nil { 216 return c.conn.WriteBuffer(buffer) 217 } 218 if c.err != nil { 219 return c.err 220 } 221 err := c.writeRequest(buffer.Bytes()) 222 c.err = err 223 close(c.create) 224 return err 225 } 226 227 func (c *EarlyWebsocketConn) Close() error { 228 if c.conn == nil { 229 return nil 230 } 231 return c.conn.Close() 232 } 233 234 func (c *EarlyWebsocketConn) LocalAddr() net.Addr { 235 if c.conn == nil { 236 return M.Socksaddr{} 237 } 238 return c.conn.LocalAddr() 239 } 240 241 func (c *EarlyWebsocketConn) RemoteAddr() net.Addr { 242 if c.conn == nil { 243 return M.Socksaddr{} 244 } 245 return c.conn.RemoteAddr() 246 } 247 248 func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error { 249 return os.ErrInvalid 250 } 251 252 func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error { 253 return os.ErrInvalid 254 } 255 256 func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error { 257 return os.ErrInvalid 258 } 259 260 func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool { 261 return true 262 } 263 264 func (c *EarlyWebsocketConn) Upstream() any { 265 return common.PtrOrNil(c.conn) 266 } 267 268 func (c *EarlyWebsocketConn) LazyHeadroom() bool { 269 return c.conn == nil 270 }