github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2raywebsocket/client.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"net/url"
     8  	"time"
     9  
    10  	"github.com/inazumav/sing-box/adapter"
    11  	"github.com/inazumav/sing-box/common/tls"
    12  	"github.com/inazumav/sing-box/option"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  	sHTTP "github.com/sagernet/sing/protocol/http"
    17  	"github.com/sagernet/websocket"
    18  )
    19  
    20  var _ adapter.V2RayClientTransport = (*Client)(nil)
    21  
    22  type Client struct {
    23  	dialer              *websocket.Dialer
    24  	requestURL          url.URL
    25  	requestURLString    string
    26  	headers             http.Header
    27  	maxEarlyData        uint32
    28  	earlyDataHeaderName string
    29  }
    30  
    31  func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
    32  	wsDialer := &websocket.Dialer{
    33  		ReadBufferSize:   4 * 1024,
    34  		WriteBufferSize:  4 * 1024,
    35  		HandshakeTimeout: time.Second * 8,
    36  	}
    37  	if tlsConfig != nil {
    38  		if len(tlsConfig.NextProtos()) == 0 {
    39  			tlsConfig.SetNextProtos([]string{"http/1.1"})
    40  		}
    41  		wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
    42  			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
    43  			if err != nil {
    44  				return nil, err
    45  			}
    46  			tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
    47  			if err != nil {
    48  				return nil, err
    49  			}
    50  			return &deadConn{tlsConn}, nil
    51  		}
    52  	} else {
    53  		wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
    54  			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
    55  			if err != nil {
    56  				return nil, err
    57  			}
    58  			return &deadConn{conn}, nil
    59  		}
    60  	}
    61  	var requestURL url.URL
    62  	if tlsConfig == nil {
    63  		requestURL.Scheme = "ws"
    64  	} else {
    65  		requestURL.Scheme = "wss"
    66  	}
    67  	requestURL.Host = serverAddr.String()
    68  	requestURL.Path = options.Path
    69  	err := sHTTP.URLSetPath(&requestURL, options.Path)
    70  	if err != nil {
    71  		return nil
    72  	}
    73  	headers := make(http.Header)
    74  	for key, value := range options.Headers {
    75  		headers[key] = value
    76  	}
    77  	return &Client{
    78  		wsDialer,
    79  		requestURL,
    80  		requestURL.String(),
    81  		headers,
    82  		options.MaxEarlyData,
    83  		options.EarlyDataHeaderName,
    84  	}
    85  }
    86  
    87  func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
    88  	if c.maxEarlyData <= 0 {
    89  		conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
    90  		if err == nil {
    91  			return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
    92  		}
    93  		return nil, wrapDialError(response, err)
    94  	} else {
    95  		return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
    96  	}
    97  }
    98  
    99  func wrapDialError(response *http.Response, err error) error {
   100  	if response == nil {
   101  		return err
   102  	}
   103  	return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
   104  }