github.com/cortesi/devd@v0.0.0-20200427000907-c1a3bfba27d8/reverseproxy/reverseproxy.go (about) 1 // Package reverseproxy is a reverse proxy implementation based on the built-in 2 // httuptil.Reverseproxy. Extensions include better logging and support for 3 // injection. 4 package reverseproxy 5 6 import ( 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "net/url" 12 "strconv" 13 "strings" 14 "sync" 15 "time" 16 17 "golang.org/x/net/context" 18 19 "github.com/cortesi/devd/inject" 20 "github.com/cortesi/termlog" 21 humanize "github.com/dustin/go-humanize" 22 ) 23 24 // onExitFlushLoop is a callback set by tests to detect the state of the 25 // flushLoop() goroutine. 26 var onExitFlushLoop func() 27 28 // ReverseProxy is an HTTP Handler that takes an incoming request and 29 // sends it to another server, proxying the response back to the 30 // client. 31 type ReverseProxy struct { 32 // Director must be a function which modifies 33 // the request into a new request to be sent 34 // using Transport. Its response is then copied 35 // back to the original client unmodified. 36 Director func(*http.Request) 37 38 // The transport used to perform proxy requests. 39 // If nil, http.DefaultTransport is used. 40 Transport http.RoundTripper 41 42 // FlushInterval specifies the flush interval 43 // to flush to the client while copying the 44 // response body. 45 // If zero, no periodic flushing is done. 46 FlushInterval time.Duration 47 48 Inject inject.CopyInject 49 } 50 51 func singleJoiningSlash(a, b string) string { 52 if b == "" { 53 return a 54 } 55 56 aslash := strings.HasSuffix(a, "/") 57 bslash := strings.HasPrefix(b, "/") 58 switch { 59 case aslash && bslash: 60 return a + b[1:] 61 case !aslash && !bslash: 62 return a + "/" + b 63 } 64 return a + b 65 } 66 67 // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites 68 // URLs to the scheme, host, and base path provided in target. If the 69 // target's path is "/base" and the incoming request was for "/dir", 70 // the target request will be for /base/dir. 71 func NewSingleHostReverseProxy(target *url.URL, ci inject.CopyInject) *ReverseProxy { 72 targetQuery := target.RawQuery 73 director := func(req *http.Request) { 74 req.URL.Host = target.Host 75 req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) 76 if req.Header.Get("X-Forwarded-Host") == "" { 77 req.Header.Set("X-Forwarded-Host", req.Host) 78 } 79 if req.Header.Get("X-Forwarded-Proto") == "" { 80 req.Header.Set("X-Forwarded-Proto", req.URL.Scheme) 81 } 82 req.URL.Scheme = target.Scheme 83 84 // Set "identity"-only content encoding, in order for injector to 85 // work on text response 86 req.Header.Set("Accept-Encoding", "identity") 87 88 req.Host = req.URL.Host 89 if targetQuery == "" || req.URL.RawQuery == "" { 90 req.URL.RawQuery = targetQuery + req.URL.RawQuery 91 } else { 92 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 93 } 94 } 95 return &ReverseProxy{Director: director, Inject: ci} 96 } 97 98 func copyHeader(dst, src http.Header) { 99 for k, vv := range src { 100 for _, v := range vv { 101 dst.Add(k, v) 102 } 103 } 104 } 105 106 // Hop-by-hop headers. These are removed when sent to the backend. 107 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 108 var hopHeaders = []string{ 109 "Connection", 110 "Keep-Alive", 111 "Proxy-Authenticate", 112 "Proxy-Authorization", 113 "Te", // canonicalized version of "TE" 114 "Trailers", 115 "Transfer-Encoding", 116 "Upgrade", 117 } 118 119 // ServeHTTPContext serves HTTP with a context 120 func (p *ReverseProxy) ServeHTTPContext( 121 ctx context.Context, rw http.ResponseWriter, req *http.Request, 122 ) { 123 log := termlog.FromContext(ctx) 124 transport := p.Transport 125 if transport == nil { 126 transport = http.DefaultTransport 127 } 128 129 outreq := new(http.Request) 130 *outreq = *req // includes shallow copies of maps, but okay 131 132 p.Director(outreq) 133 outreq.Proto = "HTTP/1.1" 134 outreq.ProtoMajor = 1 135 outreq.ProtoMinor = 1 136 outreq.Close = false 137 138 // Remove hop-by-hop headers to the backend. Especially 139 // important is "Connection" because we want a persistent 140 // connection, regardless of what the client sent to us. This 141 // is modifying the same underlying map from req (shallow 142 // copied above) so we only copy it if necessary. 143 copiedHeaders := false 144 for _, h := range hopHeaders { 145 if outreq.Header.Get(h) != "" { 146 if !copiedHeaders { 147 outreq.Header = make(http.Header) 148 copyHeader(outreq.Header, req.Header) 149 copiedHeaders = true 150 } 151 outreq.Header.Del(h) 152 } 153 } 154 155 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 156 // If we aren't the first proxy retain prior 157 // X-Forwarded-For information as a comma+space 158 // separated list and fold multiple headers into one. 159 if prior, ok := outreq.Header["X-Forwarded-For"]; ok { 160 clientIP = strings.Join(prior, ", ") + ", " + clientIP 161 } 162 outreq.Header.Set("X-Forwarded-For", clientIP) 163 } 164 165 res, err := transport.RoundTrip(outreq) 166 if err != nil { 167 log.Shout("reverse proxy error: %v", err) 168 rw.WriteHeader(http.StatusInternalServerError) 169 return 170 } 171 defer res.Body.Close() 172 if req.ContentLength > 0 { 173 log.Say(fmt.Sprintf("%s uploaded", humanize.Bytes(uint64(req.ContentLength)))) 174 } 175 176 inject, err := p.Inject.Sniff(res.Body, res.Header.Get("Content-Type")) 177 if err != nil { 178 log.Shout("reverse proxy error: %v", err) 179 rw.WriteHeader(http.StatusInternalServerError) 180 return 181 } 182 183 if inject.Found() { 184 cl, err := strconv.ParseInt(res.Header.Get("Content-Length"), 10, 32) 185 if err == nil { 186 cl = cl + int64(inject.Extra()) 187 res.Header.Set("Content-Length", strconv.FormatInt(cl, 10)) 188 } 189 } 190 copyHeader(rw.Header(), res.Header) 191 rw.WriteHeader(res.StatusCode) 192 p.copyResponse(ctx, rw, inject) 193 } 194 195 func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 196 p.ServeHTTPContext(context.Background(), w, r) 197 } 198 199 func (p *ReverseProxy) copyResponse(ctx context.Context, dst io.Writer, inject inject.Injector) { 200 log := termlog.FromContext(ctx) 201 if p.FlushInterval != 0 { 202 if wf, ok := dst.(writeFlusher); ok { 203 mlw := &maxLatencyWriter{ 204 dst: wf, 205 latency: p.FlushInterval, 206 done: make(chan bool), 207 } 208 go mlw.flushLoop() 209 defer mlw.stop() 210 dst = mlw 211 } 212 } 213 _, err := inject.Copy(dst) 214 if err != nil { 215 log.Shout("Error forwarding data: %s", err) 216 } 217 } 218 219 type writeFlusher interface { 220 io.Writer 221 http.Flusher 222 } 223 224 type maxLatencyWriter struct { 225 sync.Mutex // protects Write + Flush 226 227 dst writeFlusher 228 latency time.Duration 229 230 done chan bool 231 } 232 233 func (m *maxLatencyWriter) Write(p []byte) (int, error) { 234 m.Lock() 235 defer m.Unlock() 236 return m.dst.Write(p) 237 } 238 239 func (m *maxLatencyWriter) flushLoop() { 240 t := time.NewTicker(m.latency) 241 defer t.Stop() 242 for { 243 select { 244 case <-m.done: 245 if onExitFlushLoop != nil { 246 onExitFlushLoop() 247 } 248 return 249 case <-t.C: 250 m.Lock() 251 m.dst.Flush() 252 m.Unlock() 253 } 254 } 255 } 256 257 func (m *maxLatencyWriter) stop() { m.done <- true }