github.com/cortesi/devd@v0.0.0-20200427000907-c1a3bfba27d8/reverseproxy/reverseproxy.go (about)

     1  // Package reverseproxy is a reverse proxy implementation based on the built-in
     2  // httuptil.Reverseproxy. Extensions include better logging and support for
     3  // injection.
     4  package reverseproxy
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"net/url"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"golang.org/x/net/context"
    18  
    19  	"github.com/cortesi/devd/inject"
    20  	"github.com/cortesi/termlog"
    21  	humanize "github.com/dustin/go-humanize"
    22  )
    23  
    24  // onExitFlushLoop is a callback set by tests to detect the state of the
    25  // flushLoop() goroutine.
    26  var onExitFlushLoop func()
    27  
    28  // ReverseProxy is an HTTP Handler that takes an incoming request and
    29  // sends it to another server, proxying the response back to the
    30  // client.
    31  type ReverseProxy struct {
    32  	// Director must be a function which modifies
    33  	// the request into a new request to be sent
    34  	// using Transport. Its response is then copied
    35  	// back to the original client unmodified.
    36  	Director func(*http.Request)
    37  
    38  	// The transport used to perform proxy requests.
    39  	// If nil, http.DefaultTransport is used.
    40  	Transport http.RoundTripper
    41  
    42  	// FlushInterval specifies the flush interval
    43  	// to flush to the client while copying the
    44  	// response body.
    45  	// If zero, no periodic flushing is done.
    46  	FlushInterval time.Duration
    47  
    48  	Inject inject.CopyInject
    49  }
    50  
    51  func singleJoiningSlash(a, b string) string {
    52  	if b == "" {
    53  		return a
    54  	}
    55  
    56  	aslash := strings.HasSuffix(a, "/")
    57  	bslash := strings.HasPrefix(b, "/")
    58  	switch {
    59  	case aslash && bslash:
    60  		return a + b[1:]
    61  	case !aslash && !bslash:
    62  		return a + "/" + b
    63  	}
    64  	return a + b
    65  }
    66  
    67  // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
    68  // URLs to the scheme, host, and base path provided in target. If the
    69  // target's path is "/base" and the incoming request was for "/dir",
    70  // the target request will be for /base/dir.
    71  func NewSingleHostReverseProxy(target *url.URL, ci inject.CopyInject) *ReverseProxy {
    72  	targetQuery := target.RawQuery
    73  	director := func(req *http.Request) {
    74  		req.URL.Host = target.Host
    75  		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
    76  		if req.Header.Get("X-Forwarded-Host") == "" {
    77  			req.Header.Set("X-Forwarded-Host", req.Host)
    78  		}
    79  		if req.Header.Get("X-Forwarded-Proto") == "" {
    80  			req.Header.Set("X-Forwarded-Proto", req.URL.Scheme)
    81  		}
    82  		req.URL.Scheme = target.Scheme
    83  
    84  		// Set "identity"-only content encoding, in order for injector to
    85  		// work on text response
    86  		req.Header.Set("Accept-Encoding", "identity")
    87  
    88  		req.Host = req.URL.Host
    89  		if targetQuery == "" || req.URL.RawQuery == "" {
    90  			req.URL.RawQuery = targetQuery + req.URL.RawQuery
    91  		} else {
    92  			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
    93  		}
    94  	}
    95  	return &ReverseProxy{Director: director, Inject: ci}
    96  }
    97  
    98  func copyHeader(dst, src http.Header) {
    99  	for k, vv := range src {
   100  		for _, v := range vv {
   101  			dst.Add(k, v)
   102  		}
   103  	}
   104  }
   105  
   106  // Hop-by-hop headers. These are removed when sent to the backend.
   107  // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
   108  var hopHeaders = []string{
   109  	"Connection",
   110  	"Keep-Alive",
   111  	"Proxy-Authenticate",
   112  	"Proxy-Authorization",
   113  	"Te", // canonicalized version of "TE"
   114  	"Trailers",
   115  	"Transfer-Encoding",
   116  	"Upgrade",
   117  }
   118  
   119  // ServeHTTPContext serves HTTP with a context
   120  func (p *ReverseProxy) ServeHTTPContext(
   121  	ctx context.Context, rw http.ResponseWriter, req *http.Request,
   122  ) {
   123  	log := termlog.FromContext(ctx)
   124  	transport := p.Transport
   125  	if transport == nil {
   126  		transport = http.DefaultTransport
   127  	}
   128  
   129  	outreq := new(http.Request)
   130  	*outreq = *req // includes shallow copies of maps, but okay
   131  
   132  	p.Director(outreq)
   133  	outreq.Proto = "HTTP/1.1"
   134  	outreq.ProtoMajor = 1
   135  	outreq.ProtoMinor = 1
   136  	outreq.Close = false
   137  
   138  	// Remove hop-by-hop headers to the backend.  Especially
   139  	// important is "Connection" because we want a persistent
   140  	// connection, regardless of what the client sent to us.  This
   141  	// is modifying the same underlying map from req (shallow
   142  	// copied above) so we only copy it if necessary.
   143  	copiedHeaders := false
   144  	for _, h := range hopHeaders {
   145  		if outreq.Header.Get(h) != "" {
   146  			if !copiedHeaders {
   147  				outreq.Header = make(http.Header)
   148  				copyHeader(outreq.Header, req.Header)
   149  				copiedHeaders = true
   150  			}
   151  			outreq.Header.Del(h)
   152  		}
   153  	}
   154  
   155  	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
   156  		// If we aren't the first proxy retain prior
   157  		// X-Forwarded-For information as a comma+space
   158  		// separated list and fold multiple headers into one.
   159  		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
   160  			clientIP = strings.Join(prior, ", ") + ", " + clientIP
   161  		}
   162  		outreq.Header.Set("X-Forwarded-For", clientIP)
   163  	}
   164  
   165  	res, err := transport.RoundTrip(outreq)
   166  	if err != nil {
   167  		log.Shout("reverse proxy error: %v", err)
   168  		rw.WriteHeader(http.StatusInternalServerError)
   169  		return
   170  	}
   171  	defer res.Body.Close()
   172  	if req.ContentLength > 0 {
   173  		log.Say(fmt.Sprintf("%s uploaded", humanize.Bytes(uint64(req.ContentLength))))
   174  	}
   175  
   176  	inject, err := p.Inject.Sniff(res.Body, res.Header.Get("Content-Type"))
   177  	if err != nil {
   178  		log.Shout("reverse proxy error: %v", err)
   179  		rw.WriteHeader(http.StatusInternalServerError)
   180  		return
   181  	}
   182  
   183  	if inject.Found() {
   184  		cl, err := strconv.ParseInt(res.Header.Get("Content-Length"), 10, 32)
   185  		if err == nil {
   186  			cl = cl + int64(inject.Extra())
   187  			res.Header.Set("Content-Length", strconv.FormatInt(cl, 10))
   188  		}
   189  	}
   190  	copyHeader(rw.Header(), res.Header)
   191  	rw.WriteHeader(res.StatusCode)
   192  	p.copyResponse(ctx, rw, inject)
   193  }
   194  
   195  func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   196  	p.ServeHTTPContext(context.Background(), w, r)
   197  }
   198  
   199  func (p *ReverseProxy) copyResponse(ctx context.Context, dst io.Writer, inject inject.Injector) {
   200  	log := termlog.FromContext(ctx)
   201  	if p.FlushInterval != 0 {
   202  		if wf, ok := dst.(writeFlusher); ok {
   203  			mlw := &maxLatencyWriter{
   204  				dst:     wf,
   205  				latency: p.FlushInterval,
   206  				done:    make(chan bool),
   207  			}
   208  			go mlw.flushLoop()
   209  			defer mlw.stop()
   210  			dst = mlw
   211  		}
   212  	}
   213  	_, err := inject.Copy(dst)
   214  	if err != nil {
   215  		log.Shout("Error forwarding data: %s", err)
   216  	}
   217  }
   218  
   219  type writeFlusher interface {
   220  	io.Writer
   221  	http.Flusher
   222  }
   223  
   224  type maxLatencyWriter struct {
   225  	sync.Mutex // protects Write + Flush
   226  
   227  	dst     writeFlusher
   228  	latency time.Duration
   229  
   230  	done chan bool
   231  }
   232  
   233  func (m *maxLatencyWriter) Write(p []byte) (int, error) {
   234  	m.Lock()
   235  	defer m.Unlock()
   236  	return m.dst.Write(p)
   237  }
   238  
   239  func (m *maxLatencyWriter) flushLoop() {
   240  	t := time.NewTicker(m.latency)
   241  	defer t.Stop()
   242  	for {
   243  		select {
   244  		case <-m.done:
   245  			if onExitFlushLoop != nil {
   246  				onExitFlushLoop()
   247  			}
   248  			return
   249  		case <-t.C:
   250  			m.Lock()
   251  			m.dst.Flush()
   252  			m.Unlock()
   253  		}
   254  	}
   255  }
   256  
   257  func (m *maxLatencyWriter) stop() { m.done <- true }