github.com/metacubex/mihomo@v1.18.5/transport/vmess/websocket.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/sha1"
     8  	"crypto/tls"
     9  	"encoding/base64"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/http"
    16  	"net/url"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/metacubex/mihomo/common/buf"
    22  	N "github.com/metacubex/mihomo/common/net"
    23  	tlsC "github.com/metacubex/mihomo/component/tls"
    24  	"github.com/metacubex/mihomo/log"
    25  
    26  	"github.com/gobwas/ws"
    27  	"github.com/gobwas/ws/wsutil"
    28  	"github.com/zhangyunhao116/fastrand"
    29  )
    30  
    31  type websocketConn struct {
    32  	net.Conn
    33  	state          ws.State
    34  	reader         *wsutil.Reader
    35  	controlHandler wsutil.FrameHandlerFunc
    36  
    37  	rawWriter N.ExtendedWriter
    38  }
    39  
    40  type websocketWithEarlyDataConn struct {
    41  	net.Conn
    42  	wsWriter N.ExtendedWriter
    43  	underlay net.Conn
    44  	closed   bool
    45  	dialed   chan bool
    46  	cancel   context.CancelFunc
    47  	ctx      context.Context
    48  	config   *WebsocketConfig
    49  }
    50  
    51  type WebsocketConfig struct {
    52  	Host                     string
    53  	Port                     string
    54  	Path                     string
    55  	Headers                  http.Header
    56  	TLS                      bool
    57  	TLSConfig                *tls.Config
    58  	MaxEarlyData             int
    59  	EarlyDataHeaderName      string
    60  	ClientFingerprint        string
    61  	V2rayHttpUpgrade         bool
    62  	V2rayHttpUpgradeFastOpen bool
    63  }
    64  
    65  // Read implements net.Conn.Read()
    66  // modify from gobwas/ws/wsutil.readData
    67  func (wsc *websocketConn) Read(b []byte) (n int, err error) {
    68  	defer func() { // avoid gobwas/ws pbytes.GetLen panic
    69  		if value := recover(); value != nil {
    70  			err = fmt.Errorf("websocket error: %s", value)
    71  		}
    72  	}()
    73  	var header ws.Header
    74  	for {
    75  		n, err = wsc.reader.Read(b)
    76  		// in gobwas/ws: "The error is io.EOF only if all of message bytes were read."
    77  		// but maybe next frame still have data, so drop it
    78  		if errors.Is(err, io.EOF) {
    79  			err = nil
    80  		}
    81  		if !errors.Is(err, wsutil.ErrNoFrameAdvance) {
    82  			return
    83  		}
    84  		header, err = wsc.reader.NextFrame()
    85  		if err != nil {
    86  			return
    87  		}
    88  		if header.OpCode.IsControl() {
    89  			err = wsc.controlHandler(header, wsc.reader)
    90  			if err != nil {
    91  				return
    92  			}
    93  			continue
    94  		}
    95  		if header.OpCode&(ws.OpBinary|ws.OpText) == 0 {
    96  			err = wsc.reader.Discard()
    97  			if err != nil {
    98  				return
    99  			}
   100  			continue
   101  		}
   102  	}
   103  }
   104  
   105  // Write implements io.Writer.
   106  func (wsc *websocketConn) Write(b []byte) (n int, err error) {
   107  	err = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpBinary, b)
   108  	if err != nil {
   109  		return
   110  	}
   111  	n = len(b)
   112  	return
   113  }
   114  
   115  func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
   116  	var payloadBitLength int
   117  	dataLen := buffer.Len()
   118  	data := buffer.Bytes()
   119  	if dataLen < 126 {
   120  		payloadBitLength = 1
   121  	} else if dataLen < 65536 {
   122  		payloadBitLength = 3
   123  	} else {
   124  		payloadBitLength = 9
   125  	}
   126  
   127  	var headerLen int
   128  	headerLen += 1 // FIN / RSV / OPCODE
   129  	headerLen += payloadBitLength
   130  	if wsc.state.ClientSide() {
   131  		headerLen += 4 // MASK KEY
   132  	}
   133  
   134  	header := buffer.ExtendHeader(headerLen)
   135  	header[0] = byte(ws.OpBinary) | 0x80
   136  	if wsc.state.ClientSide() {
   137  		header[1] = 1 << 7
   138  	} else {
   139  		header[1] = 0
   140  	}
   141  
   142  	if dataLen < 126 {
   143  		header[1] |= byte(dataLen)
   144  	} else if dataLen < 65536 {
   145  		header[1] |= 126
   146  		binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
   147  	} else {
   148  		header[1] |= 127
   149  		binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
   150  	}
   151  
   152  	if wsc.state.ClientSide() {
   153  		maskKey := fastrand.Uint32()
   154  		binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey)
   155  		N.MaskWebSocket(maskKey, data)
   156  	}
   157  
   158  	return wsc.rawWriter.WriteBuffer(buffer)
   159  }
   160  
   161  func (wsc *websocketConn) FrontHeadroom() int {
   162  	return 14
   163  }
   164  
   165  func (wsc *websocketConn) Upstream() any {
   166  	return wsc.Conn
   167  }
   168  
   169  func (wsc *websocketConn) Close() error {
   170  	_ = wsc.Conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   171  	_ = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, ""))
   172  	_ = wsc.Conn.Close()
   173  	return nil
   174  }
   175  
   176  func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
   177  	base64DataBuf := &bytes.Buffer{}
   178  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
   179  
   180  	earlyDataBuf := bytes.NewBuffer(earlyData)
   181  	if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil {
   182  		return fmt.Errorf("failed to encode early data: %w", err)
   183  	}
   184  
   185  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   186  		return fmt.Errorf("failed to encode early data tail: %w", errc)
   187  	}
   188  
   189  	var err error
   190  	if wsedc.Conn, err = streamWebsocketConn(wsedc.ctx, wsedc.underlay, wsedc.config, base64DataBuf); err != nil {
   191  		wsedc.Close()
   192  		return fmt.Errorf("failed to dial WebSocket: %w", err)
   193  	}
   194  
   195  	wsedc.dialed <- true
   196  	wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn)
   197  	if earlyDataBuf.Len() != 0 {
   198  		_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
   199  	}
   200  
   201  	return err
   202  }
   203  
   204  func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
   205  	if wsedc.closed {
   206  		return 0, io.ErrClosedPipe
   207  	}
   208  	if wsedc.Conn == nil {
   209  		if err := wsedc.Dial(b); err != nil {
   210  			return 0, err
   211  		}
   212  		return len(b), nil
   213  	}
   214  
   215  	return wsedc.Conn.Write(b)
   216  }
   217  
   218  func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
   219  	if wsedc.closed {
   220  		return io.ErrClosedPipe
   221  	}
   222  	if wsedc.Conn == nil {
   223  		if err := wsedc.Dial(buffer.Bytes()); err != nil {
   224  			return err
   225  		}
   226  		return nil
   227  	}
   228  
   229  	return wsedc.wsWriter.WriteBuffer(buffer)
   230  }
   231  
   232  func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
   233  	if wsedc.closed {
   234  		return 0, io.ErrClosedPipe
   235  	}
   236  	if wsedc.Conn == nil {
   237  		select {
   238  		case <-wsedc.ctx.Done():
   239  			return 0, io.ErrUnexpectedEOF
   240  		case <-wsedc.dialed:
   241  		}
   242  	}
   243  	return wsedc.Conn.Read(b)
   244  }
   245  
   246  func (wsedc *websocketWithEarlyDataConn) Close() error {
   247  	wsedc.closed = true
   248  	wsedc.cancel()
   249  	if wsedc.Conn == nil {
   250  		return nil
   251  	}
   252  	return wsedc.Conn.Close()
   253  }
   254  
   255  func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
   256  	if wsedc.Conn == nil {
   257  		return wsedc.underlay.LocalAddr()
   258  	}
   259  	return wsedc.Conn.LocalAddr()
   260  }
   261  
   262  func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
   263  	if wsedc.Conn == nil {
   264  		return wsedc.underlay.RemoteAddr()
   265  	}
   266  	return wsedc.Conn.RemoteAddr()
   267  }
   268  
   269  func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
   270  	if err := wsedc.SetReadDeadline(t); err != nil {
   271  		return err
   272  	}
   273  	return wsedc.SetWriteDeadline(t)
   274  }
   275  
   276  func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
   277  	if wsedc.Conn == nil {
   278  		return nil
   279  	}
   280  	return wsedc.Conn.SetReadDeadline(t)
   281  }
   282  
   283  func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
   284  	if wsedc.Conn == nil {
   285  		return nil
   286  	}
   287  	return wsedc.Conn.SetWriteDeadline(t)
   288  }
   289  
   290  func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int {
   291  	return 14
   292  }
   293  
   294  func (wsedc *websocketWithEarlyDataConn) Upstream() any {
   295  	return wsedc.underlay
   296  }
   297  
   298  //func (wsedc *websocketWithEarlyDataConn) LazyHeadroom() bool {
   299  //	return wsedc.Conn == nil
   300  //}
   301  //
   302  //func (wsedc *websocketWithEarlyDataConn) Upstream() any {
   303  //	if wsedc.Conn == nil { // ensure return a nil interface not an interface with nil value
   304  //		return nil
   305  //	}
   306  //	return wsedc.Conn
   307  //}
   308  
   309  func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool {
   310  	return wsedc.Conn == nil
   311  }
   312  
   313  func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   314  	ctx, cancel := context.WithCancel(context.Background())
   315  	conn = &websocketWithEarlyDataConn{
   316  		dialed:   make(chan bool, 1),
   317  		cancel:   cancel,
   318  		ctx:      ctx,
   319  		underlay: conn,
   320  		config:   c,
   321  	}
   322  	// websocketWithEarlyDataConn can't correct handle Deadline
   323  	// it will not apply the already set Deadline after Dial()
   324  	// so call N.NewDeadlineConn to add a safe wrapper
   325  	return N.NewDeadlineConn(conn), nil
   326  }
   327  
   328  func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
   329  	u, err := url.Parse(c.Path)
   330  	if err != nil {
   331  		return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
   332  	}
   333  
   334  	uri := url.URL{
   335  		Scheme:   "ws",
   336  		Host:     net.JoinHostPort(c.Host, c.Port),
   337  		Path:     u.Path,
   338  		RawQuery: u.RawQuery,
   339  	}
   340  
   341  	if !strings.HasPrefix(uri.Path, "/") {
   342  		uri.Path = "/" + uri.Path
   343  	}
   344  
   345  	if c.TLS {
   346  		uri.Scheme = "wss"
   347  		config := c.TLSConfig
   348  		if config == nil { // The config cannot be nil
   349  			config = &tls.Config{NextProtos: []string{"http/1.1"}}
   350  		}
   351  		if config.ServerName == "" && !config.InsecureSkipVerify { // users must set either ServerName or InsecureSkipVerify in the config.
   352  			config = config.Clone()
   353  			config.ServerName = uri.Host
   354  		}
   355  
   356  		if len(c.ClientFingerprint) != 0 {
   357  			if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists {
   358  				utlsConn := tlsC.UClient(conn, config, fingerprint)
   359  				if err = utlsConn.BuildWebsocketHandshakeState(); err != nil {
   360  					return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
   361  				}
   362  				conn = utlsConn
   363  			}
   364  		} else {
   365  			conn = tls.Client(conn, config)
   366  		}
   367  
   368  		if tlsConn, ok := conn.(interface {
   369  			HandshakeContext(ctx context.Context) error
   370  		}); ok {
   371  			if err = tlsConn.HandshakeContext(ctx); err != nil {
   372  				return nil, err
   373  			}
   374  		}
   375  	}
   376  
   377  	request := &http.Request{
   378  		Method: http.MethodGet,
   379  		URL:    &uri,
   380  		Header: c.Headers.Clone(),
   381  		Host:   c.Host,
   382  	}
   383  
   384  	request.Header.Set("Connection", "Upgrade")
   385  	request.Header.Set("Upgrade", "websocket")
   386  
   387  	if host := request.Header.Get("Host"); host != "" {
   388  		// For client requests, Host optionally overrides the Host
   389  		// header to send. If empty, the Request.Write method uses
   390  		// the value of URL.Host. Host may contain an international
   391  		// domain name.
   392  		request.Host = host
   393  	}
   394  	request.Header.Del("Host")
   395  
   396  	var secKey string
   397  	if !c.V2rayHttpUpgrade {
   398  		const nonceKeySize = 16
   399  		// NOTE: bts does not escape.
   400  		bts := make([]byte, nonceKeySize)
   401  		if _, err = fastrand.Read(bts); err != nil {
   402  			return nil, fmt.Errorf("rand read error: %w", err)
   403  		}
   404  		secKey = base64.StdEncoding.EncodeToString(bts)
   405  		request.Header.Set("Sec-WebSocket-Version", "13")
   406  		request.Header.Set("Sec-WebSocket-Key", secKey)
   407  	}
   408  
   409  	if earlyData != nil {
   410  		earlyDataString := earlyData.String()
   411  		if c.EarlyDataHeaderName == "" {
   412  			uri.Path += earlyDataString
   413  		} else {
   414  			request.Header.Set(c.EarlyDataHeaderName, earlyDataString)
   415  		}
   416  	}
   417  
   418  	if ctx.Done() != nil {
   419  		done := N.SetupContextForConn(ctx, conn)
   420  		defer done(&err)
   421  	}
   422  
   423  	err = request.Write(conn)
   424  	if err != nil {
   425  		return nil, err
   426  	}
   427  	bufferedConn := N.NewBufferedConn(conn)
   428  
   429  	if c.V2rayHttpUpgrade && c.V2rayHttpUpgradeFastOpen {
   430  		return N.NewEarlyConn(bufferedConn, func() error {
   431  			response, err := http.ReadResponse(bufferedConn.Reader(), request)
   432  			if err != nil {
   433  				return err
   434  			}
   435  			if response.StatusCode != http.StatusSwitchingProtocols ||
   436  				!strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
   437  				!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
   438  				return fmt.Errorf("unexpected status: %s", response.Status)
   439  			}
   440  			return nil
   441  		}), nil
   442  	}
   443  
   444  	response, err := http.ReadResponse(bufferedConn.Reader(), request)
   445  	if err != nil {
   446  		return nil, err
   447  	}
   448  	if response.StatusCode != http.StatusSwitchingProtocols ||
   449  		!strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
   450  		!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
   451  		return nil, fmt.Errorf("unexpected status: %s", response.Status)
   452  	}
   453  
   454  	if c.V2rayHttpUpgrade {
   455  		return bufferedConn, nil
   456  	}
   457  
   458  	if log.Level() == log.DEBUG { // we might not check this for performance
   459  		secAccept := response.Header.Get("Sec-Websocket-Accept")
   460  		const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
   461  		if lenSecAccept := len(secAccept); lenSecAccept != acceptSize {
   462  			return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept)
   463  		}
   464  		if getSecAccept(secKey) != secAccept {
   465  			return nil, errors.New("unexpected Sec-Websocket-Accept")
   466  		}
   467  	}
   468  
   469  	conn = newWebsocketConn(conn, ws.StateClientSide)
   470  	// websocketConn can't correct handle ReadDeadline
   471  	// so call N.NewDeadlineConn to add a safe wrapper
   472  	return N.NewDeadlineConn(conn), nil
   473  }
   474  
   475  func getSecAccept(secKey string) string {
   476  	const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
   477  	const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
   478  	p := make([]byte, nonceSize+len(magic))
   479  	copy(p[:nonceSize], secKey)
   480  	copy(p[nonceSize:], magic)
   481  	sum := sha1.Sum(p)
   482  	return base64.StdEncoding.EncodeToString(sum[:])
   483  }
   484  
   485  func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   486  	if u, err := url.Parse(c.Path); err == nil {
   487  		if q := u.Query(); q.Get("ed") != "" {
   488  			if ed, err := strconv.Atoi(q.Get("ed")); err == nil {
   489  				c.MaxEarlyData = ed
   490  				c.EarlyDataHeaderName = "Sec-WebSocket-Protocol"
   491  				q.Del("ed")
   492  				u.RawQuery = q.Encode()
   493  				c.Path = u.String()
   494  			}
   495  		}
   496  	}
   497  
   498  	if c.MaxEarlyData > 0 {
   499  		return streamWebsocketWithEarlyDataConn(conn, c)
   500  	}
   501  
   502  	return streamWebsocketConn(ctx, conn, c, nil)
   503  }
   504  
   505  func newWebsocketConn(conn net.Conn, state ws.State) *websocketConn {
   506  	controlHandler := wsutil.ControlFrameHandler(conn, state)
   507  	return &websocketConn{
   508  		Conn:  conn,
   509  		state: state,
   510  		reader: &wsutil.Reader{
   511  			Source:          conn,
   512  			State:           state,
   513  			SkipHeaderCheck: true,
   514  			CheckUTF8:       false,
   515  			OnIntermediate:  controlHandler,
   516  		},
   517  		controlHandler: controlHandler,
   518  		rawWriter:      N.NewExtendedWriter(conn),
   519  	}
   520  }
   521  
   522  var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "")
   523  
   524  func decodeEd(s string) ([]byte, error) {
   525  	return base64.RawURLEncoding.DecodeString(replacer.Replace(s))
   526  }
   527  
   528  func decodeXray0rtt(requestHeader http.Header) []byte {
   529  	// read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws
   530  	if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
   531  		if edBuf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode
   532  			return edBuf
   533  		}
   534  	}
   535  	return nil
   536  }
   537  
   538  func IsWebSocketUpgrade(r *http.Request) bool {
   539  	return r.Header.Get("Upgrade") == "websocket"
   540  }
   541  
   542  func IsV2rayHttpUpdate(r *http.Request) bool {
   543  	return IsWebSocketUpgrade(r) && r.Header.Get("Sec-WebSocket-Key") == ""
   544  }
   545  
   546  func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
   547  	var conn net.Conn
   548  	var rw *bufio.ReadWriter
   549  	var err error
   550  	isRaw := IsV2rayHttpUpdate(r)
   551  	w.Header().Set("Connection", "upgrade")
   552  	w.Header().Set("Upgrade", "websocket")
   553  	if !isRaw {
   554  		w.Header().Set("Sec-Websocket-Accept", getSecAccept(r.Header.Get("Sec-WebSocket-Key")))
   555  	}
   556  	w.WriteHeader(http.StatusSwitchingProtocols)
   557  	if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher {
   558  		err = flusher.FlushError()
   559  		if err != nil {
   560  			return nil, fmt.Errorf("flush response: %w", err)
   561  		}
   562  	}
   563  	hijacker, canHijack := w.(http.Hijacker)
   564  	if !canHijack {
   565  		return nil, errors.New("invalid connection, maybe HTTP/2")
   566  	}
   567  	conn, rw, err = hijacker.Hijack()
   568  	if err != nil {
   569  		return nil, fmt.Errorf("hijack failed: %w", err)
   570  	}
   571  
   572  	// rw.Writer was flushed, so we only need warp rw.Reader
   573  	conn = N.WarpConnWithBioReader(conn, rw.Reader)
   574  
   575  	if !isRaw {
   576  		conn = newWebsocketConn(conn, ws.StateServerSide)
   577  		// websocketConn can't correct handle ReadDeadline
   578  		// so call N.NewDeadlineConn to add a safe wrapper
   579  		conn = N.NewDeadlineConn(conn)
   580  	}
   581  
   582  	if edBuf := decodeXray0rtt(r.Header); len(edBuf) > 0 {
   583  		appendOk := false
   584  		if bufConn, ok := conn.(*N.BufferedConn); ok {
   585  			appendOk = bufConn.AppendData(edBuf)
   586  		}
   587  		if !appendOk {
   588  			conn = N.NewCachedConn(conn, edBuf)
   589  		}
   590  
   591  	}
   592  
   593  	return conn, nil
   594  }