github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/net/http/httputil/reverseproxy.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // HTTP reverse proxy handler 6 7 package httputil 8 9 import ( 10 "context" 11 "fmt" 12 "io" 13 "log" 14 "net" 15 "net/http" 16 "net/textproto" 17 "net/url" 18 "strings" 19 "sync" 20 "time" 21 22 "golang.org/x/net/http/httpguts" 23 ) 24 25 // ReverseProxy is an HTTP Handler that takes an incoming request and 26 // sends it to another server, proxying the response back to the 27 // client. 28 // 29 // ReverseProxy by default sets the client IP as the value of the 30 // X-Forwarded-For header. 31 // 32 // If an X-Forwarded-For header already exists, the client IP is 33 // appended to the existing values. As a special case, if the header 34 // exists in the Request.Header map but has a nil value (such as when 35 // set by the Director func), the X-Forwarded-For header is 36 // not modified. 37 // 38 // To prevent IP spoofing, be sure to delete any pre-existing 39 // X-Forwarded-For header coming from the client or 40 // an untrusted proxy. 41 type ReverseProxy struct { 42 // Director must be a function which modifies 43 // the request into a new request to be sent 44 // using Transport. Its response is then copied 45 // back to the original client unmodified. 46 // Director must not access the provided Request 47 // after returning. 48 Director func(*http.Request) 49 50 // The transport used to perform proxy requests. 51 // If nil, http.DefaultTransport is used. 52 Transport http.RoundTripper 53 54 // FlushInterval specifies the flush interval 55 // to flush to the client while copying the 56 // response body. 57 // If zero, no periodic flushing is done. 58 // A negative value means to flush immediately 59 // after each write to the client. 60 // The FlushInterval is ignored when ReverseProxy 61 // recognizes a response as a streaming response, or 62 // if its ContentLength is -1; for such responses, writes 63 // are flushed to the client immediately. 64 FlushInterval time.Duration 65 66 // ErrorLog specifies an optional logger for errors 67 // that occur when attempting to proxy the request. 68 // If nil, logging is done via the log package's standard logger. 69 ErrorLog *log.Logger 70 71 // BufferPool optionally specifies a buffer pool to 72 // get byte slices for use by io.CopyBuffer when 73 // copying HTTP response bodies. 74 BufferPool BufferPool 75 76 // ModifyResponse is an optional function that modifies the 77 // Response from the backend. It is called if the backend 78 // returns a response at all, with any HTTP status code. 79 // If the backend is unreachable, the optional ErrorHandler is 80 // called without any call to ModifyResponse. 81 // 82 // If ModifyResponse returns an error, ErrorHandler is called 83 // with its error value. If ErrorHandler is nil, its default 84 // implementation is used. 85 ModifyResponse func(*http.Response) error 86 87 // ErrorHandler is an optional function that handles errors 88 // reaching the backend or errors from ModifyResponse. 89 // 90 // If nil, the default is to log the provided error and return 91 // a 502 Status Bad Gateway response. 92 ErrorHandler func(http.ResponseWriter, *http.Request, error) 93 } 94 95 // A BufferPool is an interface for getting and returning temporary 96 // byte slices for use by io.CopyBuffer. 97 type BufferPool interface { 98 Get() []byte 99 Put([]byte) 100 } 101 102 func singleJoiningSlash(a, b string) string { 103 aslash := strings.HasSuffix(a, "/") 104 bslash := strings.HasPrefix(b, "/") 105 switch { 106 case aslash && bslash: 107 return a + b[1:] 108 case !aslash && !bslash: 109 return a + "/" + b 110 } 111 return a + b 112 } 113 114 func joinURLPath(a, b *url.URL) (path, rawpath string) { 115 if a.RawPath == "" && b.RawPath == "" { 116 return singleJoiningSlash(a.Path, b.Path), "" 117 } 118 // Same as singleJoiningSlash, but uses EscapedPath to determine 119 // whether a slash should be added 120 apath := a.EscapedPath() 121 bpath := b.EscapedPath() 122 123 aslash := strings.HasSuffix(apath, "/") 124 bslash := strings.HasPrefix(bpath, "/") 125 126 switch { 127 case aslash && bslash: 128 return a.Path + b.Path[1:], apath + bpath[1:] 129 case !aslash && !bslash: 130 return a.Path + "/" + b.Path, apath + "/" + bpath 131 } 132 return a.Path + b.Path, apath + bpath 133 } 134 135 // NewSingleHostReverseProxy returns a new ReverseProxy that routes 136 // URLs to the scheme, host, and base path provided in target. If the 137 // target's path is "/base" and the incoming request was for "/dir", 138 // the target request will be for /base/dir. 139 // NewSingleHostReverseProxy does not rewrite the Host header. 140 // To rewrite Host headers, use ReverseProxy directly with a custom 141 // Director policy. 142 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { 143 targetQuery := target.RawQuery 144 director := func(req *http.Request) { 145 req.URL.Scheme = target.Scheme 146 req.URL.Host = target.Host 147 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) 148 if targetQuery == "" || req.URL.RawQuery == "" { 149 req.URL.RawQuery = targetQuery + req.URL.RawQuery 150 } else { 151 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 152 } 153 if _, ok := req.Header["User-Agent"]; !ok { 154 // explicitly disable User-Agent so it's not set to default value 155 req.Header.Set("User-Agent", "") 156 } 157 } 158 return &ReverseProxy{Director: director} 159 } 160 161 func copyHeader(dst, src http.Header) { 162 for k, vv := range src { 163 for _, v := range vv { 164 dst.Add(k, v) 165 } 166 } 167 } 168 169 // Hop-by-hop headers. These are removed when sent to the backend. 170 // As of RFC 7230, hop-by-hop headers are required to appear in the 171 // Connection header field. These are the headers defined by the 172 // obsoleted RFC 2616 (section 13.5.1) and are used for backward 173 // compatibility. 174 var hopHeaders = []string{ 175 "Connection", 176 "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google 177 "Keep-Alive", 178 "Proxy-Authenticate", 179 "Proxy-Authorization", 180 "Te", // canonicalized version of "TE" 181 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 182 "Transfer-Encoding", 183 "Upgrade", 184 } 185 186 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { 187 p.logf("http: proxy error: %v", err) 188 rw.WriteHeader(http.StatusBadGateway) 189 } 190 191 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { 192 if p.ErrorHandler != nil { 193 return p.ErrorHandler 194 } 195 return p.defaultErrorHandler 196 } 197 198 // modifyResponse conditionally runs the optional ModifyResponse hook 199 // and reports whether the request should proceed. 200 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { 201 if p.ModifyResponse == nil { 202 return true 203 } 204 if err := p.ModifyResponse(res); err != nil { 205 res.Body.Close() 206 p.getErrorHandler()(rw, req, err) 207 return false 208 } 209 return true 210 } 211 212 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 213 transport := p.Transport 214 if transport == nil { 215 transport = http.DefaultTransport 216 } 217 218 ctx := req.Context() 219 if cn, ok := rw.(http.CloseNotifier); ok { 220 var cancel context.CancelFunc 221 ctx, cancel = context.WithCancel(ctx) 222 defer cancel() 223 notifyChan := cn.CloseNotify() 224 go func() { 225 select { 226 case <-notifyChan: 227 cancel() 228 case <-ctx.Done(): 229 } 230 }() 231 } 232 233 outreq := req.Clone(ctx) 234 if req.ContentLength == 0 { 235 outreq.Body = nil // Issue 16036: nil Body for http.Transport retries 236 } 237 if outreq.Header == nil { 238 outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate 239 } 240 241 p.Director(outreq) 242 outreq.Close = false 243 244 reqUpType := upgradeType(outreq.Header) 245 removeConnectionHeaders(outreq.Header) 246 247 // Remove hop-by-hop headers to the backend. Especially 248 // important is "Connection" because we want a persistent 249 // connection, regardless of what the client sent to us. 250 for _, h := range hopHeaders { 251 hv := outreq.Header.Get(h) 252 if hv == "" { 253 continue 254 } 255 if h == "Te" && hv == "trailers" { 256 // Issue 21096: tell backend applications that 257 // care about trailer support that we support 258 // trailers. (We do, but we don't go out of 259 // our way to advertise that unless the 260 // incoming client request thought it was 261 // worth mentioning) 262 continue 263 } 264 outreq.Header.Del(h) 265 } 266 267 // After stripping all the hop-by-hop connection headers above, add back any 268 // necessary for protocol upgrades, such as for websockets. 269 if reqUpType != "" { 270 outreq.Header.Set("Connection", "Upgrade") 271 outreq.Header.Set("Upgrade", reqUpType) 272 } 273 274 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 275 // If we aren't the first proxy retain prior 276 // X-Forwarded-For information as a comma+space 277 // separated list and fold multiple headers into one. 278 prior, ok := outreq.Header["X-Forwarded-For"] 279 omit := ok && prior == nil // Issue 38079: nil now means don't populate the header 280 if len(prior) > 0 { 281 clientIP = strings.Join(prior, ", ") + ", " + clientIP 282 } 283 if !omit { 284 outreq.Header.Set("X-Forwarded-For", clientIP) 285 } 286 } 287 288 res, err := transport.RoundTrip(outreq) 289 if err != nil { 290 p.getErrorHandler()(rw, outreq, err) 291 return 292 } 293 294 // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) 295 if res.StatusCode == http.StatusSwitchingProtocols { 296 if !p.modifyResponse(rw, res, outreq) { 297 return 298 } 299 p.handleUpgradeResponse(rw, outreq, res) 300 return 301 } 302 303 removeConnectionHeaders(res.Header) 304 305 for _, h := range hopHeaders { 306 res.Header.Del(h) 307 } 308 309 if !p.modifyResponse(rw, res, outreq) { 310 return 311 } 312 313 copyHeader(rw.Header(), res.Header) 314 315 // The "Trailer" header isn't included in the Transport's response, 316 // at least for *http.Transport. Build it up from Trailer. 317 announcedTrailers := len(res.Trailer) 318 if announcedTrailers > 0 { 319 trailerKeys := make([]string, 0, len(res.Trailer)) 320 for k := range res.Trailer { 321 trailerKeys = append(trailerKeys, k) 322 } 323 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) 324 } 325 326 rw.WriteHeader(res.StatusCode) 327 328 err = p.copyResponse(rw, res.Body, p.flushInterval(res)) 329 if err != nil { 330 defer res.Body.Close() 331 // Since we're streaming the response, if we run into an error all we can do 332 // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler 333 // on read error while copying body. 334 if !shouldPanicOnCopyError(req) { 335 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) 336 return 337 } 338 panic(http.ErrAbortHandler) 339 } 340 res.Body.Close() // close now, instead of defer, to populate res.Trailer 341 342 if len(res.Trailer) > 0 { 343 // Force chunking if we saw a response trailer. 344 // This prevents net/http from calculating the length for short 345 // bodies and adding a Content-Length. 346 if fl, ok := rw.(http.Flusher); ok { 347 fl.Flush() 348 } 349 } 350 351 if len(res.Trailer) == announcedTrailers { 352 copyHeader(rw.Header(), res.Trailer) 353 return 354 } 355 356 for k, vv := range res.Trailer { 357 k = http.TrailerPrefix + k 358 for _, v := range vv { 359 rw.Header().Add(k, v) 360 } 361 } 362 } 363 364 var inOurTests bool // whether we're in our own tests 365 366 // shouldPanicOnCopyError reports whether the reverse proxy should 367 // panic with http.ErrAbortHandler. This is the right thing to do by 368 // default, but Go 1.10 and earlier did not, so existing unit tests 369 // weren't expecting panics. Only panic in our own tests, or when 370 // running under the HTTP server. 371 func shouldPanicOnCopyError(req *http.Request) bool { 372 if inOurTests { 373 // Our tests know to handle this panic. 374 return true 375 } 376 if req.Context().Value(http.ServerContextKey) != nil { 377 // We seem to be running under an HTTP server, so 378 // it'll recover the panic. 379 return true 380 } 381 // Otherwise act like Go 1.10 and earlier to not break 382 // existing tests. 383 return false 384 } 385 386 // removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. 387 // See RFC 7230, section 6.1 388 func removeConnectionHeaders(h http.Header) { 389 for _, f := range h["Connection"] { 390 for _, sf := range strings.Split(f, ",") { 391 if sf = textproto.TrimString(sf); sf != "" { 392 h.Del(sf) 393 } 394 } 395 } 396 } 397 398 // flushInterval returns the p.FlushInterval value, conditionally 399 // overriding its value for a specific request/response. 400 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { 401 resCT := res.Header.Get("Content-Type") 402 403 // For Server-Sent Events responses, flush immediately. 404 // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream 405 if resCT == "text/event-stream" { 406 return -1 // negative means immediately 407 } 408 409 // We might have the case of streaming for which Content-Length might be unset. 410 if res.ContentLength == -1 { 411 return -1 412 } 413 414 return p.FlushInterval 415 } 416 417 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { 418 if flushInterval != 0 { 419 if wf, ok := dst.(writeFlusher); ok { 420 mlw := &maxLatencyWriter{ 421 dst: wf, 422 latency: flushInterval, 423 } 424 defer mlw.stop() 425 426 // set up initial timer so headers get flushed even if body writes are delayed 427 mlw.flushPending = true 428 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) 429 430 dst = mlw 431 } 432 } 433 434 var buf []byte 435 if p.BufferPool != nil { 436 buf = p.BufferPool.Get() 437 defer p.BufferPool.Put(buf) 438 } 439 _, err := p.copyBuffer(dst, src, buf) 440 return err 441 } 442 443 // copyBuffer returns any write errors or non-EOF read errors, and the amount 444 // of bytes written. 445 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { 446 if len(buf) == 0 { 447 buf = make([]byte, 32*1024) 448 } 449 var written int64 450 for { 451 nr, rerr := src.Read(buf) 452 if rerr != nil && rerr != io.EOF && rerr != context.Canceled { 453 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) 454 } 455 if nr > 0 { 456 nw, werr := dst.Write(buf[:nr]) 457 if nw > 0 { 458 written += int64(nw) 459 } 460 if werr != nil { 461 return written, werr 462 } 463 if nr != nw { 464 return written, io.ErrShortWrite 465 } 466 } 467 if rerr != nil { 468 if rerr == io.EOF { 469 rerr = nil 470 } 471 return written, rerr 472 } 473 } 474 } 475 476 func (p *ReverseProxy) logf(format string, args ...interface{}) { 477 if p.ErrorLog != nil { 478 p.ErrorLog.Printf(format, args...) 479 } else { 480 log.Printf(format, args...) 481 } 482 } 483 484 type writeFlusher interface { 485 io.Writer 486 http.Flusher 487 } 488 489 type maxLatencyWriter struct { 490 dst writeFlusher 491 latency time.Duration // non-zero; negative means to flush immediately 492 493 mu sync.Mutex // protects t, flushPending, and dst.Flush 494 t *time.Timer 495 flushPending bool 496 } 497 498 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { 499 m.mu.Lock() 500 defer m.mu.Unlock() 501 n, err = m.dst.Write(p) 502 if m.latency < 0 { 503 m.dst.Flush() 504 return 505 } 506 if m.flushPending { 507 return 508 } 509 if m.t == nil { 510 m.t = time.AfterFunc(m.latency, m.delayedFlush) 511 } else { 512 m.t.Reset(m.latency) 513 } 514 m.flushPending = true 515 return 516 } 517 518 func (m *maxLatencyWriter) delayedFlush() { 519 m.mu.Lock() 520 defer m.mu.Unlock() 521 if !m.flushPending { // if stop was called but AfterFunc already started this goroutine 522 return 523 } 524 m.dst.Flush() 525 m.flushPending = false 526 } 527 528 func (m *maxLatencyWriter) stop() { 529 m.mu.Lock() 530 defer m.mu.Unlock() 531 m.flushPending = false 532 if m.t != nil { 533 m.t.Stop() 534 } 535 } 536 537 func upgradeType(h http.Header) string { 538 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { 539 return "" 540 } 541 return strings.ToLower(h.Get("Upgrade")) 542 } 543 544 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { 545 reqUpType := upgradeType(req.Header) 546 resUpType := upgradeType(res.Header) 547 if reqUpType != resUpType { 548 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) 549 return 550 } 551 552 hj, ok := rw.(http.Hijacker) 553 if !ok { 554 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) 555 return 556 } 557 backConn, ok := res.Body.(io.ReadWriteCloser) 558 if !ok { 559 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) 560 return 561 } 562 563 backConnCloseCh := make(chan bool) 564 go func() { 565 // Ensure that the cancelation of a request closes the backend. 566 // See issue https://golang.org/issue/35559. 567 select { 568 case <-req.Context().Done(): 569 case <-backConnCloseCh: 570 } 571 backConn.Close() 572 }() 573 574 defer close(backConnCloseCh) 575 576 conn, brw, err := hj.Hijack() 577 if err != nil { 578 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) 579 return 580 } 581 defer conn.Close() 582 583 copyHeader(rw.Header(), res.Header) 584 585 res.Header = rw.Header() 586 res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above 587 if err := res.Write(brw); err != nil { 588 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) 589 return 590 } 591 if err := brw.Flush(); err != nil { 592 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) 593 return 594 } 595 errc := make(chan error, 1) 596 spc := switchProtocolCopier{user: conn, backend: backConn} 597 go spc.copyToBackend(errc) 598 go spc.copyFromBackend(errc) 599 <-errc 600 return 601 } 602 603 // switchProtocolCopier exists so goroutines proxying data back and 604 // forth have nice names in stacks. 605 type switchProtocolCopier struct { 606 user, backend io.ReadWriter 607 } 608 609 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { 610 _, err := io.Copy(c.user, c.backend) 611 errc <- err 612 } 613 614 func (c switchProtocolCopier) copyToBackend(errc chan<- error) { 615 _, err := io.Copy(c.backend, c.user) 616 errc <- err 617 }