github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/http/client.go (about)

     1  package http
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"encoding/base64"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"sync"
    12  	"text/template"
    13  
    14  	"github.com/xmplusdev/xmcore/common"
    15  	"github.com/xmplusdev/xmcore/common/buf"
    16  	"github.com/xmplusdev/xmcore/common/bytespool"
    17  	"github.com/xmplusdev/xmcore/common/net"
    18  	"github.com/xmplusdev/xmcore/common/protocol"
    19  	"github.com/xmplusdev/xmcore/common/retry"
    20  	"github.com/xmplusdev/xmcore/common/session"
    21  	"github.com/xmplusdev/xmcore/common/signal"
    22  	"github.com/xmplusdev/xmcore/common/task"
    23  	"github.com/xmplusdev/xmcore/core"
    24  	"github.com/xmplusdev/xmcore/features/policy"
    25  	"github.com/xmplusdev/xmcore/transport"
    26  	"github.com/xmplusdev/xmcore/transport/internet"
    27  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    28  	"github.com/xmplusdev/xmcore/transport/internet/tls"
    29  	"golang.org/x/net/http2"
    30  )
    31  
    32  type Client struct {
    33  	serverPicker  protocol.ServerPicker
    34  	policyManager policy.Manager
    35  	header        []*Header
    36  }
    37  
    38  type h2Conn struct {
    39  	rawConn net.Conn
    40  	h2Conn  *http2.ClientConn
    41  }
    42  
    43  var (
    44  	cachedH2Mutex sync.Mutex
    45  	cachedH2Conns map[net.Destination]h2Conn
    46  )
    47  
    48  // NewClient create a new http client based on the given config.
    49  func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
    50  	serverList := protocol.NewServerList()
    51  	for _, rec := range config.Server {
    52  		s, err := protocol.NewServerSpecFromPB(rec)
    53  		if err != nil {
    54  			return nil, newError("failed to get server spec").Base(err)
    55  		}
    56  		serverList.AddServer(s)
    57  	}
    58  	if serverList.Size() == 0 {
    59  		return nil, newError("0 target server")
    60  	}
    61  
    62  	v := core.MustFromContext(ctx)
    63  	return &Client{
    64  		serverPicker:  protocol.NewRoundRobinServerPicker(serverList),
    65  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    66  		header:        config.Header,
    67  	}, nil
    68  }
    69  
    70  // Process implements proxy.Outbound.Process. We first create a socket tunnel via HTTP CONNECT method, then redirect all inbound traffic to that tunnel.
    71  func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    72  	outbound := session.OutboundFromContext(ctx)
    73  	if outbound == nil || !outbound.Target.IsValid() {
    74  		return newError("target not specified.")
    75  	}
    76  	outbound.Name = "http"
    77  	inbound := session.InboundFromContext(ctx)
    78  	if inbound != nil {
    79  		inbound.SetCanSpliceCopy(2)
    80  	}
    81  	target := outbound.Target
    82  	targetAddr := target.NetAddr()
    83  
    84  	if target.Network == net.Network_UDP {
    85  		return newError("UDP is not supported by HTTP outbound")
    86  	}
    87  
    88  	var user *protocol.MemoryUser
    89  	var conn stat.Connection
    90  
    91  	mbuf, _ := link.Reader.ReadMultiBuffer()
    92  	len := mbuf.Len()
    93  	firstPayload := bytespool.Alloc(len)
    94  	mbuf, _ = buf.SplitBytes(mbuf, firstPayload)
    95  	firstPayload = firstPayload[:len]
    96  
    97  	buf.ReleaseMulti(mbuf)
    98  	defer bytespool.Free(firstPayload)
    99  
   100  	header, err := fillRequestHeader(ctx, c.header)
   101  	if err != nil {
   102  		return newError("failed to fill out header").Base(err)
   103  	}
   104  
   105  	if err := retry.ExponentialBackoff(5, 100).On(func() error {
   106  		server := c.serverPicker.PickServer()
   107  		dest := server.Destination()
   108  		user = server.PickUser()
   109  
   110  		netConn, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, header, firstPayload)
   111  		if netConn != nil {
   112  			if _, ok := netConn.(*http2Conn); !ok {
   113  				if _, err := netConn.Write(firstPayload); err != nil {
   114  					netConn.Close()
   115  					return err
   116  				}
   117  			}
   118  			conn = stat.Connection(netConn)
   119  		}
   120  		return err
   121  	}); err != nil {
   122  		return newError("failed to find an available destination").Base(err)
   123  	}
   124  
   125  	defer func() {
   126  		if err := conn.Close(); err != nil {
   127  			newError("failed to closed connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   128  		}
   129  	}()
   130  
   131  	p := c.policyManager.ForLevel(0)
   132  	if user != nil {
   133  		p = c.policyManager.ForLevel(user.Level)
   134  	}
   135  
   136  	var newCtx context.Context
   137  	var newCancel context.CancelFunc
   138  	if session.TimeoutOnlyFromContext(ctx) {
   139  		newCtx, newCancel = context.WithCancel(context.Background())
   140  	}
   141  
   142  	ctx, cancel := context.WithCancel(ctx)
   143  	timer := signal.CancelAfterInactivity(ctx, func() {
   144  		cancel()
   145  		if newCancel != nil {
   146  			newCancel()
   147  		}
   148  	}, p.Timeouts.ConnectionIdle)
   149  
   150  	requestFunc := func() error {
   151  		defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   152  		return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   153  	}
   154  	responseFunc := func() error {
   155  		defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   156  		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   157  	}
   158  
   159  	if newCtx != nil {
   160  		ctx = newCtx
   161  	}
   162  
   163  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   164  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   165  		return newError("connection ends").Base(err)
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  // fillRequestHeader will fill out the template of the headers
   172  func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) {
   173  	if len(header) == 0 {
   174  		return header, nil
   175  	}
   176  
   177  	inbound := session.InboundFromContext(ctx)
   178  	outbound := session.OutboundFromContext(ctx)
   179  
   180  	if inbound == nil || outbound == nil {
   181  		return nil, newError("missing inbound or outbound metadata from context")
   182  	}
   183  
   184  	data := struct {
   185  		Source net.Destination
   186  		Target net.Destination
   187  	}{
   188  		Source: inbound.Source,
   189  		Target: outbound.Target,
   190  	}
   191  
   192  	filled := make([]*Header, len(header))
   193  	for i, h := range header {
   194  		tmpl, err := template.New(h.Key).Parse(h.Value)
   195  		if err != nil {
   196  			return nil, err
   197  		}
   198  		var buf bytes.Buffer
   199  
   200  		if err = tmpl.Execute(&buf, data); err != nil {
   201  			return nil, err
   202  		}
   203  		filled[i] = &Header{Key: h.Key, Value: buf.String()}
   204  	}
   205  
   206  	return filled, nil
   207  }
   208  
   209  // setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method
   210  func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, user *protocol.MemoryUser, dialer internet.Dialer, header []*Header, firstPayload []byte) (net.Conn, error) {
   211  	req := &http.Request{
   212  		Method: http.MethodConnect,
   213  		URL:    &url.URL{Host: target},
   214  		Header: make(http.Header),
   215  		Host:   target,
   216  	}
   217  
   218  	if user != nil && user.Account != nil {
   219  		account := user.Account.(*Account)
   220  		auth := account.GetUsername() + ":" + account.GetPassword()
   221  		req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
   222  	}
   223  
   224  	for _, h := range header {
   225  		req.Header.Set(h.Key, h.Value)
   226  	}
   227  
   228  	connectHTTP1 := func(rawConn net.Conn) (net.Conn, error) {
   229  		req.Header.Set("Proxy-Connection", "Keep-Alive")
   230  
   231  		err := req.Write(rawConn)
   232  		if err != nil {
   233  			rawConn.Close()
   234  			return nil, err
   235  		}
   236  
   237  		resp, err := http.ReadResponse(bufio.NewReader(rawConn), req)
   238  		if err != nil {
   239  			rawConn.Close()
   240  			return nil, err
   241  		}
   242  		defer resp.Body.Close()
   243  
   244  		if resp.StatusCode != http.StatusOK {
   245  			rawConn.Close()
   246  			return nil, newError("Proxy responded with non 200 code: " + resp.Status)
   247  		}
   248  		return rawConn, nil
   249  	}
   250  
   251  	connectHTTP2 := func(rawConn net.Conn, h2clientConn *http2.ClientConn) (net.Conn, error) {
   252  		pr, pw := io.Pipe()
   253  		req.Body = pr
   254  
   255  		var pErr error
   256  		var wg sync.WaitGroup
   257  		wg.Add(1)
   258  
   259  		go func() {
   260  			_, pErr = pw.Write(firstPayload)
   261  			wg.Done()
   262  		}()
   263  
   264  		resp, err := h2clientConn.RoundTrip(req)
   265  		if err != nil {
   266  			rawConn.Close()
   267  			return nil, err
   268  		}
   269  
   270  		wg.Wait()
   271  		if pErr != nil {
   272  			rawConn.Close()
   273  			return nil, pErr
   274  		}
   275  
   276  		if resp.StatusCode != http.StatusOK {
   277  			rawConn.Close()
   278  			return nil, newError("Proxy responded with non 200 code: " + resp.Status)
   279  		}
   280  		return newHTTP2Conn(rawConn, pw, resp.Body), nil
   281  	}
   282  
   283  	cachedH2Mutex.Lock()
   284  	cachedConn, cachedConnFound := cachedH2Conns[dest]
   285  	cachedH2Mutex.Unlock()
   286  
   287  	if cachedConnFound {
   288  		rc, cc := cachedConn.rawConn, cachedConn.h2Conn
   289  		if cc.CanTakeNewRequest() {
   290  			proxyConn, err := connectHTTP2(rc, cc)
   291  			if err != nil {
   292  				return nil, err
   293  			}
   294  
   295  			return proxyConn, nil
   296  		}
   297  	}
   298  
   299  	rawConn, err := dialer.Dial(ctx, dest)
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  
   304  	iConn := rawConn
   305  	if statConn, ok := iConn.(*stat.CounterConnection); ok {
   306  		iConn = statConn.Connection
   307  	}
   308  
   309  	nextProto := ""
   310  	if tlsConn, ok := iConn.(*tls.Conn); ok {
   311  		if err := tlsConn.HandshakeContext(ctx); err != nil {
   312  			rawConn.Close()
   313  			return nil, err
   314  		}
   315  		nextProto = tlsConn.ConnectionState().NegotiatedProtocol
   316  	}
   317  
   318  	switch nextProto {
   319  	case "", "http/1.1":
   320  		return connectHTTP1(rawConn)
   321  	case "h2":
   322  		t := http2.Transport{}
   323  		h2clientConn, err := t.NewClientConn(rawConn)
   324  		if err != nil {
   325  			rawConn.Close()
   326  			return nil, err
   327  		}
   328  
   329  		proxyConn, err := connectHTTP2(rawConn, h2clientConn)
   330  		if err != nil {
   331  			rawConn.Close()
   332  			return nil, err
   333  		}
   334  
   335  		cachedH2Mutex.Lock()
   336  		if cachedH2Conns == nil {
   337  			cachedH2Conns = make(map[net.Destination]h2Conn)
   338  		}
   339  
   340  		cachedH2Conns[dest] = h2Conn{
   341  			rawConn: rawConn,
   342  			h2Conn:  h2clientConn,
   343  		}
   344  		cachedH2Mutex.Unlock()
   345  
   346  		return proxyConn, err
   347  	default:
   348  		return nil, newError("negotiated unsupported application layer protocol: " + nextProto)
   349  	}
   350  }
   351  
   352  func newHTTP2Conn(c net.Conn, pipedReqBody *io.PipeWriter, respBody io.ReadCloser) net.Conn {
   353  	return &http2Conn{Conn: c, in: pipedReqBody, out: respBody}
   354  }
   355  
   356  type http2Conn struct {
   357  	net.Conn
   358  	in  *io.PipeWriter
   359  	out io.ReadCloser
   360  }
   361  
   362  func (h *http2Conn) Read(p []byte) (n int, err error) {
   363  	return h.out.Read(p)
   364  }
   365  
   366  func (h *http2Conn) Write(p []byte) (n int, err error) {
   367  	return h.in.Write(p)
   368  }
   369  
   370  func (h *http2Conn) Close() error {
   371  	h.in.Close()
   372  	return h.out.Close()
   373  }
   374  
   375  func init() {
   376  	common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   377  		return NewClient(ctx, config.(*ClientConfig))
   378  	}))
   379  }