github.com/ader1990/go@v0.0.0-20140630135419-8c24447fa791/src/pkg/net/http/httputil/reverseproxy.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // HTTP reverse proxy handler
     6  
     7  package httputil
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/url"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  // onExitFlushLoop is a callback set by tests to detect the state of the
    21  // flushLoop() goroutine.
    22  var onExitFlushLoop func()
    23  
    24  // ReverseProxy is an HTTP Handler that takes an incoming request and
    25  // sends it to another server, proxying the response back to the
    26  // client.
    27  type ReverseProxy struct {
    28  	// Director must be a function which modifies
    29  	// the request into a new request to be sent
    30  	// using Transport. Its response is then copied
    31  	// back to the original client unmodified.
    32  	Director func(*http.Request)
    33  
    34  	// The transport used to perform proxy requests.
    35  	// If nil, http.DefaultTransport is used.
    36  	Transport http.RoundTripper
    37  
    38  	// FlushInterval specifies the flush interval
    39  	// to flush to the client while copying the
    40  	// response body.
    41  	// If zero, no periodic flushing is done.
    42  	FlushInterval time.Duration
    43  }
    44  
    45  func singleJoiningSlash(a, b string) string {
    46  	aslash := strings.HasSuffix(a, "/")
    47  	bslash := strings.HasPrefix(b, "/")
    48  	switch {
    49  	case aslash && bslash:
    50  		return a + b[1:]
    51  	case !aslash && !bslash:
    52  		return a + "/" + b
    53  	}
    54  	return a + b
    55  }
    56  
    57  // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
    58  // URLs to the scheme, host, and base path provided in target. If the
    59  // target's path is "/base" and the incoming request was for "/dir",
    60  // the target request will be for /base/dir.
    61  func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
    62  	targetQuery := target.RawQuery
    63  	director := func(req *http.Request) {
    64  		req.URL.Scheme = target.Scheme
    65  		req.URL.Host = target.Host
    66  		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
    67  		if targetQuery == "" || req.URL.RawQuery == "" {
    68  			req.URL.RawQuery = targetQuery + req.URL.RawQuery
    69  		} else {
    70  			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
    71  		}
    72  	}
    73  	return &ReverseProxy{Director: director}
    74  }
    75  
    76  func copyHeader(dst, src http.Header) {
    77  	for k, vv := range src {
    78  		for _, v := range vv {
    79  			dst.Add(k, v)
    80  		}
    81  	}
    82  }
    83  
    84  // Hop-by-hop headers. These are removed when sent to the backend.
    85  // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
    86  var hopHeaders = []string{
    87  	"Connection",
    88  	"Keep-Alive",
    89  	"Proxy-Authenticate",
    90  	"Proxy-Authorization",
    91  	"Te", // canonicalized version of "TE"
    92  	"Trailers",
    93  	"Transfer-Encoding",
    94  	"Upgrade",
    95  }
    96  
    97  func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    98  	transport := p.Transport
    99  	if transport == nil {
   100  		transport = http.DefaultTransport
   101  	}
   102  
   103  	outreq := new(http.Request)
   104  	*outreq = *req // includes shallow copies of maps, but okay
   105  
   106  	p.Director(outreq)
   107  	outreq.Proto = "HTTP/1.1"
   108  	outreq.ProtoMajor = 1
   109  	outreq.ProtoMinor = 1
   110  	outreq.Close = false
   111  
   112  	// Remove hop-by-hop headers to the backend.  Especially
   113  	// important is "Connection" because we want a persistent
   114  	// connection, regardless of what the client sent to us.  This
   115  	// is modifying the same underlying map from req (shallow
   116  	// copied above) so we only copy it if necessary.
   117  	copiedHeaders := false
   118  	for _, h := range hopHeaders {
   119  		if outreq.Header.Get(h) != "" {
   120  			if !copiedHeaders {
   121  				outreq.Header = make(http.Header)
   122  				copyHeader(outreq.Header, req.Header)
   123  				copiedHeaders = true
   124  			}
   125  			outreq.Header.Del(h)
   126  		}
   127  	}
   128  
   129  	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
   130  		// If we aren't the first proxy retain prior
   131  		// X-Forwarded-For information as a comma+space
   132  		// separated list and fold multiple headers into one.
   133  		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
   134  			clientIP = strings.Join(prior, ", ") + ", " + clientIP
   135  		}
   136  		outreq.Header.Set("X-Forwarded-For", clientIP)
   137  	}
   138  
   139  	res, err := transport.RoundTrip(outreq)
   140  	if err != nil {
   141  		log.Printf("http: proxy error: %v", err)
   142  		rw.WriteHeader(http.StatusInternalServerError)
   143  		return
   144  	}
   145  	defer res.Body.Close()
   146  
   147  	for _, h := range hopHeaders {
   148  		res.Header.Del(h)
   149  	}
   150  
   151  	copyHeader(rw.Header(), res.Header)
   152  
   153  	rw.WriteHeader(res.StatusCode)
   154  	p.copyResponse(rw, res.Body)
   155  }
   156  
   157  func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
   158  	if p.FlushInterval != 0 {
   159  		if wf, ok := dst.(writeFlusher); ok {
   160  			mlw := &maxLatencyWriter{
   161  				dst:     wf,
   162  				latency: p.FlushInterval,
   163  				done:    make(chan bool),
   164  			}
   165  			go mlw.flushLoop()
   166  			defer mlw.stop()
   167  			dst = mlw
   168  		}
   169  	}
   170  
   171  	io.Copy(dst, src)
   172  }
   173  
   174  type writeFlusher interface {
   175  	io.Writer
   176  	http.Flusher
   177  }
   178  
   179  type maxLatencyWriter struct {
   180  	dst     writeFlusher
   181  	latency time.Duration
   182  
   183  	lk   sync.Mutex // protects Write + Flush
   184  	done chan bool
   185  }
   186  
   187  func (m *maxLatencyWriter) Write(p []byte) (int, error) {
   188  	m.lk.Lock()
   189  	defer m.lk.Unlock()
   190  	return m.dst.Write(p)
   191  }
   192  
   193  func (m *maxLatencyWriter) flushLoop() {
   194  	t := time.NewTicker(m.latency)
   195  	defer t.Stop()
   196  	for {
   197  		select {
   198  		case <-m.done:
   199  			if onExitFlushLoop != nil {
   200  				onExitFlushLoop()
   201  			}
   202  			return
   203  		case <-t.C:
   204  			m.lk.Lock()
   205  			m.dst.Flush()
   206  			m.lk.Unlock()
   207  		}
   208  	}
   209  }
   210  
   211  func (m *maxLatencyWriter) stop() { m.done <- true }