github.com/sagernet/sing-box@v1.2.7/transport/v2raywebsocket/client.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"net/url"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing-box/adapter"
    11  	"github.com/sagernet/sing-box/common/tls"
    12  	"github.com/sagernet/sing-box/option"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  	sHTTP "github.com/sagernet/sing/protocol/http"
    17  	"github.com/sagernet/websocket"
    18  )
    19  
    20  var _ adapter.V2RayClientTransport = (*Client)(nil)
    21  
    22  type Client struct {
    23  	dialer              *websocket.Dialer
    24  	uri                 string
    25  	headers             http.Header
    26  	maxEarlyData        uint32
    27  	earlyDataHeaderName string
    28  }
    29  
    30  func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
    31  	wsDialer := &websocket.Dialer{
    32  		ReadBufferSize:   4 * 1024,
    33  		WriteBufferSize:  4 * 1024,
    34  		HandshakeTimeout: time.Second * 8,
    35  	}
    36  	if tlsConfig != nil {
    37  		if len(tlsConfig.NextProtos()) == 0 {
    38  			tlsConfig.SetNextProtos([]string{"http/1.1"})
    39  		}
    40  		wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
    41  			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
    42  			if err != nil {
    43  				return nil, err
    44  			}
    45  			return tls.ClientHandshake(ctx, conn, tlsConfig)
    46  		}
    47  	} else {
    48  		wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
    49  			return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
    50  		}
    51  	}
    52  	var uri url.URL
    53  	if tlsConfig == nil {
    54  		uri.Scheme = "ws"
    55  	} else {
    56  		uri.Scheme = "wss"
    57  	}
    58  	uri.Host = serverAddr.String()
    59  	uri.Path = options.Path
    60  	err := sHTTP.URLSetPath(&uri, options.Path)
    61  	if err != nil {
    62  		return nil
    63  	}
    64  	headers := make(http.Header)
    65  	for key, value := range options.Headers {
    66  		headers[key] = value
    67  	}
    68  	return &Client{
    69  		wsDialer,
    70  		uri.String(),
    71  		headers,
    72  		options.MaxEarlyData,
    73  		options.EarlyDataHeaderName,
    74  	}
    75  }
    76  
    77  func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
    78  	if c.maxEarlyData <= 0 {
    79  		conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
    80  		if err == nil {
    81  			return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
    82  		}
    83  		return nil, wrapDialError(response, err)
    84  	} else {
    85  		return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
    86  	}
    87  }
    88  
    89  func wrapDialError(response *http.Response, err error) error {
    90  	if response == nil {
    91  		return err
    92  	}
    93  	return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
    94  }