github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2rayhttpupgrade/client.go (about) 1 package v2rayhttpupgrade 2 3 import ( 4 std_bufio "bufio" 5 "context" 6 "net" 7 "net/http" 8 "net/url" 9 "strings" 10 11 "github.com/sagernet/sing-box/adapter" 12 "github.com/sagernet/sing-box/common/tls" 13 "github.com/sagernet/sing-box/option" 14 "github.com/sagernet/sing/common/buf" 15 "github.com/sagernet/sing/common/bufio" 16 E "github.com/sagernet/sing/common/exceptions" 17 M "github.com/sagernet/sing/common/metadata" 18 N "github.com/sagernet/sing/common/network" 19 sHTTP "github.com/sagernet/sing/protocol/http" 20 ) 21 22 var _ adapter.V2RayClientTransport = (*Client)(nil) 23 24 type Client struct { 25 dialer N.Dialer 26 tlsConfig tls.Config 27 serverAddr M.Socksaddr 28 requestURL url.URL 29 headers http.Header 30 host string 31 } 32 33 func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayHTTPUpgradeOptions, tlsConfig tls.Config) (*Client, error) { 34 if tlsConfig != nil { 35 if len(tlsConfig.NextProtos()) == 0 { 36 tlsConfig.SetNextProtos([]string{"http/1.1"}) 37 } 38 } 39 var host string 40 if options.Host != "" { 41 host = options.Host 42 } else if tlsConfig != nil && tlsConfig.ServerName() != "" { 43 host = tlsConfig.ServerName() 44 } else { 45 host = serverAddr.String() 46 } 47 var requestURL url.URL 48 if tlsConfig == nil { 49 requestURL.Scheme = "http" 50 } else { 51 requestURL.Scheme = "https" 52 } 53 requestURL.Host = serverAddr.String() 54 requestURL.Path = options.Path 55 err := sHTTP.URLSetPath(&requestURL, options.Path) 56 if err != nil { 57 return nil, E.Cause(err, "parse path") 58 } 59 if !strings.HasPrefix(requestURL.Path, "/") { 60 requestURL.Path = "/" + requestURL.Path 61 } 62 headers := make(http.Header) 63 for key, value := range options.Headers { 64 headers[key] = value 65 } 66 return &Client{ 67 dialer: dialer, 68 tlsConfig: tlsConfig, 69 serverAddr: serverAddr, 70 requestURL: requestURL, 71 headers: headers, 72 host: host, 73 }, nil 74 } 75 76 func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { 77 conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) 78 if err != nil { 79 return nil, err 80 } 81 if c.tlsConfig != nil { 82 conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig) 83 if err != nil { 84 return nil, err 85 } 86 } 87 request := &http.Request{ 88 Method: http.MethodGet, 89 URL: &c.requestURL, 90 Header: c.headers.Clone(), 91 Host: c.host, 92 } 93 request.Header.Set("Connection", "Upgrade") 94 request.Header.Set("Upgrade", "websocket") 95 err = request.Write(conn) 96 if err != nil { 97 return nil, err 98 } 99 bufReader := std_bufio.NewReader(conn) 100 response, err := http.ReadResponse(bufReader, request) 101 if err != nil { 102 return nil, err 103 } 104 if response.StatusCode != 101 || 105 !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || 106 !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { 107 return nil, E.New("unexpected status: ", response.Status) 108 } 109 if bufReader.Buffered() > 0 { 110 buffer := buf.NewSize(bufReader.Buffered()) 111 _, err = buffer.ReadFullFrom(bufReader, buffer.Len()) 112 if err != nil { 113 return nil, err 114 } 115 conn = bufio.NewCachedConn(conn, buffer) 116 } 117 return conn, nil 118 }