github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2raywebsocket/client.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"net/url"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/sagernet/sing-box/adapter"
    12  	"github.com/sagernet/sing-box/common/tls"
    13  	C "github.com/sagernet/sing-box/constant"
    14  	"github.com/sagernet/sing-box/option"
    15  	"github.com/sagernet/sing/common/buf"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	"github.com/sagernet/sing/common/bufio/deadline"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  	sHTTP "github.com/sagernet/sing/protocol/http"
    22  	"github.com/sagernet/ws"
    23  )
    24  
    25  var _ adapter.V2RayClientTransport = (*Client)(nil)
    26  
    27  type Client struct {
    28  	dialer              N.Dialer
    29  	tlsConfig           tls.Config
    30  	serverAddr          M.Socksaddr
    31  	requestURL          url.URL
    32  	headers             http.Header
    33  	maxEarlyData        uint32
    34  	earlyDataHeaderName string
    35  }
    36  
    37  func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
    38  	if tlsConfig != nil {
    39  		if len(tlsConfig.NextProtos()) == 0 {
    40  			tlsConfig.SetNextProtos([]string{"http/1.1"})
    41  		}
    42  	}
    43  	var requestURL url.URL
    44  	if tlsConfig == nil {
    45  		requestURL.Scheme = "ws"
    46  	} else {
    47  		requestURL.Scheme = "wss"
    48  	}
    49  	requestURL.Host = serverAddr.String()
    50  	requestURL.Path = options.Path
    51  	err := sHTTP.URLSetPath(&requestURL, options.Path)
    52  	if err != nil {
    53  		return nil, E.Cause(err, "parse path")
    54  	}
    55  	if !strings.HasPrefix(requestURL.Path, "/") {
    56  		requestURL.Path = "/" + requestURL.Path
    57  	}
    58  	headers := options.Headers.Build()
    59  	if host := headers.Get("Host"); host != "" {
    60  		headers.Del("Host")
    61  		requestURL.Host = host
    62  	}
    63  	if headers.Get("User-Agent") == "" {
    64  		headers.Set("User-Agent", "Go-http-client/1.1")
    65  	}
    66  	return &Client{
    67  		dialer,
    68  		tlsConfig,
    69  		serverAddr,
    70  		requestURL,
    71  		headers,
    72  		options.MaxEarlyData,
    73  		options.EarlyDataHeaderName,
    74  	}, nil
    75  }
    76  
    77  func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers http.Header) (*WebsocketConn, error) {
    78  	conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	if c.tlsConfig != nil {
    83  		conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  	}
    88  	var deadlineConn net.Conn
    89  	if deadline.NeedAdditionalReadDeadline(conn) {
    90  		deadlineConn = deadline.NewConn(conn)
    91  	} else {
    92  		deadlineConn = conn
    93  	}
    94  	err = deadlineConn.SetDeadline(time.Now().Add(C.TCPTimeout))
    95  	if err != nil {
    96  		return nil, E.Cause(err, "set read deadline")
    97  	}
    98  	var protocols []string
    99  	if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" {
   100  		protocols = []string{protocolHeader}
   101  		headers.Del("Sec-WebSocket-Protocol")
   102  	}
   103  	reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(deadlineConn, requestURL)
   104  	deadlineConn.SetDeadline(time.Time{})
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	if reader != nil {
   109  		buffer := buf.NewSize(reader.Buffered())
   110  		_, err = buffer.ReadFullFrom(reader, buffer.Len())
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  		conn = bufio.NewCachedConn(conn, buffer)
   115  	}
   116  	return NewConn(conn, nil, ws.StateClientSide), nil
   117  }
   118  
   119  func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
   120  	if c.maxEarlyData <= 0 {
   121  		conn, err := c.dialContext(ctx, &c.requestURL, c.headers)
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		return conn, nil
   126  	} else {
   127  		return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
   128  	}
   129  }