github.com/gocuntian/go@v0.0.0-20160610041250-fee02d270bf8/src/net/http/httputil/reverseproxy.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // HTTP reverse proxy handler
     6  
     7  package httputil
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/url"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  // onExitFlushLoop is a callback set by tests to detect the state of the
    21  // flushLoop() goroutine.
    22  var onExitFlushLoop func()
    23  
    24  // ReverseProxy is an HTTP Handler that takes an incoming request and
    25  // sends it to another server, proxying the response back to the
    26  // client.
    27  type ReverseProxy struct {
    28  	// Director must be a function which modifies
    29  	// the request into a new request to be sent
    30  	// using Transport. Its response is then copied
    31  	// back to the original client unmodified.
    32  	Director func(*http.Request)
    33  
    34  	// The transport used to perform proxy requests.
    35  	// If nil, http.DefaultTransport is used.
    36  	Transport http.RoundTripper
    37  
    38  	// FlushInterval specifies the flush interval
    39  	// to flush to the client while copying the
    40  	// response body.
    41  	// If zero, no periodic flushing is done.
    42  	FlushInterval time.Duration
    43  
    44  	// ErrorLog specifies an optional logger for errors
    45  	// that occur when attempting to proxy the request.
    46  	// If nil, logging goes to os.Stderr via the log package's
    47  	// standard logger.
    48  	ErrorLog *log.Logger
    49  
    50  	// BufferPool optionally specifies a buffer pool to
    51  	// get byte slices for use by io.CopyBuffer when
    52  	// copying HTTP response bodies.
    53  	BufferPool BufferPool
    54  }
    55  
    56  // A BufferPool is an interface for getting and returning temporary
    57  // byte slices for use by io.CopyBuffer.
    58  type BufferPool interface {
    59  	Get() []byte
    60  	Put([]byte)
    61  }
    62  
    63  func singleJoiningSlash(a, b string) string {
    64  	aslash := strings.HasSuffix(a, "/")
    65  	bslash := strings.HasPrefix(b, "/")
    66  	switch {
    67  	case aslash && bslash:
    68  		return a + b[1:]
    69  	case !aslash && !bslash:
    70  		return a + "/" + b
    71  	}
    72  	return a + b
    73  }
    74  
    75  // NewSingleHostReverseProxy returns a new ReverseProxy that routes
    76  // URLs to the scheme, host, and base path provided in target. If the
    77  // target's path is "/base" and the incoming request was for "/dir",
    78  // the target request will be for /base/dir.
    79  // NewSingleHostReverseProxy does not rewrite the Host header.
    80  // To rewrite Host headers, use ReverseProxy directly with a custom
    81  // Director policy.
    82  func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
    83  	targetQuery := target.RawQuery
    84  	director := func(req *http.Request) {
    85  		req.URL.Scheme = target.Scheme
    86  		req.URL.Host = target.Host
    87  		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
    88  		if targetQuery == "" || req.URL.RawQuery == "" {
    89  			req.URL.RawQuery = targetQuery + req.URL.RawQuery
    90  		} else {
    91  			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
    92  		}
    93  		if _, ok := req.Header["User-Agent"]; !ok {
    94  			// explicitly disable User-Agent so it's not set to default value
    95  			req.Header.Set("User-Agent", "")
    96  		}
    97  	}
    98  	return &ReverseProxy{Director: director}
    99  }
   100  
   101  func copyHeader(dst, src http.Header) {
   102  	for k, vv := range src {
   103  		for _, v := range vv {
   104  			dst.Add(k, v)
   105  		}
   106  	}
   107  }
   108  
   109  // Hop-by-hop headers. These are removed when sent to the backend.
   110  // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
   111  var hopHeaders = []string{
   112  	"Connection",
   113  	"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
   114  	"Keep-Alive",
   115  	"Proxy-Authenticate",
   116  	"Proxy-Authorization",
   117  	"Te",      // canonicalized version of "TE"
   118  	"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
   119  	"Transfer-Encoding",
   120  	"Upgrade",
   121  }
   122  
   123  type requestCanceler interface {
   124  	CancelRequest(*http.Request)
   125  }
   126  
   127  type runOnFirstRead struct {
   128  	io.Reader // optional; nil means empty body
   129  
   130  	fn func() // Run before first Read, then set to nil
   131  }
   132  
   133  func (c *runOnFirstRead) Read(bs []byte) (int, error) {
   134  	if c.fn != nil {
   135  		c.fn()
   136  		c.fn = nil
   137  	}
   138  	if c.Reader == nil {
   139  		return 0, io.EOF
   140  	}
   141  	return c.Reader.Read(bs)
   142  }
   143  
   144  func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   145  	transport := p.Transport
   146  	if transport == nil {
   147  		transport = http.DefaultTransport
   148  	}
   149  
   150  	outreq := new(http.Request)
   151  	*outreq = *req // includes shallow copies of maps, but okay
   152  
   153  	if closeNotifier, ok := rw.(http.CloseNotifier); ok {
   154  		if requestCanceler, ok := transport.(requestCanceler); ok {
   155  			reqDone := make(chan struct{})
   156  			defer close(reqDone)
   157  
   158  			clientGone := closeNotifier.CloseNotify()
   159  
   160  			outreq.Body = struct {
   161  				io.Reader
   162  				io.Closer
   163  			}{
   164  				Reader: &runOnFirstRead{
   165  					Reader: outreq.Body,
   166  					fn: func() {
   167  						go func() {
   168  							select {
   169  							case <-clientGone:
   170  								requestCanceler.CancelRequest(outreq)
   171  							case <-reqDone:
   172  							}
   173  						}()
   174  					},
   175  				},
   176  				Closer: outreq.Body,
   177  			}
   178  		}
   179  	}
   180  
   181  	p.Director(outreq)
   182  	outreq.Proto = "HTTP/1.1"
   183  	outreq.ProtoMajor = 1
   184  	outreq.ProtoMinor = 1
   185  	outreq.Close = false
   186  
   187  	// Remove hop-by-hop headers to the backend. Especially
   188  	// important is "Connection" because we want a persistent
   189  	// connection, regardless of what the client sent to us. This
   190  	// is modifying the same underlying map from req (shallow
   191  	// copied above) so we only copy it if necessary.
   192  	copiedHeaders := false
   193  	for _, h := range hopHeaders {
   194  		if outreq.Header.Get(h) != "" {
   195  			if !copiedHeaders {
   196  				outreq.Header = make(http.Header)
   197  				copyHeader(outreq.Header, req.Header)
   198  				copiedHeaders = true
   199  			}
   200  			outreq.Header.Del(h)
   201  		}
   202  	}
   203  
   204  	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
   205  		// If we aren't the first proxy retain prior
   206  		// X-Forwarded-For information as a comma+space
   207  		// separated list and fold multiple headers into one.
   208  		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
   209  			clientIP = strings.Join(prior, ", ") + ", " + clientIP
   210  		}
   211  		outreq.Header.Set("X-Forwarded-For", clientIP)
   212  	}
   213  
   214  	res, err := transport.RoundTrip(outreq)
   215  	if err != nil {
   216  		p.logf("http: proxy error: %v", err)
   217  		rw.WriteHeader(http.StatusBadGateway)
   218  		return
   219  	}
   220  
   221  	for _, h := range hopHeaders {
   222  		res.Header.Del(h)
   223  	}
   224  
   225  	copyHeader(rw.Header(), res.Header)
   226  
   227  	// The "Trailer" header isn't included in the Transport's response,
   228  	// at least for *http.Transport. Build it up from Trailer.
   229  	if len(res.Trailer) > 0 {
   230  		var trailerKeys []string
   231  		for k := range res.Trailer {
   232  			trailerKeys = append(trailerKeys, k)
   233  		}
   234  		rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
   235  	}
   236  
   237  	rw.WriteHeader(res.StatusCode)
   238  	if len(res.Trailer) > 0 {
   239  		// Force chunking if we saw a response trailer.
   240  		// This prevents net/http from calculating the length for short
   241  		// bodies and adding a Content-Length.
   242  		if fl, ok := rw.(http.Flusher); ok {
   243  			fl.Flush()
   244  		}
   245  	}
   246  	p.copyResponse(rw, res.Body)
   247  	res.Body.Close() // close now, instead of defer, to populate res.Trailer
   248  	copyHeader(rw.Header(), res.Trailer)
   249  }
   250  
   251  func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
   252  	if p.FlushInterval != 0 {
   253  		if wf, ok := dst.(writeFlusher); ok {
   254  			mlw := &maxLatencyWriter{
   255  				dst:     wf,
   256  				latency: p.FlushInterval,
   257  				done:    make(chan bool),
   258  			}
   259  			go mlw.flushLoop()
   260  			defer mlw.stop()
   261  			dst = mlw
   262  		}
   263  	}
   264  
   265  	var buf []byte
   266  	if p.BufferPool != nil {
   267  		buf = p.BufferPool.Get()
   268  	}
   269  	io.CopyBuffer(dst, src, buf)
   270  	if p.BufferPool != nil {
   271  		p.BufferPool.Put(buf)
   272  	}
   273  }
   274  
   275  func (p *ReverseProxy) logf(format string, args ...interface{}) {
   276  	if p.ErrorLog != nil {
   277  		p.ErrorLog.Printf(format, args...)
   278  	} else {
   279  		log.Printf(format, args...)
   280  	}
   281  }
   282  
   283  type writeFlusher interface {
   284  	io.Writer
   285  	http.Flusher
   286  }
   287  
   288  type maxLatencyWriter struct {
   289  	dst     writeFlusher
   290  	latency time.Duration
   291  
   292  	mu   sync.Mutex // protects Write + Flush
   293  	done chan bool
   294  }
   295  
   296  func (m *maxLatencyWriter) Write(p []byte) (int, error) {
   297  	m.mu.Lock()
   298  	defer m.mu.Unlock()
   299  	return m.dst.Write(p)
   300  }
   301  
   302  func (m *maxLatencyWriter) flushLoop() {
   303  	t := time.NewTicker(m.latency)
   304  	defer t.Stop()
   305  	for {
   306  		select {
   307  		case <-m.done:
   308  			if onExitFlushLoop != nil {
   309  				onExitFlushLoop()
   310  			}
   311  			return
   312  		case <-t.C:
   313  			m.mu.Lock()
   314  			m.dst.Flush()
   315  			m.mu.Unlock()
   316  		}
   317  	}
   318  }
   319  
   320  func (m *maxLatencyWriter) stop() { m.done <- true }