github.com/yaling888/clash@v1.53.0/transport/vmess/websocket.go (about) 1 package vmess 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "net/http" 13 "net/url" 14 "strconv" 15 "sync" 16 "time" 17 18 "github.com/gorilla/websocket" 19 20 "github.com/yaling888/clash/common/errors2" 21 ) 22 23 type websocketConn struct { 24 conn *websocket.Conn 25 reader io.Reader 26 remoteAddr net.Addr 27 28 // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency 29 rMux sync.Mutex 30 wMux sync.Mutex 31 } 32 33 type websocketWithEarlyDataConn struct { 34 net.Conn 35 underlay net.Conn 36 closed bool 37 dialed chan bool 38 cancel context.CancelFunc 39 ctx context.Context 40 config *WebsocketConfig 41 } 42 43 type WebsocketConfig struct { 44 Host string 45 Port string 46 Path string 47 Headers http.Header 48 TLS bool 49 TLSConfig *tls.Config 50 MaxEarlyData int 51 EarlyDataHeaderName string 52 } 53 54 // Read implements net.Conn.Read() 55 func (wsc *websocketConn) Read(b []byte) (int, error) { 56 wsc.rMux.Lock() 57 defer wsc.rMux.Unlock() 58 for { 59 reader, err := wsc.getReader() 60 if err != nil { 61 return 0, err 62 } 63 64 nBytes, err := reader.Read(b) 65 if err == io.EOF { 66 wsc.reader = nil 67 continue 68 } 69 return nBytes, err 70 } 71 } 72 73 // Write implements io.Writer. 74 func (wsc *websocketConn) Write(b []byte) (int, error) { 75 wsc.wMux.Lock() 76 defer wsc.wMux.Unlock() 77 if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { 78 return 0, err 79 } 80 return len(b), nil 81 } 82 83 func (wsc *websocketConn) Close() error { 84 var errs error 85 if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { 86 errs = errors.Join(errs, err) 87 } 88 if err := wsc.conn.Close(); err != nil { 89 errs = errors.Join(errs, err) 90 } 91 if errs != nil { 92 errs = errors.Join(errors.New("failed to close connection"), errs) 93 return errors2.Cause(errs) 94 } 95 return nil 96 } 97 98 func (wsc *websocketConn) getReader() (io.Reader, error) { 99 if wsc.reader != nil { 100 return wsc.reader, nil 101 } 102 103 _, reader, err := wsc.conn.NextReader() 104 if err != nil { 105 return nil, err 106 } 107 wsc.reader = reader 108 return reader, nil 109 } 110 111 func (wsc *websocketConn) LocalAddr() net.Addr { 112 return wsc.conn.LocalAddr() 113 } 114 115 func (wsc *websocketConn) RemoteAddr() net.Addr { 116 return wsc.remoteAddr 117 } 118 119 func (wsc *websocketConn) SetDeadline(t time.Time) error { 120 if err := wsc.SetReadDeadline(t); err != nil { 121 return err 122 } 123 return wsc.SetWriteDeadline(t) 124 } 125 126 func (wsc *websocketConn) SetReadDeadline(t time.Time) error { 127 return wsc.conn.SetReadDeadline(t) 128 } 129 130 func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { 131 return wsc.conn.SetWriteDeadline(t) 132 } 133 134 func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { 135 base64DataBuf := &bytes.Buffer{} 136 base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) 137 138 earlyDataBuf := bytes.NewBuffer(earlyData) 139 if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { 140 return fmt.Errorf("failed to encode early data: %w", err) 141 } 142 143 if errc := base64EarlyDataEncoder.Close(); errc != nil { 144 return fmt.Errorf("failed to encode early data tail: %w", errc) 145 } 146 147 var err error 148 if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil { 149 _ = wsedc.Close() 150 return fmt.Errorf("failed to dial WebSocket: %w", err) 151 } 152 153 wsedc.dialed <- true 154 if earlyDataBuf.Len() != 0 { 155 _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) 156 } 157 158 return err 159 } 160 161 func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { 162 if wsedc.closed { 163 return 0, io.ErrClosedPipe 164 } 165 if wsedc.Conn == nil { 166 if err := wsedc.Dial(b); err != nil { 167 return 0, err 168 } 169 return len(b), nil 170 } 171 172 return wsedc.Conn.Write(b) 173 } 174 175 func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { 176 if wsedc.closed { 177 return 0, io.ErrClosedPipe 178 } 179 if wsedc.Conn == nil { 180 select { 181 case <-wsedc.ctx.Done(): 182 return 0, io.ErrUnexpectedEOF 183 case <-wsedc.dialed: 184 } 185 } 186 return wsedc.Conn.Read(b) 187 } 188 189 func (wsedc *websocketWithEarlyDataConn) Close() error { 190 wsedc.closed = true 191 wsedc.cancel() 192 if wsedc.Conn == nil { 193 return nil 194 } 195 return wsedc.Conn.Close() 196 } 197 198 func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { 199 if wsedc.Conn == nil { 200 return wsedc.underlay.LocalAddr() 201 } 202 return wsedc.Conn.LocalAddr() 203 } 204 205 func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { 206 if wsedc.Conn == nil { 207 return wsedc.underlay.RemoteAddr() 208 } 209 return wsedc.Conn.RemoteAddr() 210 } 211 212 func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { 213 if err := wsedc.SetReadDeadline(t); err != nil { 214 return err 215 } 216 return wsedc.SetWriteDeadline(t) 217 } 218 219 func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { 220 if wsedc.Conn == nil { 221 return nil 222 } 223 return wsedc.Conn.SetReadDeadline(t) 224 } 225 226 func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { 227 if wsedc.Conn == nil { 228 return nil 229 } 230 return wsedc.Conn.SetWriteDeadline(t) 231 } 232 233 func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 234 ctx, cancel := context.WithCancel(context.Background()) 235 conn = &websocketWithEarlyDataConn{ 236 dialed: make(chan bool, 1), 237 cancel: cancel, 238 ctx: ctx, 239 underlay: conn, 240 config: c, 241 } 242 return conn, nil 243 } 244 245 func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { 246 dialer := &websocket.Dialer{ 247 NetDial: func(network, addr string) (net.Conn, error) { 248 return conn, nil 249 }, 250 ReadBufferSize: 4 * 1024, 251 WriteBufferSize: 4 * 1024, 252 HandshakeTimeout: time.Second * 8, 253 } 254 255 scheme := "ws" 256 if c.TLS { 257 scheme = "wss" 258 dialer.TLSClientConfig = c.TLSConfig 259 } 260 261 u, err := url.Parse(c.Path) 262 if err != nil { 263 return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) 264 } 265 266 uri := url.URL{ 267 Scheme: scheme, 268 Host: net.JoinHostPort(c.Host, c.Port), 269 Path: u.Path, 270 RawQuery: u.RawQuery, 271 } 272 273 headers := http.Header{} 274 if c.Headers != nil { 275 for k := range c.Headers { 276 headers.Add(k, c.Headers.Get(k)) 277 } 278 } 279 280 if earlyData != nil { 281 if c.EarlyDataHeaderName == "" { 282 uri.Path += earlyData.String() 283 } else { 284 headers.Set(c.EarlyDataHeaderName, earlyData.String()) 285 } 286 } 287 288 wsConn, resp, err := dialer.Dial(uri.String(), headers) 289 if err != nil { 290 if resp != nil { 291 err = errors.Join(err, errors.New(resp.Status)) 292 } 293 return nil, errors2.Cause(errors.Join(fmt.Errorf("dial %s error", uri.Host), err)) 294 } 295 296 return &websocketConn{ 297 conn: wsConn, 298 remoteAddr: conn.RemoteAddr(), 299 }, nil 300 } 301 302 func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 303 if u, err := url.Parse(c.Path); err == nil { 304 if q := u.Query(); q.Get("ed") != "" { 305 if ed, err := strconv.Atoi(q.Get("ed")); err == nil { 306 c.MaxEarlyData = ed 307 c.EarlyDataHeaderName = "Sec-WebSocket-Protocol" 308 q.Del("ed") 309 u.RawQuery = q.Encode() 310 c.Path = u.String() 311 } 312 } 313 } 314 315 if c.MaxEarlyData > 0 { 316 return streamWebsocketWithEarlyDataConn(conn, c) 317 } 318 319 return streamWebsocketConn(conn, c, nil) 320 }