github.com/code-reading/golang@v0.0.0-20220303082512-ba5bc0e589a3/go/src/net/http/httputil/reverseproxy.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // HTTP reverse proxy handler 6 7 package httputil 8 9 import ( 10 "context" 11 "fmt" 12 "io" 13 "log" 14 "net" 15 "net/http" 16 "net/http/internal/ascii" 17 "net/textproto" 18 "net/url" 19 "strings" 20 "sync" 21 "time" 22 23 "golang.org/x/net/http/httpguts" 24 ) 25 26 // ReverseProxy is an HTTP Handler that takes an incoming request and 27 // sends it to another server, proxying the response back to the 28 // client. 29 // 30 // ReverseProxy by default sets the client IP as the value of the 31 // X-Forwarded-For header. 32 // 33 // If an X-Forwarded-For header already exists, the client IP is 34 // appended to the existing values. As a special case, if the header 35 // exists in the Request.Header map but has a nil value (such as when 36 // set by the Director func), the X-Forwarded-For header is 37 // not modified. 38 // 39 // To prevent IP spoofing, be sure to delete any pre-existing 40 // X-Forwarded-For header coming from the client or 41 // an untrusted proxy. 42 type ReverseProxy struct { 43 // Director must be a function which modifies 44 // the request into a new request to be sent 45 // using Transport. Its response is then copied 46 // back to the original client unmodified. 47 // Director must not access the provided Request 48 // after returning. 49 Director func(*http.Request) 50 51 // The transport used to perform proxy requests. 52 // If nil, http.DefaultTransport is used. 53 Transport http.RoundTripper 54 55 // FlushInterval specifies the flush interval 56 // to flush to the client while copying the 57 // response body. 58 // If zero, no periodic flushing is done. 59 // A negative value means to flush immediately 60 // after each write to the client. 61 // The FlushInterval is ignored when ReverseProxy 62 // recognizes a response as a streaming response, or 63 // if its ContentLength is -1; for such responses, writes 64 // are flushed to the client immediately. 65 FlushInterval time.Duration 66 67 // ErrorLog specifies an optional logger for errors 68 // that occur when attempting to proxy the request. 69 // If nil, logging is done via the log package's standard logger. 70 ErrorLog *log.Logger 71 72 // BufferPool optionally specifies a buffer pool to 73 // get byte slices for use by io.CopyBuffer when 74 // copying HTTP response bodies. 75 BufferPool BufferPool 76 77 // ModifyResponse is an optional function that modifies the 78 // Response from the backend. It is called if the backend 79 // returns a response at all, with any HTTP status code. 80 // If the backend is unreachable, the optional ErrorHandler is 81 // called without any call to ModifyResponse. 82 // 83 // If ModifyResponse returns an error, ErrorHandler is called 84 // with its error value. If ErrorHandler is nil, its default 85 // implementation is used. 86 ModifyResponse func(*http.Response) error 87 88 // ErrorHandler is an optional function that handles errors 89 // reaching the backend or errors from ModifyResponse. 90 // 91 // If nil, the default is to log the provided error and return 92 // a 502 Status Bad Gateway response. 93 ErrorHandler func(http.ResponseWriter, *http.Request, error) 94 } 95 96 // A BufferPool is an interface for getting and returning temporary 97 // byte slices for use by io.CopyBuffer. 98 type BufferPool interface { 99 Get() []byte 100 Put([]byte) 101 } 102 103 func singleJoiningSlash(a, b string) string { 104 aslash := strings.HasSuffix(a, "/") 105 bslash := strings.HasPrefix(b, "/") 106 switch { 107 case aslash && bslash: 108 return a + b[1:] 109 case !aslash && !bslash: 110 return a + "/" + b 111 } 112 return a + b 113 } 114 115 func joinURLPath(a, b *url.URL) (path, rawpath string) { 116 if a.RawPath == "" && b.RawPath == "" { 117 return singleJoiningSlash(a.Path, b.Path), "" 118 } 119 // Same as singleJoiningSlash, but uses EscapedPath to determine 120 // whether a slash should be added 121 apath := a.EscapedPath() 122 bpath := b.EscapedPath() 123 124 aslash := strings.HasSuffix(apath, "/") 125 bslash := strings.HasPrefix(bpath, "/") 126 127 switch { 128 case aslash && bslash: 129 return a.Path + b.Path[1:], apath + bpath[1:] 130 case !aslash && !bslash: 131 return a.Path + "/" + b.Path, apath + "/" + bpath 132 } 133 return a.Path + b.Path, apath + bpath 134 } 135 136 // NewSingleHostReverseProxy returns a new ReverseProxy that routes 137 // URLs to the scheme, host, and base path provided in target. If the 138 // target's path is "/base" and the incoming request was for "/dir", 139 // the target request will be for /base/dir. 140 // NewSingleHostReverseProxy does not rewrite the Host header. 141 // To rewrite Host headers, use ReverseProxy directly with a custom 142 // Director policy. 143 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { 144 targetQuery := target.RawQuery 145 director := func(req *http.Request) { 146 req.URL.Scheme = target.Scheme 147 req.URL.Host = target.Host 148 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) 149 if targetQuery == "" || req.URL.RawQuery == "" { 150 req.URL.RawQuery = targetQuery + req.URL.RawQuery 151 } else { 152 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 153 } 154 if _, ok := req.Header["User-Agent"]; !ok { 155 // explicitly disable User-Agent so it's not set to default value 156 req.Header.Set("User-Agent", "") 157 } 158 } 159 return &ReverseProxy{Director: director} 160 } 161 162 func copyHeader(dst, src http.Header) { 163 for k, vv := range src { 164 for _, v := range vv { 165 dst.Add(k, v) 166 } 167 } 168 } 169 170 // Hop-by-hop headers. These are removed when sent to the backend. 171 // As of RFC 7230, hop-by-hop headers are required to appear in the 172 // Connection header field. These are the headers defined by the 173 // obsoleted RFC 2616 (section 13.5.1) and are used for backward 174 // compatibility. 175 var hopHeaders = []string{ 176 "Connection", 177 "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google 178 "Keep-Alive", 179 "Proxy-Authenticate", 180 "Proxy-Authorization", 181 "Te", // canonicalized version of "TE" 182 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 183 "Transfer-Encoding", 184 "Upgrade", 185 } 186 187 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { 188 p.logf("http: proxy error: %v", err) 189 rw.WriteHeader(http.StatusBadGateway) 190 } 191 192 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { 193 if p.ErrorHandler != nil { 194 return p.ErrorHandler 195 } 196 return p.defaultErrorHandler 197 } 198 199 // modifyResponse conditionally runs the optional ModifyResponse hook 200 // and reports whether the request should proceed. 201 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { 202 if p.ModifyResponse == nil { 203 return true 204 } 205 if err := p.ModifyResponse(res); err != nil { 206 res.Body.Close() 207 p.getErrorHandler()(rw, req, err) 208 return false 209 } 210 return true 211 } 212 213 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 214 transport := p.Transport 215 if transport == nil { 216 transport = http.DefaultTransport 217 } 218 219 ctx := req.Context() 220 if cn, ok := rw.(http.CloseNotifier); ok { 221 var cancel context.CancelFunc 222 ctx, cancel = context.WithCancel(ctx) 223 defer cancel() 224 notifyChan := cn.CloseNotify() 225 go func() { 226 select { 227 case <-notifyChan: 228 cancel() 229 case <-ctx.Done(): 230 } 231 }() 232 } 233 234 outreq := req.Clone(ctx) 235 if req.ContentLength == 0 { 236 outreq.Body = nil // Issue 16036: nil Body for http.Transport retries 237 } 238 if outreq.Body != nil { 239 // Reading from the request body after returning from a handler is not 240 // allowed, and the RoundTrip goroutine that reads the Body can outlive 241 // this handler. This can lead to a crash if the handler panics (see 242 // Issue 46866). Although calling Close doesn't guarantee there isn't 243 // any Read in flight after the handle returns, in practice it's safe to 244 // read after closing it. 245 defer outreq.Body.Close() 246 } 247 if outreq.Header == nil { 248 outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate 249 } 250 251 p.Director(outreq) 252 outreq.Close = false 253 254 reqUpType := upgradeType(outreq.Header) 255 if !ascii.IsPrint(reqUpType) { 256 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) 257 return 258 } 259 removeConnectionHeaders(outreq.Header) 260 261 // Remove hop-by-hop headers to the backend. Especially 262 // important is "Connection" because we want a persistent 263 // connection, regardless of what the client sent to us. 264 for _, h := range hopHeaders { 265 outreq.Header.Del(h) 266 } 267 268 // Issue 21096: tell backend applications that care about trailer support 269 // that we support trailers. (We do, but we don't go out of our way to 270 // advertise that unless the incoming client request thought it was worth 271 // mentioning.) Note that we look at req.Header, not outreq.Header, since 272 // the latter has passed through removeConnectionHeaders. 273 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { 274 outreq.Header.Set("Te", "trailers") 275 } 276 277 // After stripping all the hop-by-hop connection headers above, add back any 278 // necessary for protocol upgrades, such as for websockets. 279 if reqUpType != "" { 280 outreq.Header.Set("Connection", "Upgrade") 281 outreq.Header.Set("Upgrade", reqUpType) 282 } 283 284 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 285 // If we aren't the first proxy retain prior 286 // X-Forwarded-For information as a comma+space 287 // separated list and fold multiple headers into one. 288 prior, ok := outreq.Header["X-Forwarded-For"] 289 omit := ok && prior == nil // Issue 38079: nil now means don't populate the header 290 if len(prior) > 0 { 291 clientIP = strings.Join(prior, ", ") + ", " + clientIP 292 } 293 if !omit { 294 outreq.Header.Set("X-Forwarded-For", clientIP) 295 } 296 } 297 298 res, err := transport.RoundTrip(outreq) 299 if err != nil { 300 p.getErrorHandler()(rw, outreq, err) 301 return 302 } 303 304 // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) 305 if res.StatusCode == http.StatusSwitchingProtocols { 306 if !p.modifyResponse(rw, res, outreq) { 307 return 308 } 309 p.handleUpgradeResponse(rw, outreq, res) 310 return 311 } 312 313 removeConnectionHeaders(res.Header) 314 315 for _, h := range hopHeaders { 316 res.Header.Del(h) 317 } 318 319 if !p.modifyResponse(rw, res, outreq) { 320 return 321 } 322 323 copyHeader(rw.Header(), res.Header) 324 325 // The "Trailer" header isn't included in the Transport's response, 326 // at least for *http.Transport. Build it up from Trailer. 327 announcedTrailers := len(res.Trailer) 328 if announcedTrailers > 0 { 329 trailerKeys := make([]string, 0, len(res.Trailer)) 330 for k := range res.Trailer { 331 trailerKeys = append(trailerKeys, k) 332 } 333 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) 334 } 335 336 rw.WriteHeader(res.StatusCode) 337 338 err = p.copyResponse(rw, res.Body, p.flushInterval(res)) 339 if err != nil { 340 defer res.Body.Close() 341 // Since we're streaming the response, if we run into an error all we can do 342 // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler 343 // on read error while copying body. 344 if !shouldPanicOnCopyError(req) { 345 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) 346 return 347 } 348 panic(http.ErrAbortHandler) 349 } 350 res.Body.Close() // close now, instead of defer, to populate res.Trailer 351 352 if len(res.Trailer) > 0 { 353 // Force chunking if we saw a response trailer. 354 // This prevents net/http from calculating the length for short 355 // bodies and adding a Content-Length. 356 if fl, ok := rw.(http.Flusher); ok { 357 fl.Flush() 358 } 359 } 360 361 if len(res.Trailer) == announcedTrailers { 362 copyHeader(rw.Header(), res.Trailer) 363 return 364 } 365 366 for k, vv := range res.Trailer { 367 k = http.TrailerPrefix + k 368 for _, v := range vv { 369 rw.Header().Add(k, v) 370 } 371 } 372 } 373 374 var inOurTests bool // whether we're in our own tests 375 376 // shouldPanicOnCopyError reports whether the reverse proxy should 377 // panic with http.ErrAbortHandler. This is the right thing to do by 378 // default, but Go 1.10 and earlier did not, so existing unit tests 379 // weren't expecting panics. Only panic in our own tests, or when 380 // running under the HTTP server. 381 func shouldPanicOnCopyError(req *http.Request) bool { 382 if inOurTests { 383 // Our tests know to handle this panic. 384 return true 385 } 386 if req.Context().Value(http.ServerContextKey) != nil { 387 // We seem to be running under an HTTP server, so 388 // it'll recover the panic. 389 return true 390 } 391 // Otherwise act like Go 1.10 and earlier to not break 392 // existing tests. 393 return false 394 } 395 396 // removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. 397 // See RFC 7230, section 6.1 398 func removeConnectionHeaders(h http.Header) { 399 for _, f := range h["Connection"] { 400 for _, sf := range strings.Split(f, ",") { 401 if sf = textproto.TrimString(sf); sf != "" { 402 h.Del(sf) 403 } 404 } 405 } 406 } 407 408 // flushInterval returns the p.FlushInterval value, conditionally 409 // overriding its value for a specific request/response. 410 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { 411 resCT := res.Header.Get("Content-Type") 412 413 // For Server-Sent Events responses, flush immediately. 414 // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream 415 if resCT == "text/event-stream" { 416 return -1 // negative means immediately 417 } 418 419 // We might have the case of streaming for which Content-Length might be unset. 420 if res.ContentLength == -1 { 421 return -1 422 } 423 424 return p.FlushInterval 425 } 426 427 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { 428 if flushInterval != 0 { 429 if wf, ok := dst.(writeFlusher); ok { 430 mlw := &maxLatencyWriter{ 431 dst: wf, 432 latency: flushInterval, 433 } 434 defer mlw.stop() 435 436 // set up initial timer so headers get flushed even if body writes are delayed 437 mlw.flushPending = true 438 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) 439 440 dst = mlw 441 } 442 } 443 444 var buf []byte 445 if p.BufferPool != nil { 446 buf = p.BufferPool.Get() 447 defer p.BufferPool.Put(buf) 448 } 449 _, err := p.copyBuffer(dst, src, buf) 450 return err 451 } 452 453 // copyBuffer returns any write errors or non-EOF read errors, and the amount 454 // of bytes written. 455 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { 456 if len(buf) == 0 { 457 buf = make([]byte, 32*1024) 458 } 459 var written int64 460 for { 461 nr, rerr := src.Read(buf) 462 if rerr != nil && rerr != io.EOF && rerr != context.Canceled { 463 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) 464 } 465 if nr > 0 { 466 nw, werr := dst.Write(buf[:nr]) 467 if nw > 0 { 468 written += int64(nw) 469 } 470 if werr != nil { 471 return written, werr 472 } 473 if nr != nw { 474 return written, io.ErrShortWrite 475 } 476 } 477 if rerr != nil { 478 if rerr == io.EOF { 479 rerr = nil 480 } 481 return written, rerr 482 } 483 } 484 } 485 486 func (p *ReverseProxy) logf(format string, args ...interface{}) { 487 if p.ErrorLog != nil { 488 p.ErrorLog.Printf(format, args...) 489 } else { 490 log.Printf(format, args...) 491 } 492 } 493 494 type writeFlusher interface { 495 io.Writer 496 http.Flusher 497 } 498 499 type maxLatencyWriter struct { 500 dst writeFlusher 501 latency time.Duration // non-zero; negative means to flush immediately 502 503 mu sync.Mutex // protects t, flushPending, and dst.Flush 504 t *time.Timer 505 flushPending bool 506 } 507 508 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { 509 m.mu.Lock() 510 defer m.mu.Unlock() 511 n, err = m.dst.Write(p) 512 if m.latency < 0 { 513 m.dst.Flush() 514 return 515 } 516 if m.flushPending { 517 return 518 } 519 if m.t == nil { 520 m.t = time.AfterFunc(m.latency, m.delayedFlush) 521 } else { 522 m.t.Reset(m.latency) 523 } 524 m.flushPending = true 525 return 526 } 527 528 func (m *maxLatencyWriter) delayedFlush() { 529 m.mu.Lock() 530 defer m.mu.Unlock() 531 if !m.flushPending { // if stop was called but AfterFunc already started this goroutine 532 return 533 } 534 m.dst.Flush() 535 m.flushPending = false 536 } 537 538 func (m *maxLatencyWriter) stop() { 539 m.mu.Lock() 540 defer m.mu.Unlock() 541 m.flushPending = false 542 if m.t != nil { 543 m.t.Stop() 544 } 545 } 546 547 func upgradeType(h http.Header) string { 548 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { 549 return "" 550 } 551 return h.Get("Upgrade") 552 } 553 554 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { 555 reqUpType := upgradeType(req.Header) 556 resUpType := upgradeType(res.Header) 557 if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. 558 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)) 559 } 560 if !ascii.EqualFold(reqUpType, resUpType) { 561 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) 562 return 563 } 564 565 hj, ok := rw.(http.Hijacker) 566 if !ok { 567 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) 568 return 569 } 570 backConn, ok := res.Body.(io.ReadWriteCloser) 571 if !ok { 572 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) 573 return 574 } 575 576 backConnCloseCh := make(chan bool) 577 go func() { 578 // Ensure that the cancellation of a request closes the backend. 579 // See issue https://golang.org/issue/35559. 580 select { 581 case <-req.Context().Done(): 582 case <-backConnCloseCh: 583 } 584 backConn.Close() 585 }() 586 587 defer close(backConnCloseCh) 588 589 conn, brw, err := hj.Hijack() 590 if err != nil { 591 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) 592 return 593 } 594 defer conn.Close() 595 596 copyHeader(rw.Header(), res.Header) 597 598 res.Header = rw.Header() 599 res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above 600 if err := res.Write(brw); err != nil { 601 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) 602 return 603 } 604 if err := brw.Flush(); err != nil { 605 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) 606 return 607 } 608 errc := make(chan error, 1) 609 spc := switchProtocolCopier{user: conn, backend: backConn} 610 go spc.copyToBackend(errc) 611 go spc.copyFromBackend(errc) 612 <-errc 613 return 614 } 615 616 // switchProtocolCopier exists so goroutines proxying data back and 617 // forth have nice names in stacks. 618 type switchProtocolCopier struct { 619 user, backend io.ReadWriter 620 } 621 622 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { 623 _, err := io.Copy(c.user, c.backend) 624 errc <- err 625 } 626 627 func (c switchProtocolCopier) copyToBackend(errc chan<- error) { 628 _, err := io.Copy(c.backend, c.user) 629 errc <- err 630 }