github.com/google/martian/v3@v3.3.3/proxy.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package martian
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"crypto/tls"
    21  	"errors"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"net/http/httputil"
    26  	"net/url"
    27  	"regexp"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/google/martian/v3/log"
    32  	"github.com/google/martian/v3/mitm"
    33  	"github.com/google/martian/v3/nosigpipe"
    34  	"github.com/google/martian/v3/proxyutil"
    35  	"github.com/google/martian/v3/trafficshape"
    36  )
    37  
    38  var errClose = errors.New("closing connection")
    39  var noop = Noop("martian")
    40  
    41  func isCloseable(err error) bool {
    42  	if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
    43  		return true
    44  	}
    45  
    46  	switch err {
    47  	case io.EOF, io.ErrClosedPipe, errClose:
    48  		return true
    49  	}
    50  
    51  	return false
    52  }
    53  
    54  // Proxy is an HTTP proxy with support for TLS MITM and customizable behavior.
    55  type Proxy struct {
    56  	roundTripper http.RoundTripper
    57  	dial         func(string, string) (net.Conn, error)
    58  	timeout      time.Duration
    59  	mitm         *mitm.Config
    60  	proxyURL     *url.URL
    61  	conns        sync.WaitGroup
    62  	connsMu      sync.Mutex // protects conns.Add/Wait from concurrent access
    63  	closing      chan bool
    64  
    65  	reqmod RequestModifier
    66  	resmod ResponseModifier
    67  }
    68  
    69  // NewProxy returns a new HTTP proxy.
    70  func NewProxy() *Proxy {
    71  	proxy := &Proxy{
    72  		roundTripper: &http.Transport{
    73  			// TODO(adamtanner): This forces the http.Transport to not upgrade requests
    74  			// to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2.
    75  			TLSNextProto:          make(map[string]func(string, *tls.Conn) http.RoundTripper),
    76  			Proxy:                 http.ProxyFromEnvironment,
    77  			TLSHandshakeTimeout:   10 * time.Second,
    78  			ExpectContinueTimeout: time.Second,
    79  		},
    80  		timeout: 5 * time.Minute,
    81  		closing: make(chan bool),
    82  		reqmod:  noop,
    83  		resmod:  noop,
    84  	}
    85  	proxy.SetDial((&net.Dialer{
    86  		Timeout:   30 * time.Second,
    87  		KeepAlive: 30 * time.Second,
    88  	}).Dial)
    89  	return proxy
    90  }
    91  
    92  // GetRoundTripper gets the http.RoundTripper of the proxy.
    93  func (p *Proxy) GetRoundTripper() http.RoundTripper {
    94  	return p.roundTripper
    95  }
    96  
    97  // SetRoundTripper sets the http.RoundTripper of the proxy.
    98  func (p *Proxy) SetRoundTripper(rt http.RoundTripper) {
    99  	p.roundTripper = rt
   100  
   101  	if tr, ok := p.roundTripper.(*http.Transport); ok {
   102  		tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
   103  		tr.Proxy = http.ProxyURL(p.proxyURL)
   104  		tr.Dial = p.dial
   105  	}
   106  }
   107  
   108  // SetDownstreamProxy sets the proxy that receives requests from the upstream
   109  // proxy.
   110  func (p *Proxy) SetDownstreamProxy(proxyURL *url.URL) {
   111  	p.proxyURL = proxyURL
   112  
   113  	if tr, ok := p.roundTripper.(*http.Transport); ok {
   114  		tr.Proxy = http.ProxyURL(p.proxyURL)
   115  	}
   116  }
   117  
   118  // SetTimeout sets the request timeout of the proxy.
   119  func (p *Proxy) SetTimeout(timeout time.Duration) {
   120  	p.timeout = timeout
   121  }
   122  
   123  // SetMITM sets the config to use for MITMing of CONNECT requests.
   124  func (p *Proxy) SetMITM(config *mitm.Config) {
   125  	p.mitm = config
   126  }
   127  
   128  // SetDial sets the dial func used to establish a connection.
   129  func (p *Proxy) SetDial(dial func(string, string) (net.Conn, error)) {
   130  	p.dial = func(a, b string) (net.Conn, error) {
   131  		c, e := dial(a, b)
   132  		nosigpipe.IgnoreSIGPIPE(c)
   133  		return c, e
   134  	}
   135  
   136  	if tr, ok := p.roundTripper.(*http.Transport); ok {
   137  		tr.Dial = p.dial
   138  	}
   139  }
   140  
   141  // Close sets the proxy to the closing state so it stops receiving new connections,
   142  // finishes processing any inflight requests, and closes existing connections without
   143  // reading anymore requests from them.
   144  func (p *Proxy) Close() {
   145  	log.Infof("martian: closing down proxy")
   146  
   147  	close(p.closing)
   148  
   149  	log.Infof("martian: waiting for connections to close")
   150  	p.connsMu.Lock()
   151  	p.conns.Wait()
   152  	p.connsMu.Unlock()
   153  	log.Infof("martian: all connections closed")
   154  }
   155  
   156  // Closing returns whether the proxy is in the closing state.
   157  func (p *Proxy) Closing() bool {
   158  	select {
   159  	case <-p.closing:
   160  		return true
   161  	default:
   162  		return false
   163  	}
   164  }
   165  
   166  // SetRequestModifier sets the request modifier.
   167  func (p *Proxy) SetRequestModifier(reqmod RequestModifier) {
   168  	if reqmod == nil {
   169  		reqmod = noop
   170  	}
   171  
   172  	p.reqmod = reqmod
   173  }
   174  
   175  // SetResponseModifier sets the response modifier.
   176  func (p *Proxy) SetResponseModifier(resmod ResponseModifier) {
   177  	if resmod == nil {
   178  		resmod = noop
   179  	}
   180  
   181  	p.resmod = resmod
   182  }
   183  
   184  // Serve accepts connections from the listener and handles the requests.
   185  func (p *Proxy) Serve(l net.Listener) error {
   186  	defer l.Close()
   187  
   188  	var delay time.Duration
   189  	for {
   190  		if p.Closing() {
   191  			return nil
   192  		}
   193  
   194  		conn, err := l.Accept()
   195  		nosigpipe.IgnoreSIGPIPE(conn)
   196  		if err != nil {
   197  			if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
   198  				if delay == 0 {
   199  					delay = 5 * time.Millisecond
   200  				} else {
   201  					delay *= 2
   202  				}
   203  				if max := time.Second; delay > max {
   204  					delay = max
   205  				}
   206  
   207  				log.Debugf("martian: temporary error on accept: %v", err)
   208  				time.Sleep(delay)
   209  				continue
   210  			}
   211  
   212  			if errors.Is(err, net.ErrClosed) {
   213  				log.Debugf("martian: listener closed, returning")
   214  				return err
   215  			}
   216  
   217  			log.Errorf("martian: failed to accept: %v", err)
   218  			return err
   219  		}
   220  		delay = 0
   221  		log.Debugf("martian: accepted connection from %s", conn.RemoteAddr())
   222  
   223  		if tconn, ok := conn.(*net.TCPConn); ok {
   224  			tconn.SetKeepAlive(true)
   225  			tconn.SetKeepAlivePeriod(3 * time.Minute)
   226  		}
   227  
   228  		go p.handleLoop(conn)
   229  	}
   230  }
   231  
   232  func (p *Proxy) handleLoop(conn net.Conn) {
   233  	p.connsMu.Lock()
   234  	p.conns.Add(1)
   235  	p.connsMu.Unlock()
   236  	defer p.conns.Done()
   237  	defer conn.Close()
   238  	if p.Closing() {
   239  		return
   240  	}
   241  
   242  	brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
   243  
   244  	s, err := newSession(conn, brw)
   245  	if err != nil {
   246  		log.Errorf("martian: failed to create session: %v", err)
   247  		return
   248  	}
   249  
   250  	ctx, err := withSession(s)
   251  	if err != nil {
   252  		log.Errorf("martian: failed to create context: %v", err)
   253  		return
   254  	}
   255  
   256  	for {
   257  		deadline := time.Now().Add(p.timeout)
   258  		conn.SetDeadline(deadline)
   259  
   260  		if err := p.handle(ctx, conn, brw); isCloseable(err) {
   261  			log.Debugf("martian: closing connection: %v", conn.RemoteAddr())
   262  			return
   263  		}
   264  	}
   265  }
   266  
   267  func (p *Proxy) readRequest(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) (*http.Request, error) {
   268  	var req *http.Request
   269  	reqc := make(chan *http.Request, 1)
   270  	errc := make(chan error, 1)
   271  	go func() {
   272  		r, err := http.ReadRequest(brw.Reader)
   273  		if err != nil {
   274  			errc <- err
   275  			return
   276  		}
   277  		reqc <- r
   278  	}()
   279  	select {
   280  	case err := <-errc:
   281  		if isCloseable(err) {
   282  			log.Debugf("martian: connection closed prematurely: %v", err)
   283  		} else {
   284  			log.Errorf("martian: failed to read request: %v", err)
   285  		}
   286  
   287  		// TODO: TCPConn.WriteClose() to avoid sending an RST to the client.
   288  
   289  		return nil, errClose
   290  	case req = <-reqc:
   291  	case <-p.closing:
   292  		return nil, errClose
   293  	}
   294  
   295  	return req, nil
   296  }
   297  
   298  func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *Session, brw *bufio.ReadWriter, conn net.Conn) error {
   299  	if err := p.reqmod.ModifyRequest(req); err != nil {
   300  		log.Errorf("martian: error modifying CONNECT request: %v", err)
   301  		proxyutil.Warning(req.Header, err)
   302  	}
   303  	if session.Hijacked() {
   304  		log.Debugf("martian: connection hijacked by request modifier")
   305  		return nil
   306  	}
   307  
   308  	if p.mitm != nil {
   309  		log.Debugf("martian: attempting MITM for connection: %s / %s", req.Host, req.URL.String())
   310  
   311  		res := proxyutil.NewResponse(200, nil, req)
   312  
   313  		if err := p.resmod.ModifyResponse(res); err != nil {
   314  			log.Errorf("martian: error modifying CONNECT response: %v", err)
   315  			proxyutil.Warning(res.Header, err)
   316  		}
   317  		if session.Hijacked() {
   318  			log.Infof("martian: connection hijacked by response modifier")
   319  			return nil
   320  		}
   321  
   322  		if err := res.Write(brw); err != nil {
   323  			log.Errorf("martian: got error while writing response back to client: %v", err)
   324  		}
   325  		if err := brw.Flush(); err != nil {
   326  			log.Errorf("martian: got error while flushing response back to client: %v", err)
   327  		}
   328  
   329  		log.Debugf("martian: completed MITM for connection: %s", req.Host)
   330  
   331  		b := make([]byte, 1)
   332  		if _, err := brw.Read(b); err != nil {
   333  			log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err)
   334  		}
   335  
   336  		// Drain all of the rest of the buffered data.
   337  		buf := make([]byte, brw.Reader.Buffered())
   338  		brw.Read(buf)
   339  
   340  		// 22 is the TLS handshake.
   341  		// https://tools.ietf.org/html/rfc5246#section-6.2.1
   342  		if b[0] == 22 {
   343  			// Prepend the previously read data to be read again by
   344  			// http.ReadRequest.
   345  			tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host))
   346  
   347  			if err := tlsconn.Handshake(); err != nil {
   348  				p.mitm.HandshakeErrorCallback(req, err)
   349  				return err
   350  			}
   351  			if tlsconn.ConnectionState().NegotiatedProtocol == "h2" {
   352  				return p.mitm.H2Config().Proxy(p.closing, tlsconn, req.URL)
   353  			}
   354  
   355  			var nconn net.Conn
   356  			nconn = tlsconn
   357  			// If the original connection is a traffic shaped connection, wrap the tls
   358  			// connection inside a traffic shaped connection too.
   359  			if ptsconn, ok := conn.(*trafficshape.Conn); ok {
   360  				nconn = ptsconn.Listener.GetTrafficShapedConn(tlsconn)
   361  			}
   362  			brw.Writer.Reset(nconn)
   363  			brw.Reader.Reset(nconn)
   364  			return p.handle(ctx, nconn, brw)
   365  		}
   366  
   367  		// Prepend the previously read data to be read again by http.ReadRequest.
   368  		brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn))
   369  		return p.handle(ctx, conn, brw)
   370  	}
   371  
   372  	log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
   373  	res, cconn, cerr := p.connect(req)
   374  	if cerr != nil {
   375  		log.Errorf("martian: failed to CONNECT: %v", cerr)
   376  		res = proxyutil.NewResponse(502, nil, req)
   377  		proxyutil.Warning(res.Header, cerr)
   378  
   379  		if err := p.resmod.ModifyResponse(res); err != nil {
   380  			log.Errorf("martian: error modifying CONNECT response: %v", err)
   381  			proxyutil.Warning(res.Header, err)
   382  		}
   383  		if session.Hijacked() {
   384  			log.Infof("martian: connection hijacked by response modifier")
   385  			return nil
   386  		}
   387  
   388  		if err := res.Write(brw); err != nil {
   389  			log.Errorf("martian: got error while writing response back to client: %v", err)
   390  		}
   391  		err := brw.Flush()
   392  		if err != nil {
   393  			log.Errorf("martian: got error while flushing response back to client: %v", err)
   394  		}
   395  		return err
   396  	}
   397  	defer res.Body.Close()
   398  	defer cconn.Close()
   399  
   400  	if err := p.resmod.ModifyResponse(res); err != nil {
   401  		log.Errorf("martian: error modifying CONNECT response: %v", err)
   402  		proxyutil.Warning(res.Header, err)
   403  	}
   404  	if session.Hijacked() {
   405  		log.Infof("martian: connection hijacked by response modifier")
   406  		return nil
   407  	}
   408  
   409  	res.ContentLength = -1
   410  	if err := res.Write(brw); err != nil {
   411  		log.Errorf("martian: got error while writing response back to client: %v", err)
   412  	}
   413  	if err := brw.Flush(); err != nil {
   414  		log.Errorf("martian: got error while flushing response back to client: %v", err)
   415  	}
   416  
   417  	cbw := bufio.NewWriter(cconn)
   418  	cbr := bufio.NewReader(cconn)
   419  	defer cbw.Flush()
   420  
   421  	copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
   422  		if _, err := io.Copy(w, r); err != nil && err != io.EOF {
   423  			log.Errorf("martian: failed to copy CONNECT tunnel: %v", err)
   424  		}
   425  
   426  		log.Debugf("martian: CONNECT tunnel finished copying")
   427  		donec <- true
   428  	}
   429  
   430  	donec := make(chan bool, 2)
   431  	go copySync(cbw, brw, donec)
   432  	go copySync(brw, cbr, donec)
   433  
   434  	log.Debugf("martian: established CONNECT tunnel, proxying traffic")
   435  	<-donec
   436  	<-donec
   437  	log.Debugf("martian: closed CONNECT tunnel")
   438  
   439  	return errClose
   440  }
   441  
   442  func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
   443  	log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())
   444  
   445  	req, err := p.readRequest(ctx, conn, brw)
   446  	if err != nil {
   447  		return err
   448  	}
   449  	defer req.Body.Close()
   450  
   451  	session := ctx.Session()
   452  	ctx, err = withSession(session)
   453  	if err != nil {
   454  		log.Errorf("martian: failed to build new context: %v", err)
   455  		return err
   456  	}
   457  
   458  	link(req, ctx)
   459  	defer unlink(req)
   460  
   461  	if tsconn, ok := conn.(*trafficshape.Conn); ok {
   462  		wrconn := tsconn.GetWrappedConn()
   463  		if sconn, ok := wrconn.(*tls.Conn); ok {
   464  			session.MarkSecure()
   465  
   466  			cs := sconn.ConnectionState()
   467  			req.TLS = &cs
   468  		}
   469  	}
   470  
   471  	if tconn, ok := conn.(*tls.Conn); ok {
   472  		session.MarkSecure()
   473  
   474  		cs := tconn.ConnectionState()
   475  		req.TLS = &cs
   476  	}
   477  
   478  	req.URL.Scheme = "http"
   479  	if session.IsSecure() {
   480  		log.Infof("martian: forcing HTTPS inside secure session")
   481  		req.URL.Scheme = "https"
   482  	}
   483  
   484  	req.RemoteAddr = conn.RemoteAddr().String()
   485  	if req.URL.Host == "" {
   486  		req.URL.Host = req.Host
   487  	}
   488  
   489  	if req.Method == "CONNECT" {
   490  		return p.handleConnectRequest(ctx, req, session, brw, conn)
   491  	}
   492  
   493  	// Not a CONNECT request
   494  	if err := p.reqmod.ModifyRequest(req); err != nil {
   495  		log.Errorf("martian: error modifying request: %v", err)
   496  		proxyutil.Warning(req.Header, err)
   497  	}
   498  	if session.Hijacked() {
   499  		return nil
   500  	}
   501  
   502  	// perform the HTTP roundtrip
   503  	res, err := p.roundTrip(ctx, req)
   504  	if err != nil {
   505  		log.Errorf("martian: failed to round trip: %v", err)
   506  		res = proxyutil.NewResponse(502, nil, req)
   507  		proxyutil.Warning(res.Header, err)
   508  	}
   509  	defer res.Body.Close()
   510  
   511  	// set request to original request manually, res.Request may be changed in transport.
   512  	// see https://github.com/google/martian/issues/298
   513  	res.Request = req
   514  
   515  	if err := p.resmod.ModifyResponse(res); err != nil {
   516  		log.Errorf("martian: error modifying response: %v", err)
   517  		proxyutil.Warning(res.Header, err)
   518  	}
   519  	if session.Hijacked() {
   520  		log.Infof("martian: connection hijacked by response modifier")
   521  		return nil
   522  	}
   523  
   524  	var closing error
   525  	if req.Close || res.Close || p.Closing() {
   526  		log.Debugf("martian: received close request: %v", req.RemoteAddr)
   527  		res.Close = true
   528  		closing = errClose
   529  	}
   530  
   531  	// check if conn is a traffic shaped connection.
   532  	if ptsconn, ok := conn.(*trafficshape.Conn); ok {
   533  		ptsconn.Context = &trafficshape.Context{}
   534  		// Check if the request URL matches any URLRegex in Shapes. If so, set the connections's Context
   535  		// with the required information, so that the Write() method of the Conn has access to it.
   536  		for urlregex, buckets := range ptsconn.LocalBuckets {
   537  			if match, _ := regexp.MatchString(urlregex, req.URL.String()); match {
   538  				if rangeStart := proxyutil.GetRangeStart(res); rangeStart > -1 {
   539  					dump, err := httputil.DumpResponse(res, false)
   540  					if err != nil {
   541  						return err
   542  					}
   543  					ptsconn.Context = &trafficshape.Context{
   544  						Shaping:            true,
   545  						Buckets:            buckets,
   546  						GlobalBucket:       ptsconn.GlobalBuckets[urlregex],
   547  						URLRegex:           urlregex,
   548  						RangeStart:         rangeStart,
   549  						ByteOffset:         rangeStart,
   550  						HeaderLen:          int64(len(dump)),
   551  						HeaderBytesWritten: 0,
   552  					}
   553  					// Get the next action to perform, if there.
   554  					ptsconn.Context.NextActionInfo = ptsconn.GetNextActionFromByte(rangeStart)
   555  					// Check if response lies in a throttled byte range.
   556  					ptsconn.Context.ThrottleContext = ptsconn.GetCurrentThrottle(rangeStart)
   557  					if ptsconn.Context.ThrottleContext.ThrottleNow {
   558  						ptsconn.Context.Buckets.WriteBucket.SetCapacity(
   559  							ptsconn.Context.ThrottleContext.Bandwidth)
   560  					}
   561  					log.Infof(
   562  						"trafficshape: Request %s with Range Start: %d matches a Shaping request %s. Enforcing Traffic shaping.",
   563  						req.URL, rangeStart, urlregex)
   564  				}
   565  				break
   566  			}
   567  		}
   568  	}
   569  
   570  	err = res.Write(brw)
   571  	if err != nil {
   572  		log.Errorf("martian: got error while writing response back to client: %v", err)
   573  		if _, ok := err.(*trafficshape.ErrForceClose); ok {
   574  			closing = errClose
   575  		}
   576  	}
   577  	err = brw.Flush()
   578  	if err != nil {
   579  		log.Errorf("martian: got error while flushing response back to client: %v", err)
   580  		if _, ok := err.(*trafficshape.ErrForceClose); ok {
   581  			closing = errClose
   582  		}
   583  	}
   584  	return closing
   585  }
   586  
   587  // A peekedConn subverts the net.Conn.Read implementation, primarily so that
   588  // sniffed bytes can be transparently prepended.
   589  type peekedConn struct {
   590  	net.Conn
   591  	r io.Reader
   592  }
   593  
   594  // Read allows control over the embedded net.Conn's read data. By using an
   595  // io.MultiReader one can read from a conn, and then replace what they read, to
   596  // be read again.
   597  func (c *peekedConn) Read(buf []byte) (int, error) { return c.r.Read(buf) }
   598  
   599  func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) {
   600  	if ctx.SkippingRoundTrip() {
   601  		log.Debugf("martian: skipping round trip")
   602  		return proxyutil.NewResponse(200, nil, req), nil
   603  	}
   604  
   605  	return p.roundTripper.RoundTrip(req)
   606  }
   607  
   608  func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
   609  	if p.proxyURL != nil {
   610  		log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host)
   611  
   612  		conn, err := p.dial("tcp", p.proxyURL.Host)
   613  		if err != nil {
   614  			return nil, nil, err
   615  		}
   616  		pbw := bufio.NewWriter(conn)
   617  		pbr := bufio.NewReader(conn)
   618  
   619  		req.Write(pbw)
   620  		pbw.Flush()
   621  
   622  		res, err := http.ReadResponse(pbr, req)
   623  		if err != nil {
   624  			return nil, nil, err
   625  		}
   626  
   627  		return res, conn, nil
   628  	}
   629  
   630  	log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host)
   631  
   632  	conn, err := p.dial("tcp", req.URL.Host)
   633  	if err != nil {
   634  		return nil, nil, err
   635  	}
   636  
   637  	return proxyutil.NewResponse(200, nil, req), conn, nil
   638  }