github.com/v2fly/v2ray-core/v4@v4.45.2/transport/internet/websocket/dialer.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  package websocket
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"encoding/base64"
    10  	"io"
    11  	"net/http"
    12  	"time"
    13  
    14  	"github.com/gorilla/websocket"
    15  
    16  	core "github.com/v2fly/v2ray-core/v4"
    17  	"github.com/v2fly/v2ray-core/v4/common"
    18  	"github.com/v2fly/v2ray-core/v4/common/net"
    19  	"github.com/v2fly/v2ray-core/v4/common/session"
    20  	"github.com/v2fly/v2ray-core/v4/features/extension"
    21  	"github.com/v2fly/v2ray-core/v4/transport/internet"
    22  	"github.com/v2fly/v2ray-core/v4/transport/internet/tls"
    23  )
    24  
    25  // Dial dials a WebSocket connection to the given destination.
    26  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
    27  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    28  
    29  	conn, err := dialWebsocket(ctx, dest, streamSettings)
    30  	if err != nil {
    31  		return nil, newError("failed to dial WebSocket").Base(err)
    32  	}
    33  	return internet.Connection(conn), nil
    34  }
    35  
    36  func init() {
    37  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    38  }
    39  
    40  func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
    41  	wsSettings := streamSettings.ProtocolSettings.(*Config)
    42  
    43  	dialer := &websocket.Dialer{
    44  		NetDial: func(network, addr string) (net.Conn, error) {
    45  			return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
    46  		},
    47  		ReadBufferSize:   4 * 1024,
    48  		WriteBufferSize:  4 * 1024,
    49  		HandshakeTimeout: time.Second * 8,
    50  	}
    51  
    52  	protocol := "ws"
    53  
    54  	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
    55  		protocol = "wss"
    56  		dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
    57  	}
    58  
    59  	host := dest.NetAddr()
    60  	if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
    61  		host = dest.Address.String()
    62  	}
    63  	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
    64  
    65  	if wsSettings.UseBrowserForwarding {
    66  		var forwarder extension.BrowserForwarder
    67  		err := core.RequireFeatures(ctx, func(Forwarder extension.BrowserForwarder) {
    68  			forwarder = Forwarder
    69  		})
    70  		if err != nil {
    71  			return nil, newError("cannot find browser forwarder service").Base(err)
    72  		}
    73  		if wsSettings.MaxEarlyData != 0 {
    74  			return newRelayedConnectionWithDelayedDial(&dialerWithEarlyDataRelayed{
    75  				forwarder: forwarder,
    76  				uriBase:   uri,
    77  				config:    wsSettings,
    78  			}), nil
    79  		}
    80  		conn, err := forwarder.DialWebsocket(uri, nil)
    81  		if err != nil {
    82  			return nil, newError("cannot dial with browser forwarder service").Base(err)
    83  		}
    84  		return newRelayedConnection(conn), nil
    85  	}
    86  
    87  	if wsSettings.MaxEarlyData != 0 {
    88  		return newConnectionWithDelayedDial(&dialerWithEarlyData{
    89  			dialer:  dialer,
    90  			uriBase: uri,
    91  			config:  wsSettings,
    92  		}), nil
    93  	}
    94  
    95  	conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) // nolint: bodyclose
    96  	if err != nil {
    97  		var reason string
    98  		if resp != nil {
    99  			reason = resp.Status
   100  		}
   101  		return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
   102  	}
   103  
   104  	return newConnection(conn, conn.RemoteAddr()), nil
   105  }
   106  
   107  type dialerWithEarlyData struct {
   108  	dialer  *websocket.Dialer
   109  	uriBase string
   110  	config  *Config
   111  }
   112  
   113  func (d dialerWithEarlyData) Dial(earlyData []byte) (*websocket.Conn, error) {
   114  	earlyDataBuf := bytes.NewBuffer(nil)
   115  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
   116  
   117  	earlydata := bytes.NewReader(earlyData)
   118  	limitedEarlyDatareader := io.LimitReader(earlydata, int64(d.config.MaxEarlyData))
   119  	n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
   120  	if encerr != nil {
   121  		return nil, newError("websocket delayed dialer cannot encode early data").Base(encerr)
   122  	}
   123  
   124  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   125  		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
   126  	}
   127  
   128  	dialFunction := func() (*websocket.Conn, *http.Response, error) {
   129  		return d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
   130  	}
   131  
   132  	if d.config.EarlyDataHeaderName != "" {
   133  		dialFunction = func() (*websocket.Conn, *http.Response, error) {
   134  			earlyDataStr := earlyDataBuf.String()
   135  			currentHeader := d.config.GetRequestHeader()
   136  			currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
   137  			return d.dialer.Dial(d.uriBase, currentHeader)
   138  		}
   139  	}
   140  
   141  	conn, resp, err := dialFunction() // nolint: bodyclose
   142  	if err != nil {
   143  		var reason string
   144  		if resp != nil {
   145  			reason = resp.Status
   146  		}
   147  		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
   148  	}
   149  	if n != int64(len(earlyData)) {
   150  		if errWrite := conn.WriteMessage(websocket.BinaryMessage, earlyData[n:]); errWrite != nil {
   151  			return nil, newError("failed to dial to (", d.uriBase, ") with early data as write of remainder early data failed: ").Base(err)
   152  		}
   153  	}
   154  	return conn, nil
   155  }
   156  
   157  type dialerWithEarlyDataRelayed struct {
   158  	forwarder extension.BrowserForwarder
   159  	uriBase   string
   160  	config    *Config
   161  }
   162  
   163  func (d dialerWithEarlyDataRelayed) Dial(earlyData []byte) (io.ReadWriteCloser, error) {
   164  	earlyDataBuf := bytes.NewBuffer(nil)
   165  	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
   166  
   167  	earlydata := bytes.NewReader(earlyData)
   168  	limitedEarlyDatareader := io.LimitReader(earlydata, int64(d.config.MaxEarlyData))
   169  	n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
   170  	if encerr != nil {
   171  		return nil, newError("websocket delayed dialer cannot encode early data").Base(encerr)
   172  	}
   173  
   174  	if errc := base64EarlyDataEncoder.Close(); errc != nil {
   175  		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
   176  	}
   177  
   178  	dialFunction := func() (io.ReadWriteCloser, error) {
   179  		return d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
   180  	}
   181  
   182  	if d.config.EarlyDataHeaderName != "" {
   183  		earlyDataStr := earlyDataBuf.String()
   184  		currentHeader := d.config.GetRequestHeader()
   185  		currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
   186  		return d.forwarder.DialWebsocket(d.uriBase, currentHeader)
   187  	}
   188  
   189  	conn, err := dialFunction()
   190  	if err != nil {
   191  		var reason string
   192  		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
   193  	}
   194  	if n != int64(len(earlyData)) {
   195  		if _, errWrite := conn.Write(earlyData[n:]); errWrite != nil {
   196  			return nil, newError("failed to dial to (", d.uriBase, ") with early data as write of remainder early data failed: ").Base(err)
   197  		}
   198  	}
   199  	return conn, nil
   200  }