github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/websocket/dialer.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"io"
     8  	gonet "net"
     9  	"net/http"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  
    14  	core "github.com/v2fly/v2ray-core/v5"
    15  	"github.com/v2fly/v2ray-core/v5/common"
    16  	"github.com/v2fly/v2ray-core/v5/common/net"
    17  	"github.com/v2fly/v2ray-core/v5/common/session"
    18  	"github.com/v2fly/v2ray-core/v5/features/extension"
    19  	"github.com/v2fly/v2ray-core/v5/transport/internet"
    20  	"github.com/v2fly/v2ray-core/v5/transport/internet/security"
    21  )
    22  
    23  // Dial dials a WebSocket connection to the given destination.
    24  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
    25  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    26  
    27  	conn, err := dialWebsocket(ctx, dest, streamSettings)
    28  	if err != nil {
    29  		return nil, newError("failed to dial WebSocket").Base(err)
    30  	}
    31  	return internet.Connection(conn), nil
    32  }
    33  
    34  func init() {
    35  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    36  }
    37  
    38  func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
    39  	wsSettings := streamSettings.ProtocolSettings.(*Config)
    40  
    41  	dialer := &websocket.Dialer{
    42  		NetDial: func(network, addr string) (net.Conn, error) {
    43  			return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
    44  		},
    45  		ReadBufferSize:   4 * 1024,
    46  		WriteBufferSize:  4 * 1024,
    47  		HandshakeTimeout: time.Second * 8,
    48  	}
    49  
    50  	protocol := "ws"
    51  
    52  	securityEngine, err := security.CreateSecurityEngineFromSettings(ctx, streamSettings)
    53  	if err != nil {
    54  		return nil, newError("unable to create security engine").Base(err)
    55  	}
    56  
    57  	if securityEngine != nil {
    58  		protocol = "wss"
    59  
    60  		dialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (gonet.Conn, error) {
    61  			conn, err := dialer.NetDial(network, addr)
    62  			if err != nil {
    63  				return nil, newError("dial TLS connection failed").Base(err)
    64  			}
    65  			conn, err = securityEngine.Client(conn,
    66  				security.OptionWithDestination{Dest: dest},
    67  				security.OptionWithALPN{ALPNs: []string{"http/1.1"}})
    68  			if err != nil {
    69  				return nil, newError("unable to create security protocol client from security engine").Base(err)
    70  			}
    71  			return conn, nil
    72  		}
    73  	}
    74  
    75  	host := dest.NetAddr()
    76  	if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
    77  		host = dest.Address.String()
    78  	}
    79  	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
    80  
    81  	if wsSettings.UseBrowserForwarding {
    82  		var forwarder extension.BrowserForwarder
    83  		err := core.RequireFeatures(ctx, func(Forwarder extension.BrowserForwarder) {
    84  			forwarder = Forwarder
    85  		})
    86  		if err != nil {
    87  			return nil, newError("cannot find browser forwarder service").Base(err)
    88  		}
    89  		if wsSettings.MaxEarlyData != 0 {
    90  			return newRelayedConnectionWithDelayedDial(&dialerWithEarlyDataRelayed{
    91  				forwarder: forwarder,
    92  				uriBase:   uri,
    93  				config:    wsSettings,
    94  			}), nil
    95  		}
    96  		conn, err := forwarder.DialWebsocket(uri, nil)
    97  		if err != nil {
    98  			return nil, newError("cannot dial with browser forwarder service").Base(err)
    99  		}
   100  		return newRelayedConnection(conn), nil
   101  	}
   102  
   103  	if wsSettings.MaxEarlyData != 0 {
   104  		return newConnectionWithDelayedDial(&dialerWithEarlyData{
   105  			dialer:  dialer,
   106  			uriBase: uri,
   107  			config:  wsSettings,
   108  		}), nil
   109  	}
   110  
   111  	conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) // nolint: bodyclose
   112  	if err != nil {
   113  		var reason string
   114  		if resp != nil {
   115  			reason = resp.Status
   116  		}
   117  		return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
   118  	}
   119  
   120  	return newConnection(conn, conn.RemoteAddr()), nil
   121  }
   122  
   123  type dialerWithEarlyData struct {
   124  	dialer  *websocket.Dialer
   125  	uriBase string
   126  	config  *Config
   127  }
   128  
   129  func (d dialerWithEarlyData) Dial(earlyData []byte) (*websocket.Conn, error) {
   130  	earlyDataBuf := bytes.NewBuffer(nil)
   131  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
   132  
   133  	earlydata := bytes.NewReader(earlyData)
   134  	limitedEarlyDatareader := io.LimitReader(earlydata, int64(d.config.MaxEarlyData))
   135  	n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
   136  	if encerr != nil {
   137  		return nil, newError("websocket delayed dialer cannot encode early data").Base(encerr)
   138  	}
   139  
   140  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   141  		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
   142  	}
   143  
   144  	dialFunction := func() (*websocket.Conn, *http.Response, error) {
   145  		return d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
   146  	}
   147  
   148  	if d.config.EarlyDataHeaderName != "" {
   149  		dialFunction = func() (*websocket.Conn, *http.Response, error) {
   150  			earlyDataStr := earlyDataBuf.String()
   151  			currentHeader := d.config.GetRequestHeader()
   152  			currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
   153  			return d.dialer.Dial(d.uriBase, currentHeader)
   154  		}
   155  	}
   156  
   157  	conn, resp, err := dialFunction() // nolint: bodyclose
   158  	if err != nil {
   159  		var reason string
   160  		if resp != nil {
   161  			reason = resp.Status
   162  		}
   163  		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
   164  	}
   165  	if n != int64(len(earlyData)) {
   166  		if errWrite := conn.WriteMessage(websocket.BinaryMessage, earlyData[n:]); errWrite != nil {
   167  			return nil, newError("failed to dial to (", d.uriBase, ") with early data as write of remainder early data failed: ").Base(err)
   168  		}
   169  	}
   170  	return conn, nil
   171  }
   172  
   173  type dialerWithEarlyDataRelayed struct {
   174  	forwarder extension.BrowserForwarder
   175  	uriBase   string
   176  	config    *Config
   177  }
   178  
   179  func (d dialerWithEarlyDataRelayed) Dial(earlyData []byte) (io.ReadWriteCloser, error) {
   180  	earlyDataBuf := bytes.NewBuffer(nil)
   181  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
   182  
   183  	earlydata := bytes.NewReader(earlyData)
   184  	limitedEarlyDatareader := io.LimitReader(earlydata, int64(d.config.MaxEarlyData))
   185  	n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
   186  	if encerr != nil {
   187  		return nil, newError("websocket delayed dialer cannot encode early data").Base(encerr)
   188  	}
   189  
   190  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   191  		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
   192  	}
   193  
   194  	dialFunction := func() (io.ReadWriteCloser, error) {
   195  		return d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
   196  	}
   197  
   198  	if d.config.EarlyDataHeaderName != "" {
   199  		earlyDataStr := earlyDataBuf.String()
   200  		currentHeader := d.config.GetRequestHeader()
   201  		currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
   202  		return d.forwarder.DialWebsocket(d.uriBase, currentHeader)
   203  	}
   204  
   205  	conn, err := dialFunction()
   206  	if err != nil {
   207  		var reason string
   208  		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
   209  	}
   210  	if n != int64(len(earlyData)) {
   211  		if _, errWrite := conn.Write(earlyData[n:]); errWrite != nil {
   212  			return nil, newError("failed to dial to (", d.uriBase, ") with early data as write of remainder early data failed: ").Base(err)
   213  		}
   214  	}
   215  	return conn, nil
   216  }