github.com/metacubex/mihomo@v1.18.5/transport/vmess/websocket.go (about) 1 package vmess 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "crypto/sha1" 8 "crypto/tls" 9 "encoding/base64" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "net/http" 16 "net/url" 17 "strconv" 18 "strings" 19 "time" 20 21 "github.com/metacubex/mihomo/common/buf" 22 N "github.com/metacubex/mihomo/common/net" 23 tlsC "github.com/metacubex/mihomo/component/tls" 24 "github.com/metacubex/mihomo/log" 25 26 "github.com/gobwas/ws" 27 "github.com/gobwas/ws/wsutil" 28 "github.com/zhangyunhao116/fastrand" 29 ) 30 31 type websocketConn struct { 32 net.Conn 33 state ws.State 34 reader *wsutil.Reader 35 controlHandler wsutil.FrameHandlerFunc 36 37 rawWriter N.ExtendedWriter 38 } 39 40 type websocketWithEarlyDataConn struct { 41 net.Conn 42 wsWriter N.ExtendedWriter 43 underlay net.Conn 44 closed bool 45 dialed chan bool 46 cancel context.CancelFunc 47 ctx context.Context 48 config *WebsocketConfig 49 } 50 51 type WebsocketConfig struct { 52 Host string 53 Port string 54 Path string 55 Headers http.Header 56 TLS bool 57 TLSConfig *tls.Config 58 MaxEarlyData int 59 EarlyDataHeaderName string 60 ClientFingerprint string 61 V2rayHttpUpgrade bool 62 V2rayHttpUpgradeFastOpen bool 63 } 64 65 // Read implements net.Conn.Read() 66 // modify from gobwas/ws/wsutil.readData 67 func (wsc *websocketConn) Read(b []byte) (n int, err error) { 68 defer func() { // avoid gobwas/ws pbytes.GetLen panic 69 if value := recover(); value != nil { 70 err = fmt.Errorf("websocket error: %s", value) 71 } 72 }() 73 var header ws.Header 74 for { 75 n, err = wsc.reader.Read(b) 76 // in gobwas/ws: "The error is io.EOF only if all of message bytes were read." 77 // but maybe next frame still have data, so drop it 78 if errors.Is(err, io.EOF) { 79 err = nil 80 } 81 if !errors.Is(err, wsutil.ErrNoFrameAdvance) { 82 return 83 } 84 header, err = wsc.reader.NextFrame() 85 if err != nil { 86 return 87 } 88 if header.OpCode.IsControl() { 89 err = wsc.controlHandler(header, wsc.reader) 90 if err != nil { 91 return 92 } 93 continue 94 } 95 if header.OpCode&(ws.OpBinary|ws.OpText) == 0 { 96 err = wsc.reader.Discard() 97 if err != nil { 98 return 99 } 100 continue 101 } 102 } 103 } 104 105 // Write implements io.Writer. 106 func (wsc *websocketConn) Write(b []byte) (n int, err error) { 107 err = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpBinary, b) 108 if err != nil { 109 return 110 } 111 n = len(b) 112 return 113 } 114 115 func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error { 116 var payloadBitLength int 117 dataLen := buffer.Len() 118 data := buffer.Bytes() 119 if dataLen < 126 { 120 payloadBitLength = 1 121 } else if dataLen < 65536 { 122 payloadBitLength = 3 123 } else { 124 payloadBitLength = 9 125 } 126 127 var headerLen int 128 headerLen += 1 // FIN / RSV / OPCODE 129 headerLen += payloadBitLength 130 if wsc.state.ClientSide() { 131 headerLen += 4 // MASK KEY 132 } 133 134 header := buffer.ExtendHeader(headerLen) 135 header[0] = byte(ws.OpBinary) | 0x80 136 if wsc.state.ClientSide() { 137 header[1] = 1 << 7 138 } else { 139 header[1] = 0 140 } 141 142 if dataLen < 126 { 143 header[1] |= byte(dataLen) 144 } else if dataLen < 65536 { 145 header[1] |= 126 146 binary.BigEndian.PutUint16(header[2:], uint16(dataLen)) 147 } else { 148 header[1] |= 127 149 binary.BigEndian.PutUint64(header[2:], uint64(dataLen)) 150 } 151 152 if wsc.state.ClientSide() { 153 maskKey := fastrand.Uint32() 154 binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey) 155 N.MaskWebSocket(maskKey, data) 156 } 157 158 return wsc.rawWriter.WriteBuffer(buffer) 159 } 160 161 func (wsc *websocketConn) FrontHeadroom() int { 162 return 14 163 } 164 165 func (wsc *websocketConn) Upstream() any { 166 return wsc.Conn 167 } 168 169 func (wsc *websocketConn) Close() error { 170 _ = wsc.Conn.SetWriteDeadline(time.Now().Add(time.Second * 5)) 171 _ = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, "")) 172 _ = wsc.Conn.Close() 173 return nil 174 } 175 176 func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { 177 base64DataBuf := &bytes.Buffer{} 178 base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) 179 180 earlyDataBuf := bytes.NewBuffer(earlyData) 181 if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { 182 return fmt.Errorf("failed to encode early data: %w", err) 183 } 184 185 if errc := base64EarlyDataEncoder.Close(); errc != nil { 186 return fmt.Errorf("failed to encode early data tail: %w", errc) 187 } 188 189 var err error 190 if wsedc.Conn, err = streamWebsocketConn(wsedc.ctx, wsedc.underlay, wsedc.config, base64DataBuf); err != nil { 191 wsedc.Close() 192 return fmt.Errorf("failed to dial WebSocket: %w", err) 193 } 194 195 wsedc.dialed <- true 196 wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn) 197 if earlyDataBuf.Len() != 0 { 198 _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) 199 } 200 201 return err 202 } 203 204 func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { 205 if wsedc.closed { 206 return 0, io.ErrClosedPipe 207 } 208 if wsedc.Conn == nil { 209 if err := wsedc.Dial(b); err != nil { 210 return 0, err 211 } 212 return len(b), nil 213 } 214 215 return wsedc.Conn.Write(b) 216 } 217 218 func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error { 219 if wsedc.closed { 220 return io.ErrClosedPipe 221 } 222 if wsedc.Conn == nil { 223 if err := wsedc.Dial(buffer.Bytes()); err != nil { 224 return err 225 } 226 return nil 227 } 228 229 return wsedc.wsWriter.WriteBuffer(buffer) 230 } 231 232 func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { 233 if wsedc.closed { 234 return 0, io.ErrClosedPipe 235 } 236 if wsedc.Conn == nil { 237 select { 238 case <-wsedc.ctx.Done(): 239 return 0, io.ErrUnexpectedEOF 240 case <-wsedc.dialed: 241 } 242 } 243 return wsedc.Conn.Read(b) 244 } 245 246 func (wsedc *websocketWithEarlyDataConn) Close() error { 247 wsedc.closed = true 248 wsedc.cancel() 249 if wsedc.Conn == nil { 250 return nil 251 } 252 return wsedc.Conn.Close() 253 } 254 255 func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { 256 if wsedc.Conn == nil { 257 return wsedc.underlay.LocalAddr() 258 } 259 return wsedc.Conn.LocalAddr() 260 } 261 262 func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { 263 if wsedc.Conn == nil { 264 return wsedc.underlay.RemoteAddr() 265 } 266 return wsedc.Conn.RemoteAddr() 267 } 268 269 func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { 270 if err := wsedc.SetReadDeadline(t); err != nil { 271 return err 272 } 273 return wsedc.SetWriteDeadline(t) 274 } 275 276 func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { 277 if wsedc.Conn == nil { 278 return nil 279 } 280 return wsedc.Conn.SetReadDeadline(t) 281 } 282 283 func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { 284 if wsedc.Conn == nil { 285 return nil 286 } 287 return wsedc.Conn.SetWriteDeadline(t) 288 } 289 290 func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int { 291 return 14 292 } 293 294 func (wsedc *websocketWithEarlyDataConn) Upstream() any { 295 return wsedc.underlay 296 } 297 298 //func (wsedc *websocketWithEarlyDataConn) LazyHeadroom() bool { 299 // return wsedc.Conn == nil 300 //} 301 // 302 //func (wsedc *websocketWithEarlyDataConn) Upstream() any { 303 // if wsedc.Conn == nil { // ensure return a nil interface not an interface with nil value 304 // return nil 305 // } 306 // return wsedc.Conn 307 //} 308 309 func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool { 310 return wsedc.Conn == nil 311 } 312 313 func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 314 ctx, cancel := context.WithCancel(context.Background()) 315 conn = &websocketWithEarlyDataConn{ 316 dialed: make(chan bool, 1), 317 cancel: cancel, 318 ctx: ctx, 319 underlay: conn, 320 config: c, 321 } 322 // websocketWithEarlyDataConn can't correct handle Deadline 323 // it will not apply the already set Deadline after Dial() 324 // so call N.NewDeadlineConn to add a safe wrapper 325 return N.NewDeadlineConn(conn), nil 326 } 327 328 func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { 329 u, err := url.Parse(c.Path) 330 if err != nil { 331 return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) 332 } 333 334 uri := url.URL{ 335 Scheme: "ws", 336 Host: net.JoinHostPort(c.Host, c.Port), 337 Path: u.Path, 338 RawQuery: u.RawQuery, 339 } 340 341 if !strings.HasPrefix(uri.Path, "/") { 342 uri.Path = "/" + uri.Path 343 } 344 345 if c.TLS { 346 uri.Scheme = "wss" 347 config := c.TLSConfig 348 if config == nil { // The config cannot be nil 349 config = &tls.Config{NextProtos: []string{"http/1.1"}} 350 } 351 if config.ServerName == "" && !config.InsecureSkipVerify { // users must set either ServerName or InsecureSkipVerify in the config. 352 config = config.Clone() 353 config.ServerName = uri.Host 354 } 355 356 if len(c.ClientFingerprint) != 0 { 357 if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists { 358 utlsConn := tlsC.UClient(conn, config, fingerprint) 359 if err = utlsConn.BuildWebsocketHandshakeState(); err != nil { 360 return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) 361 } 362 conn = utlsConn 363 } 364 } else { 365 conn = tls.Client(conn, config) 366 } 367 368 if tlsConn, ok := conn.(interface { 369 HandshakeContext(ctx context.Context) error 370 }); ok { 371 if err = tlsConn.HandshakeContext(ctx); err != nil { 372 return nil, err 373 } 374 } 375 } 376 377 request := &http.Request{ 378 Method: http.MethodGet, 379 URL: &uri, 380 Header: c.Headers.Clone(), 381 Host: c.Host, 382 } 383 384 request.Header.Set("Connection", "Upgrade") 385 request.Header.Set("Upgrade", "websocket") 386 387 if host := request.Header.Get("Host"); host != "" { 388 // For client requests, Host optionally overrides the Host 389 // header to send. If empty, the Request.Write method uses 390 // the value of URL.Host. Host may contain an international 391 // domain name. 392 request.Host = host 393 } 394 request.Header.Del("Host") 395 396 var secKey string 397 if !c.V2rayHttpUpgrade { 398 const nonceKeySize = 16 399 // NOTE: bts does not escape. 400 bts := make([]byte, nonceKeySize) 401 if _, err = fastrand.Read(bts); err != nil { 402 return nil, fmt.Errorf("rand read error: %w", err) 403 } 404 secKey = base64.StdEncoding.EncodeToString(bts) 405 request.Header.Set("Sec-WebSocket-Version", "13") 406 request.Header.Set("Sec-WebSocket-Key", secKey) 407 } 408 409 if earlyData != nil { 410 earlyDataString := earlyData.String() 411 if c.EarlyDataHeaderName == "" { 412 uri.Path += earlyDataString 413 } else { 414 request.Header.Set(c.EarlyDataHeaderName, earlyDataString) 415 } 416 } 417 418 if ctx.Done() != nil { 419 done := N.SetupContextForConn(ctx, conn) 420 defer done(&err) 421 } 422 423 err = request.Write(conn) 424 if err != nil { 425 return nil, err 426 } 427 bufferedConn := N.NewBufferedConn(conn) 428 429 if c.V2rayHttpUpgrade && c.V2rayHttpUpgradeFastOpen { 430 return N.NewEarlyConn(bufferedConn, func() error { 431 response, err := http.ReadResponse(bufferedConn.Reader(), request) 432 if err != nil { 433 return err 434 } 435 if response.StatusCode != http.StatusSwitchingProtocols || 436 !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || 437 !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { 438 return fmt.Errorf("unexpected status: %s", response.Status) 439 } 440 return nil 441 }), nil 442 } 443 444 response, err := http.ReadResponse(bufferedConn.Reader(), request) 445 if err != nil { 446 return nil, err 447 } 448 if response.StatusCode != http.StatusSwitchingProtocols || 449 !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || 450 !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { 451 return nil, fmt.Errorf("unexpected status: %s", response.Status) 452 } 453 454 if c.V2rayHttpUpgrade { 455 return bufferedConn, nil 456 } 457 458 if log.Level() == log.DEBUG { // we might not check this for performance 459 secAccept := response.Header.Get("Sec-Websocket-Accept") 460 const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size) 461 if lenSecAccept := len(secAccept); lenSecAccept != acceptSize { 462 return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept) 463 } 464 if getSecAccept(secKey) != secAccept { 465 return nil, errors.New("unexpected Sec-Websocket-Accept") 466 } 467 } 468 469 conn = newWebsocketConn(conn, ws.StateClientSide) 470 // websocketConn can't correct handle ReadDeadline 471 // so call N.NewDeadlineConn to add a safe wrapper 472 return N.NewDeadlineConn(conn), nil 473 } 474 475 func getSecAccept(secKey string) string { 476 const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 477 const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) 478 p := make([]byte, nonceSize+len(magic)) 479 copy(p[:nonceSize], secKey) 480 copy(p[nonceSize:], magic) 481 sum := sha1.Sum(p) 482 return base64.StdEncoding.EncodeToString(sum[:]) 483 } 484 485 func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 486 if u, err := url.Parse(c.Path); err == nil { 487 if q := u.Query(); q.Get("ed") != "" { 488 if ed, err := strconv.Atoi(q.Get("ed")); err == nil { 489 c.MaxEarlyData = ed 490 c.EarlyDataHeaderName = "Sec-WebSocket-Protocol" 491 q.Del("ed") 492 u.RawQuery = q.Encode() 493 c.Path = u.String() 494 } 495 } 496 } 497 498 if c.MaxEarlyData > 0 { 499 return streamWebsocketWithEarlyDataConn(conn, c) 500 } 501 502 return streamWebsocketConn(ctx, conn, c, nil) 503 } 504 505 func newWebsocketConn(conn net.Conn, state ws.State) *websocketConn { 506 controlHandler := wsutil.ControlFrameHandler(conn, state) 507 return &websocketConn{ 508 Conn: conn, 509 state: state, 510 reader: &wsutil.Reader{ 511 Source: conn, 512 State: state, 513 SkipHeaderCheck: true, 514 CheckUTF8: false, 515 OnIntermediate: controlHandler, 516 }, 517 controlHandler: controlHandler, 518 rawWriter: N.NewExtendedWriter(conn), 519 } 520 } 521 522 var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") 523 524 func decodeEd(s string) ([]byte, error) { 525 return base64.RawURLEncoding.DecodeString(replacer.Replace(s)) 526 } 527 528 func decodeXray0rtt(requestHeader http.Header) []byte { 529 // read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws 530 if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 { 531 if edBuf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode 532 return edBuf 533 } 534 } 535 return nil 536 } 537 538 func IsWebSocketUpgrade(r *http.Request) bool { 539 return r.Header.Get("Upgrade") == "websocket" 540 } 541 542 func IsV2rayHttpUpdate(r *http.Request) bool { 543 return IsWebSocketUpgrade(r) && r.Header.Get("Sec-WebSocket-Key") == "" 544 } 545 546 func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) { 547 var conn net.Conn 548 var rw *bufio.ReadWriter 549 var err error 550 isRaw := IsV2rayHttpUpdate(r) 551 w.Header().Set("Connection", "upgrade") 552 w.Header().Set("Upgrade", "websocket") 553 if !isRaw { 554 w.Header().Set("Sec-Websocket-Accept", getSecAccept(r.Header.Get("Sec-WebSocket-Key"))) 555 } 556 w.WriteHeader(http.StatusSwitchingProtocols) 557 if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher { 558 err = flusher.FlushError() 559 if err != nil { 560 return nil, fmt.Errorf("flush response: %w", err) 561 } 562 } 563 hijacker, canHijack := w.(http.Hijacker) 564 if !canHijack { 565 return nil, errors.New("invalid connection, maybe HTTP/2") 566 } 567 conn, rw, err = hijacker.Hijack() 568 if err != nil { 569 return nil, fmt.Errorf("hijack failed: %w", err) 570 } 571 572 // rw.Writer was flushed, so we only need warp rw.Reader 573 conn = N.WarpConnWithBioReader(conn, rw.Reader) 574 575 if !isRaw { 576 conn = newWebsocketConn(conn, ws.StateServerSide) 577 // websocketConn can't correct handle ReadDeadline 578 // so call N.NewDeadlineConn to add a safe wrapper 579 conn = N.NewDeadlineConn(conn) 580 } 581 582 if edBuf := decodeXray0rtt(r.Header); len(edBuf) > 0 { 583 appendOk := false 584 if bufConn, ok := conn.(*N.BufferedConn); ok { 585 appendOk = bufConn.AppendData(edBuf) 586 } 587 if !appendOk { 588 conn = N.NewCachedConn(conn, edBuf) 589 } 590 591 } 592 593 return conn, nil 594 }