github.com/yaling888/clash@v1.53.0/transport/vmess/websocket.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"strconv"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/gorilla/websocket"
    19  
    20  	"github.com/yaling888/clash/common/errors2"
    21  )
    22  
    23  type websocketConn struct {
    24  	conn       *websocket.Conn
    25  	reader     io.Reader
    26  	remoteAddr net.Addr
    27  
    28  	// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
    29  	rMux sync.Mutex
    30  	wMux sync.Mutex
    31  }
    32  
    33  type websocketWithEarlyDataConn struct {
    34  	net.Conn
    35  	underlay net.Conn
    36  	closed   bool
    37  	dialed   chan bool
    38  	cancel   context.CancelFunc
    39  	ctx      context.Context
    40  	config   *WebsocketConfig
    41  }
    42  
    43  type WebsocketConfig struct {
    44  	Host                string
    45  	Port                string
    46  	Path                string
    47  	Headers             http.Header
    48  	TLS                 bool
    49  	TLSConfig           *tls.Config
    50  	MaxEarlyData        int
    51  	EarlyDataHeaderName string
    52  }
    53  
    54  // Read implements net.Conn.Read()
    55  func (wsc *websocketConn) Read(b []byte) (int, error) {
    56  	wsc.rMux.Lock()
    57  	defer wsc.rMux.Unlock()
    58  	for {
    59  		reader, err := wsc.getReader()
    60  		if err != nil {
    61  			return 0, err
    62  		}
    63  
    64  		nBytes, err := reader.Read(b)
    65  		if err == io.EOF {
    66  			wsc.reader = nil
    67  			continue
    68  		}
    69  		return nBytes, err
    70  	}
    71  }
    72  
    73  // Write implements io.Writer.
    74  func (wsc *websocketConn) Write(b []byte) (int, error) {
    75  	wsc.wMux.Lock()
    76  	defer wsc.wMux.Unlock()
    77  	if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
    78  		return 0, err
    79  	}
    80  	return len(b), nil
    81  }
    82  
    83  func (wsc *websocketConn) Close() error {
    84  	var errs error
    85  	if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
    86  		errs = errors.Join(errs, err)
    87  	}
    88  	if err := wsc.conn.Close(); err != nil {
    89  		errs = errors.Join(errs, err)
    90  	}
    91  	if errs != nil {
    92  		errs = errors.Join(errors.New("failed to close connection"), errs)
    93  		return errors2.Cause(errs)
    94  	}
    95  	return nil
    96  }
    97  
    98  func (wsc *websocketConn) getReader() (io.Reader, error) {
    99  	if wsc.reader != nil {
   100  		return wsc.reader, nil
   101  	}
   102  
   103  	_, reader, err := wsc.conn.NextReader()
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	wsc.reader = reader
   108  	return reader, nil
   109  }
   110  
   111  func (wsc *websocketConn) LocalAddr() net.Addr {
   112  	return wsc.conn.LocalAddr()
   113  }
   114  
   115  func (wsc *websocketConn) RemoteAddr() net.Addr {
   116  	return wsc.remoteAddr
   117  }
   118  
   119  func (wsc *websocketConn) SetDeadline(t time.Time) error {
   120  	if err := wsc.SetReadDeadline(t); err != nil {
   121  		return err
   122  	}
   123  	return wsc.SetWriteDeadline(t)
   124  }
   125  
   126  func (wsc *websocketConn) SetReadDeadline(t time.Time) error {
   127  	return wsc.conn.SetReadDeadline(t)
   128  }
   129  
   130  func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
   131  	return wsc.conn.SetWriteDeadline(t)
   132  }
   133  
   134  func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
   135  	base64DataBuf := &bytes.Buffer{}
   136  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
   137  
   138  	earlyDataBuf := bytes.NewBuffer(earlyData)
   139  	if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil {
   140  		return fmt.Errorf("failed to encode early data: %w", err)
   141  	}
   142  
   143  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   144  		return fmt.Errorf("failed to encode early data tail: %w", errc)
   145  	}
   146  
   147  	var err error
   148  	if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil {
   149  		_ = wsedc.Close()
   150  		return fmt.Errorf("failed to dial WebSocket: %w", err)
   151  	}
   152  
   153  	wsedc.dialed <- true
   154  	if earlyDataBuf.Len() != 0 {
   155  		_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
   156  	}
   157  
   158  	return err
   159  }
   160  
   161  func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
   162  	if wsedc.closed {
   163  		return 0, io.ErrClosedPipe
   164  	}
   165  	if wsedc.Conn == nil {
   166  		if err := wsedc.Dial(b); err != nil {
   167  			return 0, err
   168  		}
   169  		return len(b), nil
   170  	}
   171  
   172  	return wsedc.Conn.Write(b)
   173  }
   174  
   175  func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
   176  	if wsedc.closed {
   177  		return 0, io.ErrClosedPipe
   178  	}
   179  	if wsedc.Conn == nil {
   180  		select {
   181  		case <-wsedc.ctx.Done():
   182  			return 0, io.ErrUnexpectedEOF
   183  		case <-wsedc.dialed:
   184  		}
   185  	}
   186  	return wsedc.Conn.Read(b)
   187  }
   188  
   189  func (wsedc *websocketWithEarlyDataConn) Close() error {
   190  	wsedc.closed = true
   191  	wsedc.cancel()
   192  	if wsedc.Conn == nil {
   193  		return nil
   194  	}
   195  	return wsedc.Conn.Close()
   196  }
   197  
   198  func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
   199  	if wsedc.Conn == nil {
   200  		return wsedc.underlay.LocalAddr()
   201  	}
   202  	return wsedc.Conn.LocalAddr()
   203  }
   204  
   205  func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
   206  	if wsedc.Conn == nil {
   207  		return wsedc.underlay.RemoteAddr()
   208  	}
   209  	return wsedc.Conn.RemoteAddr()
   210  }
   211  
   212  func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
   213  	if err := wsedc.SetReadDeadline(t); err != nil {
   214  		return err
   215  	}
   216  	return wsedc.SetWriteDeadline(t)
   217  }
   218  
   219  func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
   220  	if wsedc.Conn == nil {
   221  		return nil
   222  	}
   223  	return wsedc.Conn.SetReadDeadline(t)
   224  }
   225  
   226  func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
   227  	if wsedc.Conn == nil {
   228  		return nil
   229  	}
   230  	return wsedc.Conn.SetWriteDeadline(t)
   231  }
   232  
   233  func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   234  	ctx, cancel := context.WithCancel(context.Background())
   235  	conn = &websocketWithEarlyDataConn{
   236  		dialed:   make(chan bool, 1),
   237  		cancel:   cancel,
   238  		ctx:      ctx,
   239  		underlay: conn,
   240  		config:   c,
   241  	}
   242  	return conn, nil
   243  }
   244  
   245  func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
   246  	dialer := &websocket.Dialer{
   247  		NetDial: func(network, addr string) (net.Conn, error) {
   248  			return conn, nil
   249  		},
   250  		ReadBufferSize:   4 * 1024,
   251  		WriteBufferSize:  4 * 1024,
   252  		HandshakeTimeout: time.Second * 8,
   253  	}
   254  
   255  	scheme := "ws"
   256  	if c.TLS {
   257  		scheme = "wss"
   258  		dialer.TLSClientConfig = c.TLSConfig
   259  	}
   260  
   261  	u, err := url.Parse(c.Path)
   262  	if err != nil {
   263  		return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
   264  	}
   265  
   266  	uri := url.URL{
   267  		Scheme:   scheme,
   268  		Host:     net.JoinHostPort(c.Host, c.Port),
   269  		Path:     u.Path,
   270  		RawQuery: u.RawQuery,
   271  	}
   272  
   273  	headers := http.Header{}
   274  	if c.Headers != nil {
   275  		for k := range c.Headers {
   276  			headers.Add(k, c.Headers.Get(k))
   277  		}
   278  	}
   279  
   280  	if earlyData != nil {
   281  		if c.EarlyDataHeaderName == "" {
   282  			uri.Path += earlyData.String()
   283  		} else {
   284  			headers.Set(c.EarlyDataHeaderName, earlyData.String())
   285  		}
   286  	}
   287  
   288  	wsConn, resp, err := dialer.Dial(uri.String(), headers)
   289  	if err != nil {
   290  		if resp != nil {
   291  			err = errors.Join(err, errors.New(resp.Status))
   292  		}
   293  		return nil, errors2.Cause(errors.Join(fmt.Errorf("dial %s error", uri.Host), err))
   294  	}
   295  
   296  	return &websocketConn{
   297  		conn:       wsConn,
   298  		remoteAddr: conn.RemoteAddr(),
   299  	}, nil
   300  }
   301  
   302  func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   303  	if u, err := url.Parse(c.Path); err == nil {
   304  		if q := u.Query(); q.Get("ed") != "" {
   305  			if ed, err := strconv.Atoi(q.Get("ed")); err == nil {
   306  				c.MaxEarlyData = ed
   307  				c.EarlyDataHeaderName = "Sec-WebSocket-Protocol"
   308  				q.Del("ed")
   309  				u.RawQuery = q.Encode()
   310  				c.Path = u.String()
   311  			}
   312  		}
   313  	}
   314  
   315  	if c.MaxEarlyData > 0 {
   316  		return streamWebsocketWithEarlyDataConn(conn, c)
   317  	}
   318  
   319  	return streamWebsocketConn(conn, c, nil)
   320  }