
     1  package http
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"encoding/base64"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"sync"
    12  	"time"
    14  	""
    16  	core ""
    17  	""
    18  	""
    19  	""
    20  	""
    21  	""
    22  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  	""
    28  	""
    29  	""
    30  	""
    31  )
    33  type Client struct {
    34  	serverPicker       protocol.ServerPicker
    35  	policyManager      policy.Manager
    36  	h1SkipWaitForReply bool
    37  }
    39  type h2Conn struct {
    40  	rawConn net.Conn
    41  	h2Conn  *http2.ClientConn
    42  }
    44  var (
    45  	cachedH2Mutex sync.Mutex
    46  	cachedH2Conns map[net.Destination]h2Conn
    47  )
    49  // NewClient create a new http client based on the given config.
    50  func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
    51  	serverList := protocol.NewServerList()
    52  	for _, rec := range config.Server {
    53  		s, err := protocol.NewServerSpecFromPB(rec)
    54  		if err != nil {
    55  			return nil, newError("failed to get server spec").Base(err)
    56  		}
    57  		serverList.AddServer(s)
    58  	}
    59  	if serverList.Size() == 0 {
    60  		return nil, newError("0 target server")
    61  	}
    63  	v := core.MustFromContext(ctx)
    64  	return &Client{
    65  		serverPicker:       protocol.NewRoundRobinServerPicker(serverList),
    66  		policyManager:      v.GetFeature(policy.ManagerType()).(policy.Manager),
    67  		h1SkipWaitForReply: config.H1SkipWaitForReply,
    68  	}, nil
    69  }
    71  // Process implements proxy.Outbound.Process. We first create a socket tunnel via HTTP CONNECT method, then redirect all inbound traffic to that tunnel.
    72  func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    73  	outbound := session.OutboundFromContext(ctx)
    74  	if outbound == nil || !outbound.Target.IsValid() {
    75  		return newError("target not specified.")
    76  	}
    77  	target := outbound.Target
    78  	targetAddr := target.NetAddr()
    80  	if target.Network == net.Network_UDP {
    81  		return newError("UDP is not supported by HTTP outbound")
    82  	}
    84  	var user *protocol.MemoryUser
    85  	var conn internet.Connection
    87  	var firstPayload []byte
    89  	if reader, ok := link.Reader.(buf.TimeoutReader); ok {
    90  		// 0-RTT optimization for HTTP/2: If the payload comes very soon, it can be
    91  		// transmitted together. Note we should not get stuck here, as the payload may
    92  		// not exist (considering to access MySQL database via a HTTP proxy, where the
    93  		// server sends hello to the client first).
    94  		waitTime := proxy.FirstPayloadTimeout
    95  		if c.h1SkipWaitForReply {
    96  			// Some server require first write to be present in client hello.
    97  			// Increase timeout to if the client have explicitly requested to skip waiting for reply.
    98  			waitTime = time.Second
    99  		}
   100  		if mbuf, _ := reader.ReadMultiBufferTimeout(waitTime); mbuf != nil {
   101  			mlen := mbuf.Len()
   102  			firstPayload = bytespool.Alloc(mlen)
   103  			mbuf, _ = buf.SplitBytes(mbuf, firstPayload)
   104  			firstPayload = firstPayload[:mlen]
   106  			buf.ReleaseMulti(mbuf)
   107  			defer bytespool.Free(firstPayload)
   108  		}
   109  	}
   111  	if err := retry.ExponentialBackoff(5, 100).On(func() error {
   112  		server := c.serverPicker.PickServer()
   113  		dest := server.Destination()
   114  		user = server.PickUser()
   116  		netConn, firstResp, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, firstPayload, c.h1SkipWaitForReply)
   117  		if netConn != nil {
   118  			if _, ok := netConn.(*http2Conn); !ok && !c.h1SkipWaitForReply {
   119  				if _, err := netConn.Write(firstPayload); err != nil {
   120  					netConn.Close()
   121  					return err
   122  				}
   123  			}
   124  			if firstResp != nil {
   125  				if err := link.Writer.WriteMultiBuffer(firstResp); err != nil {
   126  					return err
   127  				}
   128  			}
   129  			conn = internet.Connection(netConn)
   130  		}
   131  		return err
   132  	}); err != nil {
   133  		return newError("failed to find an available destination").Base(err)
   134  	}
   136  	defer func() {
   137  		if err := conn.Close(); err != nil {
   138  			newError("failed to closed connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   139  		}
   140  	}()
   142  	p := c.policyManager.ForLevel(0)
   143  	if user != nil {
   144  		p = c.policyManager.ForLevel(user.Level)
   145  	}
   147  	ctx, cancel := context.WithCancel(ctx)
   148  	timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
   150  	requestFunc := func() error {
   151  		defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   152  		return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   153  	}
   154  	responseFunc := func() error {
   155  		defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   156  		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   157  	}
   159  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   160  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   161  		return newError("connection ends").Base(err)
   162  	}
   164  	return nil
   165  }
   167  // setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method
   168  func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, user *protocol.MemoryUser, dialer internet.Dialer, firstPayload []byte, writeFirstPayloadInH1 bool,
   169  ) (net.Conn, buf.MultiBuffer, error) {
   170  	req := &http.Request{
   171  		Method: http.MethodConnect,
   172  		URL:    &url.URL{Host: target},
   173  		Header: make(http.Header),
   174  		Host:   target,
   175  	}
   177  	if user != nil && user.Account != nil {
   178  		account := user.Account.(*Account)
   179  		auth := account.GetUsername() + ":" + account.GetPassword()
   180  		req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
   181  	}
   183  	connectHTTP1 := func(rawConn net.Conn) (net.Conn, buf.MultiBuffer, error) {
   184  		req.Header.Set("Proxy-Connection", "Keep-Alive")
   186  		if !writeFirstPayloadInH1 {
   187  			err := req.Write(rawConn)
   188  			if err != nil {
   189  				rawConn.Close()
   190  				return nil, nil, err
   191  			}
   192  		} else {
   193  			buffer := bytes.NewBuffer(nil)
   194  			err := req.Write(buffer)
   195  			if err != nil {
   196  				rawConn.Close()
   197  				return nil, nil, err
   198  			}
   199  			_, err = io.Copy(buffer, bytes.NewReader(firstPayload))
   200  			if err != nil {
   201  				rawConn.Close()
   202  				return nil, nil, err
   203  			}
   204  			_, err = rawConn.Write(buffer.Bytes())
   205  			if err != nil {
   206  				rawConn.Close()
   207  				return nil, nil, err
   208  			}
   209  		}
   210  		bufferedReader := bufio.NewReader(rawConn)
   211  		resp, err := http.ReadResponse(bufferedReader, req)
   212  		if err != nil {
   213  			rawConn.Close()
   214  			return nil, nil, err
   215  		}
   216  		defer resp.Body.Close()
   218  		if resp.StatusCode != http.StatusOK {
   219  			rawConn.Close()
   220  			return nil, nil, newError("Proxy responded with non 200 code: " + resp.Status)
   221  		}
   222  		if bufferedReader.Buffered() > 0 {
   223  			payload, err := buf.ReadFrom(io.LimitReader(bufferedReader, int64(bufferedReader.Buffered())))
   224  			if err != nil {
   225  				return nil, nil, newError("unable to drain buffer: ").Base(err)
   226  			}
   227  			return rawConn, payload, nil
   228  		}
   229  		return rawConn, nil, nil
   230  	}
   232  	connectHTTP2 := func(rawConn net.Conn, h2clientConn *http2.ClientConn) (net.Conn, error) {
   233  		pr, pw := io.Pipe()
   234  		req.Body = pr
   236  		var pErr error
   237  		var wg sync.WaitGroup
   238  		wg.Add(1)
   240  		go func() {
   241  			_, pErr = pw.Write(firstPayload)
   242  			wg.Done()
   243  		}()
   245  		resp, err := h2clientConn.RoundTrip(req) // nolint: bodyclose
   246  		if err != nil {
   247  			rawConn.Close()
   248  			return nil, err
   249  		}
   251  		wg.Wait()
   252  		if pErr != nil {
   253  			rawConn.Close()
   254  			return nil, pErr
   255  		}
   257  		if resp.StatusCode != http.StatusOK {
   258  			rawConn.Close()
   259  			return nil, newError("Proxy responded with non 200 code: " + resp.Status)
   260  		}
   261  		return newHTTP2Conn(rawConn, pw, resp.Body), nil
   262  	}
   264  	cachedH2Mutex.Lock()
   265  	cachedConn, cachedConnFound := cachedH2Conns[dest]
   266  	cachedH2Mutex.Unlock()
   268  	if cachedConnFound {
   269  		rc, cc := cachedConn.rawConn, cachedConn.h2Conn
   270  		if cc.CanTakeNewRequest() {
   271  			proxyConn, err := connectHTTP2(rc, cc)
   272  			if err != nil {
   273  				return nil, nil, err
   274  			}
   276  			return proxyConn, nil, nil
   277  		}
   278  	}
   280  	rawConn, err := dialer.Dial(ctx, dest)
   281  	if err != nil {
   282  		return nil, nil, err
   283  	}
   285  	iConn := rawConn
   286  	if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
   287  		iConn = statConn.Connection
   288  	}
   290  	nextProto := ""
   291  	if tlsConn, ok := iConn.(*tls.Conn); ok {
   292  		if err := tlsConn.Handshake(); err != nil {
   293  			rawConn.Close()
   294  			return nil, nil, err
   295  		}
   296  		nextProto = tlsConn.ConnectionState().NegotiatedProtocol
   297  	}
   299  	switch nextProto {
   300  	case "", "http/1.1":
   301  		return connectHTTP1(rawConn)
   302  	case "h2":
   303  		t := http2.Transport{}
   304  		h2clientConn, err := t.NewClientConn(rawConn)
   305  		if err != nil {
   306  			rawConn.Close()
   307  			return nil, nil, err
   308  		}
   310  		proxyConn, err := connectHTTP2(rawConn, h2clientConn)
   311  		if err != nil {
   312  			rawConn.Close()
   313  			return nil, nil, err
   314  		}
   316  		cachedH2Mutex.Lock()
   317  		if cachedH2Conns == nil {
   318  			cachedH2Conns = make(map[net.Destination]h2Conn)
   319  		}
   321  		cachedH2Conns[dest] = h2Conn{
   322  			rawConn: rawConn,
   323  			h2Conn:  h2clientConn,
   324  		}
   325  		cachedH2Mutex.Unlock()
   327  		return proxyConn, nil, err
   328  	default:
   329  		return nil, nil, newError("negotiated unsupported application layer protocol: " + nextProto)
   330  	}
   331  }
   333  func newHTTP2Conn(c net.Conn, pipedReqBody *io.PipeWriter, respBody io.ReadCloser) net.Conn {
   334  	return &http2Conn{Conn: c, in: pipedReqBody, out: respBody}
   335  }
   337  type http2Conn struct {
   338  	net.Conn
   339  	in  *io.PipeWriter
   340  	out io.ReadCloser
   341  }
   343  func (h *http2Conn) Read(p []byte) (n int, err error) {
   344  	return h.out.Read(p)
   345  }
   347  func (h *http2Conn) Write(p []byte) (n int, err error) {
   348  	return
   349  }
   351  func (h *http2Conn) Close() error {
   353  	return h.out.Close()
   354  }
   356  func init() {
   357  	common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   358  		return NewClient(ctx, config.(*ClientConfig))
   359  	}))
   360  }