github.com/bgentry/go@v0.0.0-20150121062915-6cf5a733d54d/src/net/http/httputil/reverseproxy.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // HTTP reverse proxy handler
     6  
     7  package httputil
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/url"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  // onExitFlushLoop is a callback set by tests to detect the state of the
    21  // flushLoop() goroutine.
    22  var onExitFlushLoop func()
    23  
    24  // ReverseProxy is an HTTP Handler that takes an incoming request and
    25  // sends it to another server, proxying the response back to the
    26  // client.
    27  type ReverseProxy struct {
    28  	// Director must be a function which modifies
    29  	// the request into a new request to be sent
    30  	// using Transport. Its response is then copied
    31  	// back to the original client unmodified.
    32  	Director func(*http.Request)
    33  
    34  	// The transport used to perform proxy requests.
    35  	// If nil, http.DefaultTransport is used.
    36  	Transport http.RoundTripper
    37  
    38  	// FlushInterval specifies the flush interval
    39  	// to flush to the client while copying the
    40  	// response body.
    41  	// If zero, no periodic flushing is done.
    42  	FlushInterval time.Duration
    43  
    44  	// ErrorLog specifies an optional logger for errors
    45  	// that occur when attempting to proxy the request.
    46  	// If nil, logging goes to os.Stderr via the log package's
    47  	// standard logger.
    48  	ErrorLog *log.Logger
    49  }
    50  
    51  func singleJoiningSlash(a, b string) string {
    52  	aslash := strings.HasSuffix(a, "/")
    53  	bslash := strings.HasPrefix(b, "/")
    54  	switch {
    55  	case aslash && bslash:
    56  		return a + b[1:]
    57  	case !aslash && !bslash:
    58  		return a + "/" + b
    59  	}
    60  	return a + b
    61  }
    62  
    63  // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
    64  // URLs to the scheme, host, and base path provided in target. If the
    65  // target's path is "/base" and the incoming request was for "/dir",
    66  // the target request will be for /base/dir.
    67  func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
    68  	targetQuery := target.RawQuery
    69  	director := func(req *http.Request) {
    70  		req.URL.Scheme = target.Scheme
    71  		req.URL.Host = target.Host
    72  		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
    73  		if targetQuery == "" || req.URL.RawQuery == "" {
    74  			req.URL.RawQuery = targetQuery + req.URL.RawQuery
    75  		} else {
    76  			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
    77  		}
    78  	}
    79  	return &ReverseProxy{Director: director}
    80  }
    81  
    82  func copyHeader(dst, src http.Header) {
    83  	for k, vv := range src {
    84  		for _, v := range vv {
    85  			dst.Add(k, v)
    86  		}
    87  	}
    88  }
    89  
    90  // Hop-by-hop headers. These are removed when sent to the backend.
    91  // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
    92  var hopHeaders = []string{
    93  	"Connection",
    94  	"Keep-Alive",
    95  	"Proxy-Authenticate",
    96  	"Proxy-Authorization",
    97  	"Te", // canonicalized version of "TE"
    98  	"Trailers",
    99  	"Transfer-Encoding",
   100  	"Upgrade",
   101  }
   102  
   103  type requestCanceler interface {
   104  	CancelRequest(*http.Request)
   105  }
   106  
   107  type runOnFirstRead struct {
   108  	io.Reader
   109  
   110  	fn func() // Run before first Read, then set to nil
   111  }
   112  
   113  func (c *runOnFirstRead) Read(bs []byte) (int, error) {
   114  	if c.fn != nil {
   115  		c.fn()
   116  		c.fn = nil
   117  	}
   118  	return c.Reader.Read(bs)
   119  }
   120  
   121  func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   122  	transport := p.Transport
   123  	if transport == nil {
   124  		transport = http.DefaultTransport
   125  	}
   126  
   127  	outreq := new(http.Request)
   128  	*outreq = *req // includes shallow copies of maps, but okay
   129  
   130  	if closeNotifier, ok := rw.(http.CloseNotifier); ok {
   131  		if requestCanceler, ok := transport.(requestCanceler); ok {
   132  			reqDone := make(chan struct{})
   133  			defer close(reqDone)
   134  
   135  			clientGone := closeNotifier.CloseNotify()
   136  
   137  			outreq.Body = struct {
   138  				io.Reader
   139  				io.Closer
   140  			}{
   141  				Reader: &runOnFirstRead{
   142  					Reader: outreq.Body,
   143  					fn: func() {
   144  						go func() {
   145  							select {
   146  							case <-clientGone:
   147  								requestCanceler.CancelRequest(outreq)
   148  							case <-reqDone:
   149  							}
   150  						}()
   151  					},
   152  				},
   153  				Closer: outreq.Body,
   154  			}
   155  		}
   156  	}
   157  
   158  	p.Director(outreq)
   159  	outreq.Proto = "HTTP/1.1"
   160  	outreq.ProtoMajor = 1
   161  	outreq.ProtoMinor = 1
   162  	outreq.Close = false
   163  
   164  	// Remove hop-by-hop headers to the backend.  Especially
   165  	// important is "Connection" because we want a persistent
   166  	// connection, regardless of what the client sent to us.  This
   167  	// is modifying the same underlying map from req (shallow
   168  	// copied above) so we only copy it if necessary.
   169  	copiedHeaders := false
   170  	for _, h := range hopHeaders {
   171  		if outreq.Header.Get(h) != "" {
   172  			if !copiedHeaders {
   173  				outreq.Header = make(http.Header)
   174  				copyHeader(outreq.Header, req.Header)
   175  				copiedHeaders = true
   176  			}
   177  			outreq.Header.Del(h)
   178  		}
   179  	}
   180  
   181  	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
   182  		// If we aren't the first proxy retain prior
   183  		// X-Forwarded-For information as a comma+space
   184  		// separated list and fold multiple headers into one.
   185  		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
   186  			clientIP = strings.Join(prior, ", ") + ", " + clientIP
   187  		}
   188  		outreq.Header.Set("X-Forwarded-For", clientIP)
   189  	}
   190  
   191  	res, err := transport.RoundTrip(outreq)
   192  	if err != nil {
   193  		p.logf("http: proxy error: %v", err)
   194  		rw.WriteHeader(http.StatusInternalServerError)
   195  		return
   196  	}
   197  	defer res.Body.Close()
   198  
   199  	for _, h := range hopHeaders {
   200  		res.Header.Del(h)
   201  	}
   202  
   203  	copyHeader(rw.Header(), res.Header)
   204  
   205  	rw.WriteHeader(res.StatusCode)
   206  	p.copyResponse(rw, res.Body)
   207  }
   208  
   209  func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
   210  	if p.FlushInterval != 0 {
   211  		if wf, ok := dst.(writeFlusher); ok {
   212  			mlw := &maxLatencyWriter{
   213  				dst:     wf,
   214  				latency: p.FlushInterval,
   215  				done:    make(chan bool),
   216  			}
   217  			go mlw.flushLoop()
   218  			defer mlw.stop()
   219  			dst = mlw
   220  		}
   221  	}
   222  
   223  	io.Copy(dst, src)
   224  }
   225  
   226  func (p *ReverseProxy) logf(format string, args ...interface{}) {
   227  	if p.ErrorLog != nil {
   228  		p.ErrorLog.Printf(format, args...)
   229  	} else {
   230  		log.Printf(format, args...)
   231  	}
   232  }
   233  
   234  type writeFlusher interface {
   235  	io.Writer
   236  	http.Flusher
   237  }
   238  
   239  type maxLatencyWriter struct {
   240  	dst     writeFlusher
   241  	latency time.Duration
   242  
   243  	lk   sync.Mutex // protects Write + Flush
   244  	done chan bool
   245  }
   246  
   247  func (m *maxLatencyWriter) Write(p []byte) (int, error) {
   248  	m.lk.Lock()
   249  	defer m.lk.Unlock()
   250  	return m.dst.Write(p)
   251  }
   252  
   253  func (m *maxLatencyWriter) flushLoop() {
   254  	t := time.NewTicker(m.latency)
   255  	defer t.Stop()
   256  	for {
   257  		select {
   258  		case <-m.done:
   259  			if onExitFlushLoop != nil {
   260  				onExitFlushLoop()
   261  			}
   262  			return
   263  		case <-t.C:
   264  			m.lk.Lock()
   265  			m.dst.Flush()
   266  			m.lk.Unlock()
   267  		}
   268  	}
   269  }
   270  
   271  func (m *maxLatencyWriter) stop() { m.done <- true }