github.com/useflyent/fhttp@v0.0.0-20211004035111-333f430cfbbf/httputil/reverseproxy.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // HTTP reverse proxy handler 6 7 package httputil 8 9 import ( 10 "context" 11 "fmt" 12 "io" 13 "log" 14 "net" 15 "net/textproto" 16 "net/url" 17 "strings" 18 "sync" 19 "time" 20 21 http "github.com/useflyent/fhttp" 22 23 "golang.org/x/net/http/httpguts" 24 ) 25 26 // ReverseProxy is an HTTP Handler that takes an incoming request and 27 // sends it to another server, proxying the response back to the 28 // client. 29 // 30 // ReverseProxy by default sets the client IP as the value of the 31 // X-Forwarded-For header. 32 // 33 // If an X-Forwarded-For header already exists, the client IP is 34 // appended to the existing values. As a special case, if the header 35 // exists in the Request.Header map but has a nil value (such as when 36 // set by the Director func), the X-Forwarded-For header is 37 // not modified. 38 // 39 // To prevent IP spoofing, be sure to delete any pre-existing 40 // X-Forwarded-For header coming from the client or 41 // an untrusted proxy. 42 type ReverseProxy struct { 43 // Director must be a function which modifies 44 // the request into a new request to be sent 45 // using Transport. Its response is then copied 46 // back to the original client unmodified. 47 // Director must not access the provided Request 48 // after returning. 49 Director func(*http.Request) 50 51 // The transport used to perform proxy requests. 52 // If nil, http.DefaultTransport is used. 53 Transport http.RoundTripper 54 55 // FlushInterval specifies the flush interval 56 // to flush to the client while copying the 57 // response body. 58 // If zero, no periodic flushing is done. 59 // A negative value means to flush immediately 60 // after each write to the client. 61 // The FlushInterval is ignored when ReverseProxy 62 // recognizes a response as a streaming response, or 63 // if its ContentLength is -1; for such responses, writes 64 // are flushed to the client immediately. 65 FlushInterval time.Duration 66 67 // ErrorLog specifies an optional logger for errors 68 // that occur when attempting to proxy the request. 69 // If nil, logging is done via the log package's standard logger. 70 ErrorLog *log.Logger 71 72 // BufferPool optionally specifies a buffer pool to 73 // get byte slices for use by io.CopyBuffer when 74 // copying HTTP response bodies. 75 BufferPool BufferPool 76 77 // ModifyResponse is an optional function that modifies the 78 // Response from the backend. It is called if the backend 79 // returns a response at all, with any HTTP status code. 80 // If the backend is unreachable, the optional ErrorHandler is 81 // called without any call to ModifyResponse. 82 // 83 // If ModifyResponse returns an error, ErrorHandler is called 84 // with its error value. If ErrorHandler is nil, its default 85 // implementation is used. 86 ModifyResponse func(*http.Response) error 87 88 // ErrorHandler is an optional function that handles errors 89 // reaching the backend or errors from ModifyResponse. 90 // 91 // If nil, the default is to log the provided error and return 92 // a 502 Status Bad Gateway response. 93 ErrorHandler func(http.ResponseWriter, *http.Request, error) 94 } 95 96 // A BufferPool is an interface for getting and returning temporary 97 // byte slices for use by io.CopyBuffer. 98 type BufferPool interface { 99 Get() []byte 100 Put([]byte) 101 } 102 103 func singleJoiningSlash(a, b string) string { 104 aslash := strings.HasSuffix(a, "/") 105 bslash := strings.HasPrefix(b, "/") 106 switch { 107 case aslash && bslash: 108 return a + b[1:] 109 case !aslash && !bslash: 110 return a + "/" + b 111 } 112 return a + b 113 } 114 115 func joinURLPath(a, b *url.URL) (path, rawpath string) { 116 if a.RawPath == "" && b.RawPath == "" { 117 return singleJoiningSlash(a.Path, b.Path), "" 118 } 119 // Same as singleJoiningSlash, but uses EscapedPath to determine 120 // whether a slash should be added 121 apath := a.EscapedPath() 122 bpath := b.EscapedPath() 123 124 aslash := strings.HasSuffix(apath, "/") 125 bslash := strings.HasPrefix(bpath, "/") 126 127 switch { 128 case aslash && bslash: 129 return a.Path + b.Path[1:], apath + bpath[1:] 130 case !aslash && !bslash: 131 return a.Path + "/" + b.Path, apath + "/" + bpath 132 } 133 return a.Path + b.Path, apath + bpath 134 } 135 136 // NewSingleHostReverseProxy returns a new ReverseProxy that routes 137 // URLs to the scheme, host, and base path provided in target. If the 138 // target's path is "/base" and the incoming request was for "/dir", 139 // the target request will be for /base/dir. 140 // NewSingleHostReverseProxy does not rewrite the Host header. 141 // To rewrite Host headers, use ReverseProxy directly with a custom 142 // Director policy. 143 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { 144 targetQuery := target.RawQuery 145 director := func(req *http.Request) { 146 req.URL.Scheme = target.Scheme 147 req.URL.Host = target.Host 148 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) 149 if targetQuery == "" || req.URL.RawQuery == "" { 150 req.URL.RawQuery = targetQuery + req.URL.RawQuery 151 } else { 152 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 153 } 154 if _, ok := req.Header["User-Agent"]; !ok { 155 // explicitly disable User-Agent so it's not set to default value 156 req.Header.Set("User-Agent", "") 157 } 158 } 159 return &ReverseProxy{Director: director} 160 } 161 162 func copyHeader(dst, src http.Header) { 163 for k, vv := range src { 164 for _, v := range vv { 165 dst.Add(k, v) 166 } 167 } 168 } 169 170 // Hop-by-hop headers. These are removed when sent to the backend. 171 // As of RFC 7230, hop-by-hop headers are required to appear in the 172 // Connection header field. These are the headers defined by the 173 // obsoleted RFC 2616 (section 13.5.1) and are used for backward 174 // compatibility. 175 var hopHeaders = []string{ 176 "Connection", 177 "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google 178 "Keep-Alive", 179 "Proxy-Authenticate", 180 "Proxy-Authorization", 181 "Te", // canonicalized version of "TE" 182 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 183 "Transfer-Encoding", 184 "Upgrade", 185 } 186 187 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { 188 p.logf("http: proxy error: %v", err) 189 rw.WriteHeader(http.StatusBadGateway) 190 } 191 192 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { 193 if p.ErrorHandler != nil { 194 return p.ErrorHandler 195 } 196 return p.defaultErrorHandler 197 } 198 199 // modifyResponse conditionally runs the optional ModifyResponse hook 200 // and reports whether the request should proceed. 201 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { 202 if p.ModifyResponse == nil { 203 return true 204 } 205 if err := p.ModifyResponse(res); err != nil { 206 res.Body.Close() 207 p.getErrorHandler()(rw, req, err) 208 return false 209 } 210 return true 211 } 212 213 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 214 transport := p.Transport 215 if transport == nil { 216 transport = http.DefaultTransport 217 } 218 219 ctx := req.Context() 220 if cn, ok := rw.(http.CloseNotifier); ok { 221 var cancel context.CancelFunc 222 ctx, cancel = context.WithCancel(ctx) 223 defer cancel() 224 notifyChan := cn.CloseNotify() 225 go func() { 226 select { 227 case <-notifyChan: 228 cancel() 229 case <-ctx.Done(): 230 } 231 }() 232 } 233 234 outreq := req.Clone(ctx) 235 if req.ContentLength == 0 { 236 outreq.Body = nil // Issue 16036: nil Body for http.Transport retries 237 } 238 if outreq.Header == nil { 239 outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate 240 } 241 242 p.Director(outreq) 243 outreq.Close = false 244 245 reqUpType := upgradeType(outreq.Header) 246 removeConnectionHeaders(outreq.Header) 247 248 // Remove hop-by-hop headers to the backend. Especially 249 // important is "Connection" because we want a persistent 250 // connection, regardless of what the client sent to us. 251 for _, h := range hopHeaders { 252 hv := outreq.Header.Get(h) 253 if hv == "" { 254 continue 255 } 256 if h == "Te" && hv == "trailers" { 257 // Issue 21096: tell backend applications that 258 // care about trailer support that we support 259 // trailers. (We do, but we don't go out of 260 // our way to advertise that unless the 261 // incoming client request thought it was 262 // worth mentioning) 263 continue 264 } 265 outreq.Header.Del(h) 266 } 267 268 // After stripping all the hop-by-hop connection headers above, add back any 269 // necessary for protocol upgrades, such as for websockets. 270 if reqUpType != "" { 271 outreq.Header.Set("Connection", "Upgrade") 272 outreq.Header.Set("Upgrade", reqUpType) 273 } 274 275 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 276 // If we aren't the first proxy retain prior 277 // X-Forwarded-For information as a comma+space 278 // separated list and fold multiple headers into one. 279 prior, ok := outreq.Header["X-Forwarded-For"] 280 omit := ok && prior == nil // Issue 38079: nil now means don't populate the header 281 if len(prior) > 0 { 282 clientIP = strings.Join(prior, ", ") + ", " + clientIP 283 } 284 if !omit { 285 outreq.Header.Set("X-Forwarded-For", clientIP) 286 } 287 } 288 289 res, err := transport.RoundTrip(outreq) 290 if err != nil { 291 p.getErrorHandler()(rw, outreq, err) 292 return 293 } 294 295 // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) 296 if res.StatusCode == http.StatusSwitchingProtocols { 297 if !p.modifyResponse(rw, res, outreq) { 298 return 299 } 300 p.handleUpgradeResponse(rw, outreq, res) 301 return 302 } 303 304 removeConnectionHeaders(res.Header) 305 306 for _, h := range hopHeaders { 307 res.Header.Del(h) 308 } 309 310 if !p.modifyResponse(rw, res, outreq) { 311 return 312 } 313 314 copyHeader(rw.Header(), res.Header) 315 316 // The "Trailer" header isn't included in the Transport's response, 317 // at least for *http.Transport. Build it up from Trailer. 318 announcedTrailers := len(res.Trailer) 319 if announcedTrailers > 0 { 320 trailerKeys := make([]string, 0, len(res.Trailer)) 321 for k := range res.Trailer { 322 trailerKeys = append(trailerKeys, k) 323 } 324 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) 325 } 326 327 rw.WriteHeader(res.StatusCode) 328 329 err = p.copyResponse(rw, res.Body, p.flushInterval(res)) 330 if err != nil { 331 defer res.Body.Close() 332 // Since we're streaming the response, if we run into an error all we can do 333 // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler 334 // on read error while copying body. 335 if !shouldPanicOnCopyError(req) { 336 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) 337 return 338 } 339 panic(http.ErrAbortHandler) 340 } 341 res.Body.Close() // close now, instead of defer, to populate res.Trailer 342 343 if len(res.Trailer) > 0 { 344 // Force chunking if we saw a response trailer. 345 // This prevents net/http from calculating the length for short 346 // bodies and adding a Content-Length. 347 if fl, ok := rw.(http.Flusher); ok { 348 fl.Flush() 349 } 350 } 351 352 if len(res.Trailer) == announcedTrailers { 353 copyHeader(rw.Header(), res.Trailer) 354 return 355 } 356 357 for k, vv := range res.Trailer { 358 k = http.TrailerPrefix + k 359 for _, v := range vv { 360 rw.Header().Add(k, v) 361 } 362 } 363 } 364 365 var inOurTests bool // whether we're in our own tests 366 367 // shouldPanicOnCopyError reports whether the reverse proxy should 368 // panic with http.ErrAbortHandler. This is the right thing to do by 369 // default, but Go 1.10 and earlier did not, so existing unit tests 370 // weren't expecting panics. Only panic in our own tests, or when 371 // running under the HTTP server. 372 func shouldPanicOnCopyError(req *http.Request) bool { 373 if inOurTests { 374 // Our tests know to handle this panic. 375 return true 376 } 377 if req.Context().Value(http.ServerContextKey) != nil { 378 // We seem to be running under an HTTP server, so 379 // it'll recover the panic. 380 return true 381 } 382 // Otherwise act like Go 1.10 and earlier to not break 383 // existing tests. 384 return false 385 } 386 387 // removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. 388 // See RFC 7230, section 6.1 389 func removeConnectionHeaders(h http.Header) { 390 for _, f := range h["Connection"] { 391 for _, sf := range strings.Split(f, ",") { 392 if sf = textproto.TrimString(sf); sf != "" { 393 h.Del(sf) 394 } 395 } 396 } 397 } 398 399 // flushInterval returns the p.FlushInterval value, conditionally 400 // overriding its value for a specific request/response. 401 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { 402 resCT := res.Header.Get("Content-Type") 403 404 // For Server-Sent Events responses, flush immediately. 405 // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream 406 if resCT == "text/event-stream" { 407 return -1 // negative means immediately 408 } 409 410 // We might have the case of streaming for which Content-Length might be unset. 411 if res.ContentLength == -1 { 412 return -1 413 } 414 415 return p.FlushInterval 416 } 417 418 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { 419 if flushInterval != 0 { 420 if wf, ok := dst.(writeFlusher); ok { 421 mlw := &maxLatencyWriter{ 422 dst: wf, 423 latency: flushInterval, 424 } 425 defer mlw.stop() 426 427 // set up initial timer so headers get flushed even if body writes are delayed 428 mlw.flushPending = true 429 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) 430 431 dst = mlw 432 } 433 } 434 435 var buf []byte 436 if p.BufferPool != nil { 437 buf = p.BufferPool.Get() 438 defer p.BufferPool.Put(buf) 439 } 440 _, err := p.copyBuffer(dst, src, buf) 441 return err 442 } 443 444 // copyBuffer returns any write errors or non-EOF read errors, and the amount 445 // of bytes written. 446 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { 447 if len(buf) == 0 { 448 buf = make([]byte, 32*1024) 449 } 450 var written int64 451 for { 452 nr, rerr := src.Read(buf) 453 if rerr != nil && rerr != io.EOF && rerr != context.Canceled { 454 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) 455 } 456 if nr > 0 { 457 nw, werr := dst.Write(buf[:nr]) 458 if nw > 0 { 459 written += int64(nw) 460 } 461 if werr != nil { 462 return written, werr 463 } 464 if nr != nw { 465 return written, io.ErrShortWrite 466 } 467 } 468 if rerr != nil { 469 if rerr == io.EOF { 470 rerr = nil 471 } 472 return written, rerr 473 } 474 } 475 } 476 477 func (p *ReverseProxy) logf(format string, args ...interface{}) { 478 if p.ErrorLog != nil { 479 p.ErrorLog.Printf(format, args...) 480 } else { 481 log.Printf(format, args...) 482 } 483 } 484 485 type writeFlusher interface { 486 io.Writer 487 http.Flusher 488 } 489 490 type maxLatencyWriter struct { 491 dst writeFlusher 492 latency time.Duration // non-zero; negative means to flush immediately 493 494 mu sync.Mutex // protects t, flushPending, and dst.Flush 495 t *time.Timer 496 flushPending bool 497 } 498 499 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { 500 m.mu.Lock() 501 defer m.mu.Unlock() 502 n, err = m.dst.Write(p) 503 if m.latency < 0 { 504 m.dst.Flush() 505 return 506 } 507 if m.flushPending { 508 return 509 } 510 if m.t == nil { 511 m.t = time.AfterFunc(m.latency, m.delayedFlush) 512 } else { 513 m.t.Reset(m.latency) 514 } 515 m.flushPending = true 516 return 517 } 518 519 func (m *maxLatencyWriter) delayedFlush() { 520 m.mu.Lock() 521 defer m.mu.Unlock() 522 if !m.flushPending { // if stop was called but AfterFunc already started this goroutine 523 return 524 } 525 m.dst.Flush() 526 m.flushPending = false 527 } 528 529 func (m *maxLatencyWriter) stop() { 530 m.mu.Lock() 531 defer m.mu.Unlock() 532 m.flushPending = false 533 if m.t != nil { 534 m.t.Stop() 535 } 536 } 537 538 func upgradeType(h http.Header) string { 539 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { 540 return "" 541 } 542 return strings.ToLower(h.Get("Upgrade")) 543 } 544 545 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { 546 reqUpType := upgradeType(req.Header) 547 resUpType := upgradeType(res.Header) 548 if reqUpType != resUpType { 549 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) 550 return 551 } 552 553 hj, ok := rw.(http.Hijacker) 554 if !ok { 555 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) 556 return 557 } 558 backConn, ok := res.Body.(io.ReadWriteCloser) 559 if !ok { 560 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) 561 return 562 } 563 564 backConnCloseCh := make(chan bool) 565 go func() { 566 // Ensure that the cancelation of a request closes the backend. 567 // See issue https://golang.org/issue/35559. 568 select { 569 case <-req.Context().Done(): 570 case <-backConnCloseCh: 571 } 572 backConn.Close() 573 }() 574 575 defer close(backConnCloseCh) 576 577 conn, brw, err := hj.Hijack() 578 if err != nil { 579 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) 580 return 581 } 582 defer conn.Close() 583 584 copyHeader(rw.Header(), res.Header) 585 586 res.Header = rw.Header() 587 res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above 588 if err := res.Write(brw); err != nil { 589 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) 590 return 591 } 592 if err := brw.Flush(); err != nil { 593 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) 594 return 595 } 596 errc := make(chan error, 1) 597 spc := switchProtocolCopier{user: conn, backend: backConn} 598 go spc.copyToBackend(errc) 599 go spc.copyFromBackend(errc) 600 <-errc 601 return 602 } 603 604 // switchProtocolCopier exists so goroutines proxying data back and 605 // forth have nice names in stacks. 606 type switchProtocolCopier struct { 607 user, backend io.ReadWriter 608 } 609 610 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { 611 _, err := io.Copy(c.user, c.backend) 612 errc <- err 613 } 614 615 func (c switchProtocolCopier) copyToBackend(errc chan<- error) { 616 _, err := io.Copy(c.backend, c.user) 617 errc <- err 618 }