github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/transport/vmess/websocket.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/gorilla/websocket"
    20  )
    21  
    22  type websocketConn struct {
    23  	conn       *websocket.Conn
    24  	reader     io.Reader
    25  	remoteAddr net.Addr
    26  
    27  	// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
    28  	rMux sync.Mutex
    29  	wMux sync.Mutex
    30  }
    31  
    32  type websocketWithEarlyDataConn struct {
    33  	net.Conn
    34  	underlay net.Conn
    35  	closed   bool
    36  	dialed   chan bool
    37  	cancel   context.CancelFunc
    38  	ctx      context.Context
    39  	config   *WebsocketConfig
    40  }
    41  
    42  type WebsocketConfig struct {
    43  	Host                string
    44  	Port                string
    45  	Path                string
    46  	Headers             http.Header
    47  	TLS                 bool
    48  	TLSConfig           *tls.Config
    49  	SkipCertVerify      bool
    50  	ServerName          string
    51  	SessionCache        tls.ClientSessionCache
    52  	MaxEarlyData        int
    53  	EarlyDataHeaderName string
    54  }
    55  
    56  // Read implements net.Conn.Read()
    57  func (wsc *websocketConn) Read(b []byte) (int, error) {
    58  	wsc.rMux.Lock()
    59  	defer wsc.rMux.Unlock()
    60  	for {
    61  		reader, err := wsc.getReader()
    62  		if err != nil {
    63  			return 0, err
    64  		}
    65  
    66  		nBytes, err := reader.Read(b)
    67  		if err == io.EOF {
    68  			wsc.reader = nil
    69  			continue
    70  		}
    71  		return nBytes, err
    72  	}
    73  }
    74  
    75  // Write implements io.Writer.
    76  func (wsc *websocketConn) Write(b []byte) (int, error) {
    77  	wsc.wMux.Lock()
    78  	defer wsc.wMux.Unlock()
    79  	if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
    80  		return 0, err
    81  	}
    82  	return len(b), nil
    83  }
    84  
    85  func (wsc *websocketConn) Close() error {
    86  	var errors []string
    87  	if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
    88  		errors = append(errors, err.Error())
    89  	}
    90  	if err := wsc.conn.Close(); err != nil {
    91  		errors = append(errors, err.Error())
    92  	}
    93  	if len(errors) > 0 {
    94  		return fmt.Errorf("failed to close connection: %s", strings.Join(errors, ","))
    95  	}
    96  	return nil
    97  }
    98  
    99  func (wsc *websocketConn) getReader() (io.Reader, error) {
   100  	if wsc.reader != nil {
   101  		return wsc.reader, nil
   102  	}
   103  
   104  	_, reader, err := wsc.conn.NextReader()
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	wsc.reader = reader
   109  	return reader, nil
   110  }
   111  
   112  func (wsc *websocketConn) LocalAddr() net.Addr {
   113  	return wsc.conn.LocalAddr()
   114  }
   115  
   116  func (wsc *websocketConn) RemoteAddr() net.Addr {
   117  	return wsc.remoteAddr
   118  }
   119  
   120  func (wsc *websocketConn) SetDeadline(t time.Time) error {
   121  	if err := wsc.SetReadDeadline(t); err != nil {
   122  		return err
   123  	}
   124  	return wsc.SetWriteDeadline(t)
   125  }
   126  
   127  func (wsc *websocketConn) SetReadDeadline(t time.Time) error {
   128  	return wsc.conn.SetReadDeadline(t)
   129  }
   130  
   131  func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
   132  	return wsc.conn.SetWriteDeadline(t)
   133  }
   134  
   135  func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
   136  	base64DataBuf := &bytes.Buffer{}
   137  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
   138  
   139  	earlyDataBuf := bytes.NewBuffer(earlyData)
   140  	if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil {
   141  		return errors.New("failed to encode early data: " + err.Error())
   142  	}
   143  
   144  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   145  		return errors.New("failed to encode early data tail: " + errc.Error())
   146  	}
   147  
   148  	var err error
   149  	if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil {
   150  		wsedc.Close()
   151  		return errors.New("failed to dial WebSocket: " + err.Error())
   152  	}
   153  
   154  	wsedc.dialed <- true
   155  	if earlyDataBuf.Len() != 0 {
   156  		_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
   157  	}
   158  
   159  	return err
   160  }
   161  
   162  func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
   163  	if wsedc.closed {
   164  		return 0, io.ErrClosedPipe
   165  	}
   166  	if wsedc.Conn == nil {
   167  		if err := wsedc.Dial(b); err != nil {
   168  			return 0, err
   169  		}
   170  		return len(b), nil
   171  	}
   172  
   173  	return wsedc.Conn.Write(b)
   174  }
   175  
   176  func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
   177  	if wsedc.closed {
   178  		return 0, io.ErrClosedPipe
   179  	}
   180  	if wsedc.Conn == nil {
   181  		select {
   182  		case <-wsedc.ctx.Done():
   183  			return 0, io.ErrUnexpectedEOF
   184  		case <-wsedc.dialed:
   185  		}
   186  	}
   187  	return wsedc.Conn.Read(b)
   188  }
   189  
   190  func (wsedc *websocketWithEarlyDataConn) Close() error {
   191  	wsedc.closed = true
   192  	wsedc.cancel()
   193  	if wsedc.Conn == nil {
   194  		return nil
   195  	}
   196  	return wsedc.Conn.Close()
   197  }
   198  
   199  func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
   200  	if wsedc.Conn == nil {
   201  		return wsedc.underlay.LocalAddr()
   202  	}
   203  	return wsedc.Conn.LocalAddr()
   204  }
   205  
   206  func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
   207  	if wsedc.Conn == nil {
   208  		return wsedc.underlay.RemoteAddr()
   209  	}
   210  	return wsedc.Conn.RemoteAddr()
   211  }
   212  
   213  func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
   214  	if err := wsedc.SetReadDeadline(t); err != nil {
   215  		return err
   216  	}
   217  	return wsedc.SetWriteDeadline(t)
   218  }
   219  
   220  func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
   221  	if wsedc.Conn == nil {
   222  		return nil
   223  	}
   224  	return wsedc.Conn.SetReadDeadline(t)
   225  }
   226  
   227  func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
   228  	if wsedc.Conn == nil {
   229  		return nil
   230  	}
   231  	return wsedc.Conn.SetWriteDeadline(t)
   232  }
   233  
   234  func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   235  	ctx, cancel := context.WithCancel(context.Background())
   236  	conn = &websocketWithEarlyDataConn{
   237  		dialed:   make(chan bool, 1),
   238  		cancel:   cancel,
   239  		ctx:      ctx,
   240  		underlay: conn,
   241  		config:   c,
   242  	}
   243  	return conn, nil
   244  }
   245  
   246  func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
   247  	dialer := &websocket.Dialer{
   248  		NetDial: func(network, addr string) (net.Conn, error) {
   249  			return conn, nil
   250  		},
   251  		ReadBufferSize:   4 * 1024,
   252  		WriteBufferSize:  4 * 1024,
   253  		HandshakeTimeout: time.Second * 8,
   254  	}
   255  
   256  	scheme := "ws"
   257  	if c.TLS {
   258  		scheme = "wss"
   259  		if c.TLSConfig != nil {
   260  			dialer.TLSClientConfig = c.TLSConfig
   261  		} else {
   262  			dialer.TLSClientConfig = &tls.Config{
   263  				ServerName:         c.Host,
   264  				InsecureSkipVerify: c.SkipCertVerify,
   265  				ClientSessionCache: c.SessionCache,
   266  			}
   267  		}
   268  		if c.ServerName != "" {
   269  			dialer.TLSClientConfig.ServerName = c.ServerName
   270  		} else if host := c.Headers.Get("Host"); host != "" {
   271  			dialer.TLSClientConfig.ServerName = host
   272  		}
   273  	}
   274  
   275  	u, err := url.Parse(c.Path)
   276  	if err != nil {
   277  		return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
   278  	}
   279  
   280  	uri := url.URL{
   281  		Scheme:   scheme,
   282  		Host:     net.JoinHostPort(c.Host, c.Port),
   283  		Path:     u.Path,
   284  		RawQuery: u.RawQuery,
   285  	}
   286  
   287  	headers := http.Header{}
   288  	if c.Headers != nil {
   289  		for k := range c.Headers {
   290  			headers.Add(k, c.Headers.Get(k))
   291  		}
   292  	}
   293  
   294  	if earlyData != nil {
   295  		if c.EarlyDataHeaderName == "" {
   296  			uri.Path += earlyData.String()
   297  		} else {
   298  			headers.Set(c.EarlyDataHeaderName, earlyData.String())
   299  		}
   300  	}
   301  
   302  	wsConn, resp, err := dialer.Dial(uri.String(), headers)
   303  	if err != nil {
   304  		reason := err.Error()
   305  		if resp != nil {
   306  			reason = resp.Status
   307  		}
   308  		return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason)
   309  	}
   310  
   311  	return &websocketConn{
   312  		conn:       wsConn,
   313  		remoteAddr: conn.RemoteAddr(),
   314  	}, nil
   315  }
   316  
   317  func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
   318  
   319  	if u, err := url.Parse(c.Path); err == nil {
   320  		if q := u.Query(); q.Get("ed") != "" {
   321  			if ed, err := strconv.Atoi(q.Get("ed")); err == nil {
   322  				c.MaxEarlyData = ed
   323  				c.EarlyDataHeaderName = "Sec-WebSocket-Protocol"
   324  				q.Del("ed")
   325  				u.RawQuery = q.Encode()
   326  				c.Path = u.String()
   327  			}
   328  		}
   329  	}
   330  
   331  	// dialer := &websocket.Dialer{
   332  	// 	NetDial: func(network, addr string) (net.Conn, error) {
   333  	// 		return conn, nil
   334  	// 	},
   335  	// 	ReadBufferSize:   4 * 1024,
   336  	// 	WriteBufferSize:  4 * 1024,
   337  	// 	HandshakeTimeout: time.Second * 8,
   338  	// }
   339  
   340  	// scheme := "ws"
   341  	// if c.TLS {
   342  	// 	scheme = "wss"
   343  	// 	if c.TLSConfig != nil {
   344  	// 		dialer.TLSClientConfig = c.TLSConfig
   345  	// 	} else {
   346  	// 		dialer.TLSClientConfig = &tls.Config{
   347  	// 			ServerName:         c.Host,
   348  	// 			InsecureSkipVerify: c.SkipCertVerify,
   349  	// 			ClientSessionCache: c.SessionCache,
   350  	// 		}
   351  	// 	}
   352  	// 	if c.ServerName != "" {
   353  	// 		dialer.TLSClientConfig.ServerName = c.ServerName
   354  	// 	} else if host := c.Headers.Get("Host"); host != "" {
   355  	// 		dialer.TLSClientConfig.ServerName = host
   356  	// 	}
   357  	// }
   358  
   359  	// uri := url.URL{
   360  	// 	Scheme: scheme,
   361  	// 	Host:   net.JoinHostPort(c.Host, c.Port),
   362  	// 	Path:   c.Path,
   363  	// }
   364  
   365  	// headers := http.Header{}
   366  	// if c.Headers != nil {
   367  	// 	for k := range c.Headers {
   368  	// 		headers.Add(k, c.Headers.Get(k))
   369  	// 	}
   370  	// }
   371  
   372  	// wsConn, resp, err := dialer.Dial(uri.String(), headers)
   373  	// if err != nil {
   374  	// 	reason := err.Error()
   375  	// 	if resp != nil {
   376  	// 		reason = resp.Status
   377  	// 	}
   378  	// 	return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason)
   379  	// }
   380  
   381  	// return &websocketConn{
   382  	// 	conn:       wsConn,
   383  	// 	remoteAddr: conn.RemoteAddr(),
   384  	// }, nil
   385  	if c.MaxEarlyData > 0 {
   386  		return streamWebsocketWithEarlyDataConn(conn, c)
   387  	}
   388  
   389  	return streamWebsocketConn(conn, c, nil)
   390  }