github.com/megatontech/mynoteforgo@v0.0.0-20200507084910-5d0c6ea6e890/源码/net/http/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"io/ioutil"
    16  	"log"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/url"
    20  	"os"
    21  	"reflect"
    22  	"strconv"
    23  	"strings"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    30  
    31  func init() {
    32  	inOurTests = true
    33  	hopHeaders = append(hopHeaders, fakeHopHeader)
    34  }
    35  
    36  func TestReverseProxy(t *testing.T) {
    37  	const backendResponse = "I am the backend"
    38  	const backendStatus = 404
    39  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    40  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    41  			c, _, _ := w.(http.Hijacker).Hijack()
    42  			c.Close()
    43  			return
    44  		}
    45  		if len(r.TransferEncoding) > 0 {
    46  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    47  		}
    48  		if r.Header.Get("X-Forwarded-For") == "" {
    49  			t.Errorf("didn't get X-Forwarded-For header")
    50  		}
    51  		if c := r.Header.Get("Connection"); c != "" {
    52  			t.Errorf("handler got Connection header value %q", c)
    53  		}
    54  		if c := r.Header.Get("Te"); c != "trailers" {
    55  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    56  		}
    57  		if c := r.Header.Get("Upgrade"); c != "" {
    58  			t.Errorf("handler got Upgrade header value %q", c)
    59  		}
    60  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    61  			t.Errorf("handler got Proxy-Connection header value %q", c)
    62  		}
    63  		if g, e := r.Host, "some-name"; g != e {
    64  			t.Errorf("backend got Host header %q, want %q", g, e)
    65  		}
    66  		w.Header().Set("Trailers", "not a special header field name")
    67  		w.Header().Set("Trailer", "X-Trailer")
    68  		w.Header().Set("X-Foo", "bar")
    69  		w.Header().Set("Upgrade", "foo")
    70  		w.Header().Set(fakeHopHeader, "foo")
    71  		w.Header().Add("X-Multi-Value", "foo")
    72  		w.Header().Add("X-Multi-Value", "bar")
    73  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    74  		w.WriteHeader(backendStatus)
    75  		w.Write([]byte(backendResponse))
    76  		w.Header().Set("X-Trailer", "trailer_value")
    77  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    78  	}))
    79  	defer backend.Close()
    80  	backendURL, err := url.Parse(backend.URL)
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    85  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
    86  	frontend := httptest.NewServer(proxyHandler)
    87  	defer frontend.Close()
    88  	frontendClient := frontend.Client()
    89  
    90  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    91  	getReq.Host = "some-name"
    92  	getReq.Header.Set("Connection", "close")
    93  	getReq.Header.Set("Te", "trailers")
    94  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    95  	getReq.Header.Set("Upgrade", "foo")
    96  	getReq.Close = true
    97  	res, err := frontendClient.Do(getReq)
    98  	if err != nil {
    99  		t.Fatalf("Get: %v", err)
   100  	}
   101  	if g, e := res.StatusCode, backendStatus; g != e {
   102  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   103  	}
   104  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   105  		t.Errorf("got X-Foo %q; expected %q", g, e)
   106  	}
   107  	if c := res.Header.Get(fakeHopHeader); c != "" {
   108  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   109  	}
   110  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   111  		t.Errorf("header Trailers = %q; want %q", g, e)
   112  	}
   113  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   114  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   115  	}
   116  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   117  		t.Fatalf("got %d SetCookies, want %d", g, e)
   118  	}
   119  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   120  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   121  	}
   122  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   123  		t.Errorf("unexpected cookie %q", cookie.Name)
   124  	}
   125  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   126  	if g, e := string(bodyBytes), backendResponse; g != e {
   127  		t.Errorf("got body %q; expected %q", g, e)
   128  	}
   129  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   130  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   131  	}
   132  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   133  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   134  	}
   135  
   136  	// Test that a backend failing to be reached or one which doesn't return
   137  	// a response results in a StatusBadGateway.
   138  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   139  	getReq.Close = true
   140  	res, err = frontendClient.Do(getReq)
   141  	if err != nil {
   142  		t.Fatal(err)
   143  	}
   144  	res.Body.Close()
   145  	if res.StatusCode != http.StatusBadGateway {
   146  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   147  	}
   148  
   149  }
   150  
   151  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   152  // header value.
   153  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   154  	const fakeConnectionToken = "X-Fake-Connection-Token"
   155  	const backendResponse = "I am the backend"
   156  
   157  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   158  	// in the Request's Connection header.
   159  	const someConnHeader = "X-Some-Conn-Header"
   160  
   161  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   162  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   163  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   164  		}
   165  		if c := r.Header.Get(someConnHeader); c != "" {
   166  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   167  		}
   168  		w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken)
   169  		w.Header().Set(someConnHeader, "should be deleted")
   170  		w.Header().Set(fakeConnectionToken, "should be deleted")
   171  		io.WriteString(w, backendResponse)
   172  	}))
   173  	defer backend.Close()
   174  	backendURL, err := url.Parse(backend.URL)
   175  	if err != nil {
   176  		t.Fatal(err)
   177  	}
   178  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   179  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   180  		proxyHandler.ServeHTTP(w, r)
   181  		if c := r.Header.Get(someConnHeader); c != "original value" {
   182  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value")
   183  		}
   184  	}))
   185  	defer frontend.Close()
   186  
   187  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   188  	getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken)
   189  	getReq.Header.Set(someConnHeader, "original value")
   190  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   191  	res, err := frontend.Client().Do(getReq)
   192  	if err != nil {
   193  		t.Fatalf("Get: %v", err)
   194  	}
   195  	defer res.Body.Close()
   196  	bodyBytes, err := ioutil.ReadAll(res.Body)
   197  	if err != nil {
   198  		t.Fatalf("reading body: %v", err)
   199  	}
   200  	if got, want := string(bodyBytes), backendResponse; got != want {
   201  		t.Errorf("got body %q; want %q", got, want)
   202  	}
   203  	if c := res.Header.Get(someConnHeader); c != "" {
   204  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   205  	}
   206  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   207  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   208  	}
   209  }
   210  
   211  func TestXForwardedFor(t *testing.T) {
   212  	const prevForwardedFor = "client ip"
   213  	const backendResponse = "I am the backend"
   214  	const backendStatus = 404
   215  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   216  		if r.Header.Get("X-Forwarded-For") == "" {
   217  			t.Errorf("didn't get X-Forwarded-For header")
   218  		}
   219  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   220  			t.Errorf("X-Forwarded-For didn't contain prior data")
   221  		}
   222  		w.WriteHeader(backendStatus)
   223  		w.Write([]byte(backendResponse))
   224  	}))
   225  	defer backend.Close()
   226  	backendURL, err := url.Parse(backend.URL)
   227  	if err != nil {
   228  		t.Fatal(err)
   229  	}
   230  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   231  	frontend := httptest.NewServer(proxyHandler)
   232  	defer frontend.Close()
   233  
   234  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   235  	getReq.Host = "some-name"
   236  	getReq.Header.Set("Connection", "close")
   237  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   238  	getReq.Close = true
   239  	res, err := frontend.Client().Do(getReq)
   240  	if err != nil {
   241  		t.Fatalf("Get: %v", err)
   242  	}
   243  	if g, e := res.StatusCode, backendStatus; g != e {
   244  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   245  	}
   246  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   247  	if g, e := string(bodyBytes), backendResponse; g != e {
   248  		t.Errorf("got body %q; expected %q", g, e)
   249  	}
   250  }
   251  
   252  var proxyQueryTests = []struct {
   253  	baseSuffix string // suffix to add to backend URL
   254  	reqSuffix  string // suffix to add to frontend's request URL
   255  	want       string // what backend should see for final request URL (without ?)
   256  }{
   257  	{"", "", ""},
   258  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   259  	{"", "?us=er", "us=er"},
   260  	{"?sta=tic", "", "sta=tic"},
   261  }
   262  
   263  func TestReverseProxyQuery(t *testing.T) {
   264  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   265  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   266  		w.Write([]byte("hi"))
   267  	}))
   268  	defer backend.Close()
   269  
   270  	for i, tt := range proxyQueryTests {
   271  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   272  		if err != nil {
   273  			t.Fatal(err)
   274  		}
   275  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   276  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   277  		req.Close = true
   278  		res, err := frontend.Client().Do(req)
   279  		if err != nil {
   280  			t.Fatalf("%d. Get: %v", i, err)
   281  		}
   282  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   283  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   284  		}
   285  		res.Body.Close()
   286  		frontend.Close()
   287  	}
   288  }
   289  
   290  func TestReverseProxyFlushInterval(t *testing.T) {
   291  	const expected = "hi"
   292  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   293  		w.Write([]byte(expected))
   294  	}))
   295  	defer backend.Close()
   296  
   297  	backendURL, err := url.Parse(backend.URL)
   298  	if err != nil {
   299  		t.Fatal(err)
   300  	}
   301  
   302  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   303  	proxyHandler.FlushInterval = time.Microsecond
   304  
   305  	frontend := httptest.NewServer(proxyHandler)
   306  	defer frontend.Close()
   307  
   308  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   309  	req.Close = true
   310  	res, err := frontend.Client().Do(req)
   311  	if err != nil {
   312  		t.Fatalf("Get: %v", err)
   313  	}
   314  	defer res.Body.Close()
   315  	if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
   316  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   317  	}
   318  }
   319  
   320  func TestReverseProxyCancelation(t *testing.T) {
   321  	const backendResponse = "I am the backend"
   322  
   323  	reqInFlight := make(chan struct{})
   324  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   325  		close(reqInFlight) // cause the client to cancel its request
   326  
   327  		select {
   328  		case <-time.After(10 * time.Second):
   329  			// Note: this should only happen in broken implementations, and the
   330  			// closenotify case should be instantaneous.
   331  			t.Error("Handler never saw CloseNotify")
   332  			return
   333  		case <-w.(http.CloseNotifier).CloseNotify():
   334  		}
   335  
   336  		w.WriteHeader(http.StatusOK)
   337  		w.Write([]byte(backendResponse))
   338  	}))
   339  
   340  	defer backend.Close()
   341  
   342  	backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
   343  
   344  	backendURL, err := url.Parse(backend.URL)
   345  	if err != nil {
   346  		t.Fatal(err)
   347  	}
   348  
   349  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   350  
   351  	// Discards errors of the form:
   352  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   353  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
   354  
   355  	frontend := httptest.NewServer(proxyHandler)
   356  	defer frontend.Close()
   357  	frontendClient := frontend.Client()
   358  
   359  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   360  	go func() {
   361  		<-reqInFlight
   362  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   363  	}()
   364  	res, err := frontendClient.Do(getReq)
   365  	if res != nil {
   366  		t.Errorf("got response %v; want nil", res.Status)
   367  	}
   368  	if err == nil {
   369  		// This should be an error like:
   370  		// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
   371  		//    use of closed network connection
   372  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   373  	}
   374  }
   375  
   376  func req(t *testing.T, v string) *http.Request {
   377  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  	return req
   382  }
   383  
   384  // Issue 12344
   385  func TestNilBody(t *testing.T) {
   386  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   387  		w.Write([]byte("hi"))
   388  	}))
   389  	defer backend.Close()
   390  
   391  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   392  		backURL, _ := url.Parse(backend.URL)
   393  		rp := NewSingleHostReverseProxy(backURL)
   394  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   395  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   396  		rp.ServeHTTP(w, r)
   397  	}))
   398  	defer frontend.Close()
   399  
   400  	res, err := http.Get(frontend.URL)
   401  	if err != nil {
   402  		t.Fatal(err)
   403  	}
   404  	defer res.Body.Close()
   405  	slurp, err := ioutil.ReadAll(res.Body)
   406  	if err != nil {
   407  		t.Fatal(err)
   408  	}
   409  	if string(slurp) != "hi" {
   410  		t.Errorf("Got %q; want %q", slurp, "hi")
   411  	}
   412  }
   413  
   414  // Issue 15524
   415  func TestUserAgentHeader(t *testing.T) {
   416  	const explicitUA = "explicit UA"
   417  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   418  		if r.URL.Path == "/noua" {
   419  			if c := r.Header.Get("User-Agent"); c != "" {
   420  				t.Errorf("handler got non-empty User-Agent header %q", c)
   421  			}
   422  			return
   423  		}
   424  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   425  			t.Errorf("handler got unexpected User-Agent header %q", c)
   426  		}
   427  	}))
   428  	defer backend.Close()
   429  	backendURL, err := url.Parse(backend.URL)
   430  	if err != nil {
   431  		t.Fatal(err)
   432  	}
   433  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   434  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   435  	frontend := httptest.NewServer(proxyHandler)
   436  	defer frontend.Close()
   437  	frontendClient := frontend.Client()
   438  
   439  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   440  	getReq.Header.Set("User-Agent", explicitUA)
   441  	getReq.Close = true
   442  	res, err := frontendClient.Do(getReq)
   443  	if err != nil {
   444  		t.Fatalf("Get: %v", err)
   445  	}
   446  	res.Body.Close()
   447  
   448  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   449  	getReq.Header.Set("User-Agent", "")
   450  	getReq.Close = true
   451  	res, err = frontendClient.Do(getReq)
   452  	if err != nil {
   453  		t.Fatalf("Get: %v", err)
   454  	}
   455  	res.Body.Close()
   456  }
   457  
   458  type bufferPool struct {
   459  	get func() []byte
   460  	put func([]byte)
   461  }
   462  
   463  func (bp bufferPool) Get() []byte  { return bp.get() }
   464  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   465  
   466  func TestReverseProxyGetPutBuffer(t *testing.T) {
   467  	const msg = "hi"
   468  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   469  		io.WriteString(w, msg)
   470  	}))
   471  	defer backend.Close()
   472  
   473  	backendURL, err := url.Parse(backend.URL)
   474  	if err != nil {
   475  		t.Fatal(err)
   476  	}
   477  
   478  	var (
   479  		mu  sync.Mutex
   480  		log []string
   481  	)
   482  	addLog := func(event string) {
   483  		mu.Lock()
   484  		defer mu.Unlock()
   485  		log = append(log, event)
   486  	}
   487  	rp := NewSingleHostReverseProxy(backendURL)
   488  	const size = 1234
   489  	rp.BufferPool = bufferPool{
   490  		get: func() []byte {
   491  			addLog("getBuf")
   492  			return make([]byte, size)
   493  		},
   494  		put: func(p []byte) {
   495  			addLog("putBuf-" + strconv.Itoa(len(p)))
   496  		},
   497  	}
   498  	frontend := httptest.NewServer(rp)
   499  	defer frontend.Close()
   500  
   501  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   502  	req.Close = true
   503  	res, err := frontend.Client().Do(req)
   504  	if err != nil {
   505  		t.Fatalf("Get: %v", err)
   506  	}
   507  	slurp, err := ioutil.ReadAll(res.Body)
   508  	res.Body.Close()
   509  	if err != nil {
   510  		t.Fatalf("reading body: %v", err)
   511  	}
   512  	if string(slurp) != msg {
   513  		t.Errorf("msg = %q; want %q", slurp, msg)
   514  	}
   515  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   516  	mu.Lock()
   517  	defer mu.Unlock()
   518  	if !reflect.DeepEqual(log, wantLog) {
   519  		t.Errorf("Log events = %q; want %q", log, wantLog)
   520  	}
   521  }
   522  
   523  func TestReverseProxy_Post(t *testing.T) {
   524  	const backendResponse = "I am the backend"
   525  	const backendStatus = 200
   526  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   527  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   528  		slurp, err := ioutil.ReadAll(r.Body)
   529  		if err != nil {
   530  			t.Errorf("Backend body read = %v", err)
   531  		}
   532  		if len(slurp) != len(requestBody) {
   533  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   534  		}
   535  		if !bytes.Equal(slurp, requestBody) {
   536  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   537  		}
   538  		w.Write([]byte(backendResponse))
   539  	}))
   540  	defer backend.Close()
   541  	backendURL, err := url.Parse(backend.URL)
   542  	if err != nil {
   543  		t.Fatal(err)
   544  	}
   545  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   546  	frontend := httptest.NewServer(proxyHandler)
   547  	defer frontend.Close()
   548  
   549  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   550  	res, err := frontend.Client().Do(postReq)
   551  	if err != nil {
   552  		t.Fatalf("Do: %v", err)
   553  	}
   554  	if g, e := res.StatusCode, backendStatus; g != e {
   555  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   556  	}
   557  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   558  	if g, e := string(bodyBytes), backendResponse; g != e {
   559  		t.Errorf("got body %q; expected %q", g, e)
   560  	}
   561  }
   562  
   563  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   564  
   565  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   566  	return fn(req)
   567  }
   568  
   569  // Issue 16036: send a Request with a nil Body when possible
   570  func TestReverseProxy_NilBody(t *testing.T) {
   571  	backendURL, _ := url.Parse("http://fake.tld/")
   572  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   573  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   574  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   575  		if req.Body != nil {
   576  			t.Error("Body != nil; want a nil Body")
   577  		}
   578  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   579  	})
   580  	frontend := httptest.NewServer(proxyHandler)
   581  	defer frontend.Close()
   582  
   583  	res, err := frontend.Client().Get(frontend.URL)
   584  	if err != nil {
   585  		t.Fatal(err)
   586  	}
   587  	defer res.Body.Close()
   588  	if res.StatusCode != 502 {
   589  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   590  	}
   591  }
   592  
   593  // Issue 14237. Test ModifyResponse and that an error from it
   594  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   595  func TestReverseProxyModifyResponse(t *testing.T) {
   596  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   597  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   598  	}))
   599  	defer backendServer.Close()
   600  
   601  	rpURL, _ := url.Parse(backendServer.URL)
   602  	rproxy := NewSingleHostReverseProxy(rpURL)
   603  	rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   604  	rproxy.ModifyResponse = func(resp *http.Response) error {
   605  		if resp.Header.Get("X-Hit-Mod") != "true" {
   606  			return fmt.Errorf("tried to by-pass proxy")
   607  		}
   608  		return nil
   609  	}
   610  
   611  	frontendProxy := httptest.NewServer(rproxy)
   612  	defer frontendProxy.Close()
   613  
   614  	tests := []struct {
   615  		url      string
   616  		wantCode int
   617  	}{
   618  		{frontendProxy.URL + "/mod", http.StatusOK},
   619  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   620  	}
   621  
   622  	for i, tt := range tests {
   623  		resp, err := http.Get(tt.url)
   624  		if err != nil {
   625  			t.Fatalf("failed to reach proxy: %v", err)
   626  		}
   627  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   628  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   629  		}
   630  		resp.Body.Close()
   631  	}
   632  }
   633  
   634  type failingRoundTripper struct{}
   635  
   636  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   637  	return nil, errors.New("some error")
   638  }
   639  
   640  type staticResponseRoundTripper struct{ res *http.Response }
   641  
   642  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   643  	return rt.res, nil
   644  }
   645  
   646  func TestReverseProxyErrorHandler(t *testing.T) {
   647  	tests := []struct {
   648  		name           string
   649  		wantCode       int
   650  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   651  		transport      http.RoundTripper // defaults to failingRoundTripper
   652  		modifyResponse func(*http.Response) error
   653  	}{
   654  		{
   655  			name:     "default",
   656  			wantCode: http.StatusBadGateway,
   657  		},
   658  		{
   659  			name:         "errorhandler",
   660  			wantCode:     http.StatusTeapot,
   661  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   662  		},
   663  		{
   664  			name: "modifyresponse_noerr",
   665  			transport: staticResponseRoundTripper{
   666  				&http.Response{StatusCode: 345, Body: http.NoBody},
   667  			},
   668  			modifyResponse: func(res *http.Response) error {
   669  				res.StatusCode++
   670  				return nil
   671  			},
   672  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   673  			wantCode:     346,
   674  		},
   675  		{
   676  			name: "modifyresponse_err",
   677  			transport: staticResponseRoundTripper{
   678  				&http.Response{StatusCode: 345, Body: http.NoBody},
   679  			},
   680  			modifyResponse: func(res *http.Response) error {
   681  				res.StatusCode++
   682  				return errors.New("some error to trigger errorHandler")
   683  			},
   684  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   685  			wantCode:     http.StatusTeapot,
   686  		},
   687  	}
   688  
   689  	for _, tt := range tests {
   690  		t.Run(tt.name, func(t *testing.T) {
   691  			target := &url.URL{
   692  				Scheme: "http",
   693  				Host:   "dummy.tld",
   694  				Path:   "/",
   695  			}
   696  			rproxy := NewSingleHostReverseProxy(target)
   697  			rproxy.Transport = tt.transport
   698  			rproxy.ModifyResponse = tt.modifyResponse
   699  			if rproxy.Transport == nil {
   700  				rproxy.Transport = failingRoundTripper{}
   701  			}
   702  			rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   703  			if tt.errorHandler != nil {
   704  				rproxy.ErrorHandler = tt.errorHandler
   705  			}
   706  			frontendProxy := httptest.NewServer(rproxy)
   707  			defer frontendProxy.Close()
   708  
   709  			resp, err := http.Get(frontendProxy.URL + "/test")
   710  			if err != nil {
   711  				t.Fatalf("failed to reach proxy: %v", err)
   712  			}
   713  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   714  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   715  			}
   716  			resp.Body.Close()
   717  		})
   718  	}
   719  }
   720  
   721  // Issue 16659: log errors from short read
   722  func TestReverseProxy_CopyBuffer(t *testing.T) {
   723  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   724  		out := "this call was relayed by the reverse proxy"
   725  		// Coerce a wrong content length to induce io.UnexpectedEOF
   726  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   727  		fmt.Fprintln(w, out)
   728  	}))
   729  	defer backendServer.Close()
   730  
   731  	rpURL, err := url.Parse(backendServer.URL)
   732  	if err != nil {
   733  		t.Fatal(err)
   734  	}
   735  
   736  	var proxyLog bytes.Buffer
   737  	rproxy := NewSingleHostReverseProxy(rpURL)
   738  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   739  	donec := make(chan bool, 1)
   740  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   741  		defer func() { donec <- true }()
   742  		rproxy.ServeHTTP(w, r)
   743  	}))
   744  	defer frontendProxy.Close()
   745  
   746  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   747  		t.Fatalf("want non-nil error")
   748  	}
   749  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   750  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   751  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   752  	// continue after Get.
   753  	<-donec
   754  
   755  	expected := []string{
   756  		"EOF",
   757  		"read",
   758  	}
   759  	for _, phrase := range expected {
   760  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   761  			t.Errorf("expected log to contain phrase %q", phrase)
   762  		}
   763  	}
   764  }
   765  
   766  type staticTransport struct {
   767  	res *http.Response
   768  }
   769  
   770  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   771  	return t.res, nil
   772  }
   773  
   774  func BenchmarkServeHTTP(b *testing.B) {
   775  	res := &http.Response{
   776  		StatusCode: 200,
   777  		Body:       ioutil.NopCloser(strings.NewReader("")),
   778  	}
   779  	proxy := &ReverseProxy{
   780  		Director:  func(*http.Request) {},
   781  		Transport: &staticTransport{res},
   782  	}
   783  
   784  	w := httptest.NewRecorder()
   785  	r := httptest.NewRequest("GET", "/", nil)
   786  
   787  	b.ReportAllocs()
   788  	for i := 0; i < b.N; i++ {
   789  		proxy.ServeHTTP(w, r)
   790  	}
   791  }
   792  
   793  func TestServeHTTPDeepCopy(t *testing.T) {
   794  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   795  		w.Write([]byte("Hello Gopher!"))
   796  	}))
   797  	defer backend.Close()
   798  	backendURL, err := url.Parse(backend.URL)
   799  	if err != nil {
   800  		t.Fatal(err)
   801  	}
   802  
   803  	type result struct {
   804  		before, after string
   805  	}
   806  
   807  	resultChan := make(chan result, 1)
   808  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   809  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   810  		before := r.URL.String()
   811  		proxyHandler.ServeHTTP(w, r)
   812  		after := r.URL.String()
   813  		resultChan <- result{before: before, after: after}
   814  	}))
   815  	defer frontend.Close()
   816  
   817  	want := result{before: "/", after: "/"}
   818  
   819  	res, err := frontend.Client().Get(frontend.URL)
   820  	if err != nil {
   821  		t.Fatalf("Do: %v", err)
   822  	}
   823  	res.Body.Close()
   824  
   825  	got := <-resultChan
   826  	if got != want {
   827  		t.Errorf("got = %+v; want = %+v", got, want)
   828  	}
   829  }
   830  
   831  // Issue 18327: verify we always do a deep copy of the Request.Header map
   832  // before any mutations.
   833  func TestClonesRequestHeaders(t *testing.T) {
   834  	log.SetOutput(ioutil.Discard)
   835  	defer log.SetOutput(os.Stderr)
   836  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   837  	req.RemoteAddr = "1.2.3.4:56789"
   838  	rp := &ReverseProxy{
   839  		Director: func(req *http.Request) {
   840  			req.Header.Set("From-Director", "1")
   841  		},
   842  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   843  			if v := req.Header.Get("From-Director"); v != "1" {
   844  				t.Errorf("From-Directory value = %q; want 1", v)
   845  			}
   846  			return nil, io.EOF
   847  		}),
   848  	}
   849  	rp.ServeHTTP(httptest.NewRecorder(), req)
   850  
   851  	if req.Header.Get("From-Director") == "1" {
   852  		t.Error("Director header mutation modified caller's request")
   853  	}
   854  	if req.Header.Get("X-Forwarded-For") != "" {
   855  		t.Error("X-Forward-For header mutation modified caller's request")
   856  	}
   857  
   858  }
   859  
   860  type roundTripperFunc func(req *http.Request) (*http.Response, error)
   861  
   862  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   863  	return fn(req)
   864  }
   865  
   866  func TestModifyResponseClosesBody(t *testing.T) {
   867  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   868  	req.RemoteAddr = "1.2.3.4:56789"
   869  	closeCheck := new(checkCloser)
   870  	logBuf := new(bytes.Buffer)
   871  	outErr := errors.New("ModifyResponse error")
   872  	rp := &ReverseProxy{
   873  		Director: func(req *http.Request) {},
   874  		Transport: &staticTransport{&http.Response{
   875  			StatusCode: 200,
   876  			Body:       closeCheck,
   877  		}},
   878  		ErrorLog: log.New(logBuf, "", 0),
   879  		ModifyResponse: func(*http.Response) error {
   880  			return outErr
   881  		},
   882  	}
   883  	rec := httptest.NewRecorder()
   884  	rp.ServeHTTP(rec, req)
   885  	res := rec.Result()
   886  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
   887  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   888  	}
   889  	if !closeCheck.closed {
   890  		t.Errorf("body should have been closed")
   891  	}
   892  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
   893  		t.Errorf("ErrorLog %q does not contain %q", g, e)
   894  	}
   895  }
   896  
   897  type checkCloser struct {
   898  	closed bool
   899  }
   900  
   901  func (cc *checkCloser) Close() error {
   902  	cc.closed = true
   903  	return nil
   904  }
   905  
   906  func (cc *checkCloser) Read(b []byte) (int, error) {
   907  	return len(b), nil
   908  }
   909  
   910  // Issue 23643: panic on body copy error
   911  func TestReverseProxy_PanicBodyError(t *testing.T) {
   912  	log.SetOutput(ioutil.Discard)
   913  	defer log.SetOutput(os.Stderr)
   914  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   915  		out := "this call was relayed by the reverse proxy"
   916  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
   917  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   918  		fmt.Fprintln(w, out)
   919  	}))
   920  	defer backendServer.Close()
   921  
   922  	rpURL, err := url.Parse(backendServer.URL)
   923  	if err != nil {
   924  		t.Fatal(err)
   925  	}
   926  
   927  	rproxy := NewSingleHostReverseProxy(rpURL)
   928  
   929  	// Ensure that the handler panics when the body read encounters an
   930  	// io.ErrUnexpectedEOF
   931  	defer func() {
   932  		err := recover()
   933  		if err == nil {
   934  			t.Fatal("handler should have panicked")
   935  		}
   936  		if err != http.ErrAbortHandler {
   937  			t.Fatal("expected ErrAbortHandler, got", err)
   938  		}
   939  	}()
   940  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   941  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
   942  }
   943  
   944  func TestSelectFlushInterval(t *testing.T) {
   945  	tests := []struct {
   946  		name string
   947  		p    *ReverseProxy
   948  		req  *http.Request
   949  		res  *http.Response
   950  		want time.Duration
   951  	}{
   952  		{
   953  			name: "default",
   954  			res:  &http.Response{},
   955  			p:    &ReverseProxy{FlushInterval: 123},
   956  			want: 123,
   957  		},
   958  		{
   959  			name: "server-sent events overrides non-zero",
   960  			res: &http.Response{
   961  				Header: http.Header{
   962  					"Content-Type": {"text/event-stream"},
   963  				},
   964  			},
   965  			p:    &ReverseProxy{FlushInterval: 123},
   966  			want: -1,
   967  		},
   968  		{
   969  			name: "server-sent events overrides zero",
   970  			res: &http.Response{
   971  				Header: http.Header{
   972  					"Content-Type": {"text/event-stream"},
   973  				},
   974  			},
   975  			p:    &ReverseProxy{FlushInterval: 0},
   976  			want: -1,
   977  		},
   978  	}
   979  	for _, tt := range tests {
   980  		t.Run(tt.name, func(t *testing.T) {
   981  			got := tt.p.flushInterval(tt.req, tt.res)
   982  			if got != tt.want {
   983  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
   984  			}
   985  		})
   986  	}
   987  }
   988  
   989  func TestReverseProxyWebSocket(t *testing.T) {
   990  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   991  		if upgradeType(r.Header) != "websocket" {
   992  			t.Error("unexpected backend request")
   993  			http.Error(w, "unexpected request", 400)
   994  			return
   995  		}
   996  		c, _, err := w.(http.Hijacker).Hijack()
   997  		if err != nil {
   998  			t.Error(err)
   999  			return
  1000  		}
  1001  		defer c.Close()
  1002  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1003  		bs := bufio.NewScanner(c)
  1004  		if !bs.Scan() {
  1005  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1006  			return
  1007  		}
  1008  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1009  	}))
  1010  	defer backendServer.Close()
  1011  
  1012  	backURL, _ := url.Parse(backendServer.URL)
  1013  	rproxy := NewSingleHostReverseProxy(backURL)
  1014  	rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
  1015  	rproxy.ModifyResponse = func(res *http.Response) error {
  1016  		res.Header.Add("X-Modified", "true")
  1017  		return nil
  1018  	}
  1019  
  1020  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1021  		rw.Header().Set("X-Header", "X-Value")
  1022  		rproxy.ServeHTTP(rw, req)
  1023  	})
  1024  
  1025  	frontendProxy := httptest.NewServer(handler)
  1026  	defer frontendProxy.Close()
  1027  
  1028  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1029  	req.Header.Set("Connection", "Upgrade")
  1030  	req.Header.Set("Upgrade", "websocket")
  1031  
  1032  	c := frontendProxy.Client()
  1033  	res, err := c.Do(req)
  1034  	if err != nil {
  1035  		t.Fatal(err)
  1036  	}
  1037  	if res.StatusCode != 101 {
  1038  		t.Fatalf("status = %v; want 101", res.Status)
  1039  	}
  1040  
  1041  	got := res.Header.Get("X-Header")
  1042  	want := "X-Value"
  1043  	if got != want {
  1044  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1045  	}
  1046  
  1047  	if upgradeType(res.Header) != "websocket" {
  1048  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1049  	}
  1050  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1051  	if !ok {
  1052  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1053  	}
  1054  	defer rwc.Close()
  1055  
  1056  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1057  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1058  	}
  1059  
  1060  	io.WriteString(rwc, "Hello\n")
  1061  	bs := bufio.NewScanner(rwc)
  1062  	if !bs.Scan() {
  1063  		t.Fatalf("Scan: %v", bs.Err())
  1064  	}
  1065  	got = bs.Text()
  1066  	want = `backend got "Hello"`
  1067  	if got != want {
  1068  		t.Errorf("got %#q, want %#q", got, want)
  1069  	}
  1070  }
  1071  
  1072  func TestUnannouncedTrailer(t *testing.T) {
  1073  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1074  		w.WriteHeader(http.StatusOK)
  1075  		w.(http.Flusher).Flush()
  1076  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1077  	}))
  1078  	defer backend.Close()
  1079  	backendURL, err := url.Parse(backend.URL)
  1080  	if err != nil {
  1081  		t.Fatal(err)
  1082  	}
  1083  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1084  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
  1085  	frontend := httptest.NewServer(proxyHandler)
  1086  	defer frontend.Close()
  1087  	frontendClient := frontend.Client()
  1088  
  1089  	res, err := frontendClient.Get(frontend.URL)
  1090  	if err != nil {
  1091  		t.Fatalf("Get: %v", err)
  1092  	}
  1093  
  1094  	ioutil.ReadAll(res.Body)
  1095  
  1096  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1097  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1098  	}
  1099  
  1100  }
  1101  
  1102  func TestSingleJoinSlash(t *testing.T) {
  1103  	tests := []struct {
  1104  		slasha   string
  1105  		slashb   string
  1106  		expected string
  1107  	}{
  1108  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1109  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1110  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1111  		{"https://www.google.com", "", "https://www.google.com/"},
  1112  		{"", "favicon.ico", "/favicon.ico"},
  1113  	}
  1114  	for _, tt := range tests {
  1115  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1116  			t.Errorf("singleJoiningSlash(%s,%s) want %s got %s",
  1117  				tt.slasha,
  1118  				tt.slashb,
  1119  				tt.expected,
  1120  				got)
  1121  		}
  1122  	}
  1123  }