github.com/igoogolx/clash@v1.19.8/transport/vmess/websocket.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/gorilla/websocket"
    20  )
    21  
    22  type websocketConn struct {
    23  	conn       *websocket.Conn
    24  	reader     io.Reader
    25  	remoteAddr net.Addr
    26  
    27  	// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
    28  	rMux sync.Mutex
    29  	wMux sync.Mutex
    30  }
    31  
    32  type websocketWithEarlyDataConn struct {
    33  	net.Conn
    34  	underlay net.Conn
    35  	closed   bool
    36  	dialed   chan bool
    37  	cancel   context.CancelFunc
    38  	ctx      context.Context
    39  	config   *WebsocketConfig
    40  }
    41  
    42  type WebsocketConfig struct {
    43  	Host                string
    44  	Port                string
    45  	Path                string
    46  	Headers             http.Header
    47  	TLS                 bool
    48  	TLSConfig           *tls.Config
    49  	MaxEarlyData        int
    50  	EarlyDataHeaderName string
    51  }
    52  
    53  // Read implements net.Conn.Read()
    54  func (wsc *websocketConn) Read(b []byte) (int, error) {
    55  	wsc.rMux.Lock()
    56  	defer wsc.rMux.Unlock()
    57  	for {
    58  		reader, err := wsc.getReader()
    59  		if err != nil {
    60  			return 0, err
    61  		}
    62  
    63  		nBytes, err := reader.Read(b)
    64  		if err == io.EOF {
    65  			wsc.reader = nil
    66  			continue
    67  		}
    68  		return nBytes, err
    69  	}
    70  }
    71  
    72  // Write implements io.Writer.
    73  func (wsc *websocketConn) Write(b []byte) (int, error) {
    74  	wsc.wMux.Lock()
    75  	defer wsc.wMux.Unlock()
    76  	if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
    77  		return 0, err
    78  	}
    79  	return len(b), nil
    80  }
    81  
    82  func (wsc *websocketConn) Close() error {
    83  	var errors []string
    84  	if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
    85  		errors = append(errors, err.Error())
    86  	}
    87  	if err := wsc.conn.Close(); err != nil {
    88  		errors = append(errors, err.Error())
    89  	}
    90  	if len(errors) > 0 {
    91  		return fmt.Errorf("failed to close connection: %s", strings.Join(errors, ","))
    92  	}
    93  	return nil
    94  }
    95  
    96  func (wsc *websocketConn) getReader() (io.Reader, error) {
    97  	if wsc.reader != nil {
    98  		return wsc.reader, nil
    99  	}
   100  
   101  	_, reader, err := wsc.conn.NextReader()
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	wsc.reader = reader
   106  	return reader, nil
   107  }
   108  
   109  func (wsc *websocketConn) LocalAddr() net.Addr {
   110  	return wsc.conn.LocalAddr()
   111  }
   112  
   113  func (wsc *websocketConn) RemoteAddr() net.Addr {
   114  	return wsc.remoteAddr
   115  }
   116  
   117  func (wsc *websocketConn) SetDeadline(t time.Time) error {
   118  	if err := wsc.SetReadDeadline(t); err != nil {
   119  		return err
   120  	}
   121  	return wsc.SetWriteDeadline(t)
   122  }
   123  
   124  func (wsc *websocketConn) SetReadDeadline(t time.Time) error {
   125  	return wsc.conn.SetReadDeadline(t)
   126  }
   127  
   128  func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
   129  	return wsc.conn.SetWriteDeadline(t)
   130  }
   131  
   132  func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
   133  	base64DataBuf := &bytes.Buffer{}
   134  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
   135  
   136  	earlyDataBuf := bytes.NewBuffer(earlyData)
   137  	if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil {
   138  		return errors.New("failed to encode early data: " + err.Error())
   139  	}
   140  
   141  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   142  		return errors.New("failed to encode early data tail: " + errc.Error())
   143  	}
   144  
   145  	var err error
   146  	if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil {
   147  		wsedc.Close()
   148  		return errors.New("failed to dial WebSocket: " + err.Error())
   149  	}
   150  
   151  	wsedc.dialed <- true
   152  	if earlyDataBuf.Len() != 0 {
   153  		_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
   154  	}
   155  
   156  	return err
   157  }
   158  
   159  func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
   160  	if wsedc.closed {
   161  		return 0, io.ErrClosedPipe
   162  	}
   163  	if wsedc.Conn == nil {
   164  		if err := wsedc.Dial(b); err != nil {
   165  			return 0, err
   166  		}
   167  		return len(b), nil
   168  	}
   169  
   170  	return wsedc.Conn.Write(b)
   171  }
   172  
   173  func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
   174  	if wsedc.closed {
   175  		return 0, io.ErrClosedPipe
   176  	}
   177  	if wsedc.Conn == nil {
   178  		select {
   179  		case <-wsedc.ctx.Done():
   180  			return 0, io.ErrUnexpectedEOF
   181  		case <-wsedc.dialed:
   182  		}
   183  	}
   184  	return wsedc.Conn.Read(b)
   185  }
   186  
   187  func (wsedc *websocketWithEarlyDataConn) Close() error {
   188  	wsedc.closed = true
   189  	wsedc.cancel()
   190  	if wsedc.Conn == nil {
   191  		return nil
   192  	}
   193  	return wsedc.Conn.Close()
   194  }
   195  
   196  func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
   197  	if wsedc.Conn == nil {
   198  		return wsedc.underlay.LocalAddr()
   199  	}
   200  	return wsedc.Conn.LocalAddr()
   201  }
   202  
   203  func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
   204  	if wsedc.Conn == nil {
   205  		return wsedc.underlay.RemoteAddr()
   206  	}
   207  	return wsedc.Conn.RemoteAddr()
   208  }
   209  
   210  func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
   211  	if err := wsedc.SetReadDeadline(t); err != nil {
   212  		return err
   213  	}
   214  	return wsedc.SetWriteDeadline(t)
   215  }
   216  
   217  func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
   218  	if wsedc.Conn == nil {
   219  		return nil
   220  	}
   221  	return wsedc.Conn.SetReadDeadline(t)
   222  }
   223  
   224  func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
   225  	if wsedc.Conn == nil {
   226  		return nil
   227  	}
   228  	return wsedc.Conn.SetWriteDeadline(t)
   229  }
   230  
   231  func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   232  	ctx, cancel := context.WithCancel(context.Background())
   233  	conn = &websocketWithEarlyDataConn{
   234  		dialed:   make(chan bool, 1),
   235  		cancel:   cancel,
   236  		ctx:      ctx,
   237  		underlay: conn,
   238  		config:   c,
   239  	}
   240  	return conn, nil
   241  }
   242  
   243  func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
   244  	dialer := &websocket.Dialer{
   245  		NetDial: func(network, addr string) (net.Conn, error) {
   246  			return conn, nil
   247  		},
   248  		ReadBufferSize:   4 * 1024,
   249  		WriteBufferSize:  4 * 1024,
   250  		HandshakeTimeout: time.Second * 8,
   251  	}
   252  
   253  	scheme := "ws"
   254  	if c.TLS {
   255  		scheme = "wss"
   256  		dialer.TLSClientConfig = c.TLSConfig
   257  	}
   258  
   259  	u, err := url.Parse(c.Path)
   260  	if err != nil {
   261  		return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
   262  	}
   263  
   264  	uri := url.URL{
   265  		Scheme:   scheme,
   266  		Host:     net.JoinHostPort(c.Host, c.Port),
   267  		Path:     u.Path,
   268  		RawQuery: u.RawQuery,
   269  	}
   270  
   271  	headers := http.Header{}
   272  	if c.Headers != nil {
   273  		for k := range c.Headers {
   274  			headers.Add(k, c.Headers.Get(k))
   275  		}
   276  	}
   277  
   278  	if earlyData != nil {
   279  		if c.EarlyDataHeaderName == "" {
   280  			uri.Path += earlyData.String()
   281  		} else {
   282  			headers.Set(c.EarlyDataHeaderName, earlyData.String())
   283  		}
   284  	}
   285  
   286  	wsConn, resp, err := dialer.Dial(uri.String(), headers)
   287  	if err != nil {
   288  		reason := err.Error()
   289  		if resp != nil {
   290  			reason = resp.Status
   291  		}
   292  		return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason)
   293  	}
   294  
   295  	return &websocketConn{
   296  		conn:       wsConn,
   297  		remoteAddr: conn.RemoteAddr(),
   298  	}, nil
   299  }
   300  
   301  func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   302  	if u, err := url.Parse(c.Path); err == nil {
   303  		if q := u.Query(); q.Get("ed") != "" {
   304  			if ed, err := strconv.Atoi(q.Get("ed")); err == nil {
   305  				c.MaxEarlyData = ed
   306  				c.EarlyDataHeaderName = "Sec-WebSocket-Protocol"
   307  				q.Del("ed")
   308  				u.RawQuery = q.Encode()
   309  				c.Path = u.String()
   310  			}
   311  		}
   312  	}
   313  
   314  	if c.MaxEarlyData > 0 {
   315  		return streamWebsocketWithEarlyDataConn(conn, c)
   316  	}
   317  
   318  	return streamWebsocketConn(conn, c, nil)
   319  }