github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/net/http/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/url"
    20  	"os"
    21  	"reflect"
    22  	"sort"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  	"testing"
    27  	"time"
    28  )
    29  
    30  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    31  
    32  func init() {
    33  	inOurTests = true
    34  	hopHeaders = append(hopHeaders, fakeHopHeader)
    35  }
    36  
    37  func TestReverseProxy(t *testing.T) {
    38  	const backendResponse = "I am the backend"
    39  	const backendStatus = 404
    40  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    41  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    42  			c, _, _ := w.(http.Hijacker).Hijack()
    43  			c.Close()
    44  			return
    45  		}
    46  		if len(r.TransferEncoding) > 0 {
    47  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    48  		}
    49  		if r.Header.Get("X-Forwarded-For") == "" {
    50  			t.Errorf("didn't get X-Forwarded-For header")
    51  		}
    52  		if c := r.Header.Get("Connection"); c != "" {
    53  			t.Errorf("handler got Connection header value %q", c)
    54  		}
    55  		if c := r.Header.Get("Te"); c != "trailers" {
    56  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    57  		}
    58  		if c := r.Header.Get("Upgrade"); c != "" {
    59  			t.Errorf("handler got Upgrade header value %q", c)
    60  		}
    61  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    62  			t.Errorf("handler got Proxy-Connection header value %q", c)
    63  		}
    64  		if g, e := r.Host, "some-name"; g != e {
    65  			t.Errorf("backend got Host header %q, want %q", g, e)
    66  		}
    67  		w.Header().Set("Trailers", "not a special header field name")
    68  		w.Header().Set("Trailer", "X-Trailer")
    69  		w.Header().Set("X-Foo", "bar")
    70  		w.Header().Set("Upgrade", "foo")
    71  		w.Header().Set(fakeHopHeader, "foo")
    72  		w.Header().Add("X-Multi-Value", "foo")
    73  		w.Header().Add("X-Multi-Value", "bar")
    74  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    75  		w.WriteHeader(backendStatus)
    76  		w.Write([]byte(backendResponse))
    77  		w.Header().Set("X-Trailer", "trailer_value")
    78  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    79  	}))
    80  	defer backend.Close()
    81  	backendURL, err := url.Parse(backend.URL)
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    86  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    87  	frontend := httptest.NewServer(proxyHandler)
    88  	defer frontend.Close()
    89  	frontendClient := frontend.Client()
    90  
    91  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    92  	getReq.Host = "some-name"
    93  	getReq.Header.Set("Connection", "close")
    94  	getReq.Header.Set("Te", "trailers")
    95  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    96  	getReq.Header.Set("Upgrade", "foo")
    97  	getReq.Close = true
    98  	res, err := frontendClient.Do(getReq)
    99  	if err != nil {
   100  		t.Fatalf("Get: %v", err)
   101  	}
   102  	if g, e := res.StatusCode, backendStatus; g != e {
   103  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   104  	}
   105  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   106  		t.Errorf("got X-Foo %q; expected %q", g, e)
   107  	}
   108  	if c := res.Header.Get(fakeHopHeader); c != "" {
   109  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   110  	}
   111  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   112  		t.Errorf("header Trailers = %q; want %q", g, e)
   113  	}
   114  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   115  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   116  	}
   117  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   118  		t.Fatalf("got %d SetCookies, want %d", g, e)
   119  	}
   120  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   121  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   122  	}
   123  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   124  		t.Errorf("unexpected cookie %q", cookie.Name)
   125  	}
   126  	bodyBytes, _ := io.ReadAll(res.Body)
   127  	if g, e := string(bodyBytes), backendResponse; g != e {
   128  		t.Errorf("got body %q; expected %q", g, e)
   129  	}
   130  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   131  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   132  	}
   133  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   134  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   135  	}
   136  
   137  	// Test that a backend failing to be reached or one which doesn't return
   138  	// a response results in a StatusBadGateway.
   139  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   140  	getReq.Close = true
   141  	res, err = frontendClient.Do(getReq)
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  	res.Body.Close()
   146  	if res.StatusCode != http.StatusBadGateway {
   147  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   148  	}
   149  
   150  }
   151  
   152  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   153  // header value.
   154  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   155  	const fakeConnectionToken = "X-Fake-Connection-Token"
   156  	const backendResponse = "I am the backend"
   157  
   158  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   159  	// in the Request's Connection header.
   160  	const someConnHeader = "X-Some-Conn-Header"
   161  
   162  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   163  		if c := r.Header.Get("Connection"); c != "" {
   164  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   165  		}
   166  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   167  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   168  		}
   169  		if c := r.Header.Get(someConnHeader); c != "" {
   170  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   171  		}
   172  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   173  		w.Header().Add("Connection", someConnHeader)
   174  		w.Header().Set(someConnHeader, "should be deleted")
   175  		w.Header().Set(fakeConnectionToken, "should be deleted")
   176  		io.WriteString(w, backendResponse)
   177  	}))
   178  	defer backend.Close()
   179  	backendURL, err := url.Parse(backend.URL)
   180  	if err != nil {
   181  		t.Fatal(err)
   182  	}
   183  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   184  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   185  		proxyHandler.ServeHTTP(w, r)
   186  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   187  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   188  		}
   189  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   190  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   191  		}
   192  		c := r.Header["Connection"]
   193  		var cf []string
   194  		for _, f := range c {
   195  			for _, sf := range strings.Split(f, ",") {
   196  				if sf = strings.TrimSpace(sf); sf != "" {
   197  					cf = append(cf, sf)
   198  				}
   199  			}
   200  		}
   201  		sort.Strings(cf)
   202  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   203  		sort.Strings(expectedValues)
   204  		if !reflect.DeepEqual(cf, expectedValues) {
   205  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   206  		}
   207  	}))
   208  	defer frontend.Close()
   209  
   210  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   211  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   212  	getReq.Header.Add("Connection", someConnHeader)
   213  	getReq.Header.Set(someConnHeader, "should be deleted")
   214  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   215  	res, err := frontend.Client().Do(getReq)
   216  	if err != nil {
   217  		t.Fatalf("Get: %v", err)
   218  	}
   219  	defer res.Body.Close()
   220  	bodyBytes, err := io.ReadAll(res.Body)
   221  	if err != nil {
   222  		t.Fatalf("reading body: %v", err)
   223  	}
   224  	if got, want := string(bodyBytes), backendResponse; got != want {
   225  		t.Errorf("got body %q; want %q", got, want)
   226  	}
   227  	if c := res.Header.Get("Connection"); c != "" {
   228  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   229  	}
   230  	if c := res.Header.Get(someConnHeader); c != "" {
   231  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   232  	}
   233  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   234  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   235  	}
   236  }
   237  
   238  func TestXForwardedFor(t *testing.T) {
   239  	const prevForwardedFor = "client ip"
   240  	const backendResponse = "I am the backend"
   241  	const backendStatus = 404
   242  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   243  		if r.Header.Get("X-Forwarded-For") == "" {
   244  			t.Errorf("didn't get X-Forwarded-For header")
   245  		}
   246  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   247  			t.Errorf("X-Forwarded-For didn't contain prior data")
   248  		}
   249  		w.WriteHeader(backendStatus)
   250  		w.Write([]byte(backendResponse))
   251  	}))
   252  	defer backend.Close()
   253  	backendURL, err := url.Parse(backend.URL)
   254  	if err != nil {
   255  		t.Fatal(err)
   256  	}
   257  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   258  	frontend := httptest.NewServer(proxyHandler)
   259  	defer frontend.Close()
   260  
   261  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   262  	getReq.Host = "some-name"
   263  	getReq.Header.Set("Connection", "close")
   264  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   265  	getReq.Close = true
   266  	res, err := frontend.Client().Do(getReq)
   267  	if err != nil {
   268  		t.Fatalf("Get: %v", err)
   269  	}
   270  	if g, e := res.StatusCode, backendStatus; g != e {
   271  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   272  	}
   273  	bodyBytes, _ := io.ReadAll(res.Body)
   274  	if g, e := string(bodyBytes), backendResponse; g != e {
   275  		t.Errorf("got body %q; expected %q", g, e)
   276  	}
   277  }
   278  
   279  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   280  func TestXForwardedFor_Omit(t *testing.T) {
   281  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   282  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   283  			t.Errorf("got X-Forwarded-For header: %q", v)
   284  		}
   285  		w.Write([]byte("hi"))
   286  	}))
   287  	defer backend.Close()
   288  	backendURL, err := url.Parse(backend.URL)
   289  	if err != nil {
   290  		t.Fatal(err)
   291  	}
   292  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   293  	frontend := httptest.NewServer(proxyHandler)
   294  	defer frontend.Close()
   295  
   296  	oldDirector := proxyHandler.Director
   297  	proxyHandler.Director = func(r *http.Request) {
   298  		r.Header["X-Forwarded-For"] = nil
   299  		oldDirector(r)
   300  	}
   301  
   302  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   303  	getReq.Host = "some-name"
   304  	getReq.Close = true
   305  	res, err := frontend.Client().Do(getReq)
   306  	if err != nil {
   307  		t.Fatalf("Get: %v", err)
   308  	}
   309  	res.Body.Close()
   310  }
   311  
   312  var proxyQueryTests = []struct {
   313  	baseSuffix string // suffix to add to backend URL
   314  	reqSuffix  string // suffix to add to frontend's request URL
   315  	want       string // what backend should see for final request URL (without ?)
   316  }{
   317  	{"", "", ""},
   318  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   319  	{"", "?us=er", "us=er"},
   320  	{"?sta=tic", "", "sta=tic"},
   321  }
   322  
   323  func TestReverseProxyQuery(t *testing.T) {
   324  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   325  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   326  		w.Write([]byte("hi"))
   327  	}))
   328  	defer backend.Close()
   329  
   330  	for i, tt := range proxyQueryTests {
   331  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   332  		if err != nil {
   333  			t.Fatal(err)
   334  		}
   335  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   336  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   337  		req.Close = true
   338  		res, err := frontend.Client().Do(req)
   339  		if err != nil {
   340  			t.Fatalf("%d. Get: %v", i, err)
   341  		}
   342  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   343  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   344  		}
   345  		res.Body.Close()
   346  		frontend.Close()
   347  	}
   348  }
   349  
   350  func TestReverseProxyFlushInterval(t *testing.T) {
   351  	const expected = "hi"
   352  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   353  		w.Write([]byte(expected))
   354  	}))
   355  	defer backend.Close()
   356  
   357  	backendURL, err := url.Parse(backend.URL)
   358  	if err != nil {
   359  		t.Fatal(err)
   360  	}
   361  
   362  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   363  	proxyHandler.FlushInterval = time.Microsecond
   364  
   365  	frontend := httptest.NewServer(proxyHandler)
   366  	defer frontend.Close()
   367  
   368  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   369  	req.Close = true
   370  	res, err := frontend.Client().Do(req)
   371  	if err != nil {
   372  		t.Fatalf("Get: %v", err)
   373  	}
   374  	defer res.Body.Close()
   375  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   376  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   377  	}
   378  }
   379  
   380  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   381  	const expected = "hi"
   382  	stopCh := make(chan struct{})
   383  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   384  		w.Header().Add("MyHeader", expected)
   385  		w.WriteHeader(200)
   386  		w.(http.Flusher).Flush()
   387  		<-stopCh
   388  	}))
   389  	defer backend.Close()
   390  	defer close(stopCh)
   391  
   392  	backendURL, err := url.Parse(backend.URL)
   393  	if err != nil {
   394  		t.Fatal(err)
   395  	}
   396  
   397  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   398  	proxyHandler.FlushInterval = time.Microsecond
   399  
   400  	frontend := httptest.NewServer(proxyHandler)
   401  	defer frontend.Close()
   402  
   403  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   404  	req.Close = true
   405  
   406  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   407  	defer cancel()
   408  	req = req.WithContext(ctx)
   409  
   410  	res, err := frontend.Client().Do(req)
   411  	if err != nil {
   412  		t.Fatalf("Get: %v", err)
   413  	}
   414  	defer res.Body.Close()
   415  
   416  	if res.Header.Get("MyHeader") != expected {
   417  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   418  	}
   419  }
   420  
   421  func TestReverseProxyCancellation(t *testing.T) {
   422  	const backendResponse = "I am the backend"
   423  
   424  	reqInFlight := make(chan struct{})
   425  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   426  		close(reqInFlight) // cause the client to cancel its request
   427  
   428  		select {
   429  		case <-time.After(10 * time.Second):
   430  			// Note: this should only happen in broken implementations, and the
   431  			// closenotify case should be instantaneous.
   432  			t.Error("Handler never saw CloseNotify")
   433  			return
   434  		case <-w.(http.CloseNotifier).CloseNotify():
   435  		}
   436  
   437  		w.WriteHeader(http.StatusOK)
   438  		w.Write([]byte(backendResponse))
   439  	}))
   440  
   441  	defer backend.Close()
   442  
   443  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   444  
   445  	backendURL, err := url.Parse(backend.URL)
   446  	if err != nil {
   447  		t.Fatal(err)
   448  	}
   449  
   450  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   451  
   452  	// Discards errors of the form:
   453  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   454  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   455  
   456  	frontend := httptest.NewServer(proxyHandler)
   457  	defer frontend.Close()
   458  	frontendClient := frontend.Client()
   459  
   460  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   461  	go func() {
   462  		<-reqInFlight
   463  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   464  	}()
   465  	res, err := frontendClient.Do(getReq)
   466  	if res != nil {
   467  		t.Errorf("got response %v; want nil", res.Status)
   468  	}
   469  	if err == nil {
   470  		// This should be an error like:
   471  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   472  		//    use of closed network connection
   473  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   474  	}
   475  }
   476  
   477  func req(t *testing.T, v string) *http.Request {
   478  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   479  	if err != nil {
   480  		t.Fatal(err)
   481  	}
   482  	return req
   483  }
   484  
   485  // Issue 12344
   486  func TestNilBody(t *testing.T) {
   487  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   488  		w.Write([]byte("hi"))
   489  	}))
   490  	defer backend.Close()
   491  
   492  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   493  		backURL, _ := url.Parse(backend.URL)
   494  		rp := NewSingleHostReverseProxy(backURL)
   495  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   496  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   497  		rp.ServeHTTP(w, r)
   498  	}))
   499  	defer frontend.Close()
   500  
   501  	res, err := http.Get(frontend.URL)
   502  	if err != nil {
   503  		t.Fatal(err)
   504  	}
   505  	defer res.Body.Close()
   506  	slurp, err := io.ReadAll(res.Body)
   507  	if err != nil {
   508  		t.Fatal(err)
   509  	}
   510  	if string(slurp) != "hi" {
   511  		t.Errorf("Got %q; want %q", slurp, "hi")
   512  	}
   513  }
   514  
   515  // Issue 15524
   516  func TestUserAgentHeader(t *testing.T) {
   517  	const explicitUA = "explicit UA"
   518  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   519  		if r.URL.Path == "/noua" {
   520  			if c := r.Header.Get("User-Agent"); c != "" {
   521  				t.Errorf("handler got non-empty User-Agent header %q", c)
   522  			}
   523  			return
   524  		}
   525  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   526  			t.Errorf("handler got unexpected User-Agent header %q", c)
   527  		}
   528  	}))
   529  	defer backend.Close()
   530  	backendURL, err := url.Parse(backend.URL)
   531  	if err != nil {
   532  		t.Fatal(err)
   533  	}
   534  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   535  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   536  	frontend := httptest.NewServer(proxyHandler)
   537  	defer frontend.Close()
   538  	frontendClient := frontend.Client()
   539  
   540  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   541  	getReq.Header.Set("User-Agent", explicitUA)
   542  	getReq.Close = true
   543  	res, err := frontendClient.Do(getReq)
   544  	if err != nil {
   545  		t.Fatalf("Get: %v", err)
   546  	}
   547  	res.Body.Close()
   548  
   549  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   550  	getReq.Header.Set("User-Agent", "")
   551  	getReq.Close = true
   552  	res, err = frontendClient.Do(getReq)
   553  	if err != nil {
   554  		t.Fatalf("Get: %v", err)
   555  	}
   556  	res.Body.Close()
   557  }
   558  
   559  type bufferPool struct {
   560  	get func() []byte
   561  	put func([]byte)
   562  }
   563  
   564  func (bp bufferPool) Get() []byte  { return bp.get() }
   565  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   566  
   567  func TestReverseProxyGetPutBuffer(t *testing.T) {
   568  	const msg = "hi"
   569  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   570  		io.WriteString(w, msg)
   571  	}))
   572  	defer backend.Close()
   573  
   574  	backendURL, err := url.Parse(backend.URL)
   575  	if err != nil {
   576  		t.Fatal(err)
   577  	}
   578  
   579  	var (
   580  		mu  sync.Mutex
   581  		log []string
   582  	)
   583  	addLog := func(event string) {
   584  		mu.Lock()
   585  		defer mu.Unlock()
   586  		log = append(log, event)
   587  	}
   588  	rp := NewSingleHostReverseProxy(backendURL)
   589  	const size = 1234
   590  	rp.BufferPool = bufferPool{
   591  		get: func() []byte {
   592  			addLog("getBuf")
   593  			return make([]byte, size)
   594  		},
   595  		put: func(p []byte) {
   596  			addLog("putBuf-" + strconv.Itoa(len(p)))
   597  		},
   598  	}
   599  	frontend := httptest.NewServer(rp)
   600  	defer frontend.Close()
   601  
   602  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   603  	req.Close = true
   604  	res, err := frontend.Client().Do(req)
   605  	if err != nil {
   606  		t.Fatalf("Get: %v", err)
   607  	}
   608  	slurp, err := io.ReadAll(res.Body)
   609  	res.Body.Close()
   610  	if err != nil {
   611  		t.Fatalf("reading body: %v", err)
   612  	}
   613  	if string(slurp) != msg {
   614  		t.Errorf("msg = %q; want %q", slurp, msg)
   615  	}
   616  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   617  	mu.Lock()
   618  	defer mu.Unlock()
   619  	if !reflect.DeepEqual(log, wantLog) {
   620  		t.Errorf("Log events = %q; want %q", log, wantLog)
   621  	}
   622  }
   623  
   624  func TestReverseProxy_Post(t *testing.T) {
   625  	const backendResponse = "I am the backend"
   626  	const backendStatus = 200
   627  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   628  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   629  		slurp, err := io.ReadAll(r.Body)
   630  		if err != nil {
   631  			t.Errorf("Backend body read = %v", err)
   632  		}
   633  		if len(slurp) != len(requestBody) {
   634  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   635  		}
   636  		if !bytes.Equal(slurp, requestBody) {
   637  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   638  		}
   639  		w.Write([]byte(backendResponse))
   640  	}))
   641  	defer backend.Close()
   642  	backendURL, err := url.Parse(backend.URL)
   643  	if err != nil {
   644  		t.Fatal(err)
   645  	}
   646  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   647  	frontend := httptest.NewServer(proxyHandler)
   648  	defer frontend.Close()
   649  
   650  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   651  	res, err := frontend.Client().Do(postReq)
   652  	if err != nil {
   653  		t.Fatalf("Do: %v", err)
   654  	}
   655  	if g, e := res.StatusCode, backendStatus; g != e {
   656  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   657  	}
   658  	bodyBytes, _ := io.ReadAll(res.Body)
   659  	if g, e := string(bodyBytes), backendResponse; g != e {
   660  		t.Errorf("got body %q; expected %q", g, e)
   661  	}
   662  }
   663  
   664  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   665  
   666  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   667  	return fn(req)
   668  }
   669  
   670  // Issue 16036: send a Request with a nil Body when possible
   671  func TestReverseProxy_NilBody(t *testing.T) {
   672  	backendURL, _ := url.Parse("http://fake.tld/")
   673  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   674  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   675  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   676  		if req.Body != nil {
   677  			t.Error("Body != nil; want a nil Body")
   678  		}
   679  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   680  	})
   681  	frontend := httptest.NewServer(proxyHandler)
   682  	defer frontend.Close()
   683  
   684  	res, err := frontend.Client().Get(frontend.URL)
   685  	if err != nil {
   686  		t.Fatal(err)
   687  	}
   688  	defer res.Body.Close()
   689  	if res.StatusCode != 502 {
   690  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   691  	}
   692  }
   693  
   694  // Issue 33142: always allocate the request headers
   695  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   696  	proxyHandler := new(ReverseProxy)
   697  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   698  	proxyHandler.Director = func(*http.Request) {}     // noop
   699  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   700  		if req.Header == nil {
   701  			t.Error("Header == nil; want a non-nil Header")
   702  		}
   703  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   704  	})
   705  
   706  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   707  		Method:     "GET",
   708  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   709  		Proto:      "HTTP/1.0",
   710  		ProtoMajor: 1,
   711  	})
   712  }
   713  
   714  // Issue 14237. Test ModifyResponse and that an error from it
   715  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   716  func TestReverseProxyModifyResponse(t *testing.T) {
   717  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   718  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   719  	}))
   720  	defer backendServer.Close()
   721  
   722  	rpURL, _ := url.Parse(backendServer.URL)
   723  	rproxy := NewSingleHostReverseProxy(rpURL)
   724  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   725  	rproxy.ModifyResponse = func(resp *http.Response) error {
   726  		if resp.Header.Get("X-Hit-Mod") != "true" {
   727  			return fmt.Errorf("tried to by-pass proxy")
   728  		}
   729  		return nil
   730  	}
   731  
   732  	frontendProxy := httptest.NewServer(rproxy)
   733  	defer frontendProxy.Close()
   734  
   735  	tests := []struct {
   736  		url      string
   737  		wantCode int
   738  	}{
   739  		{frontendProxy.URL + "/mod", http.StatusOK},
   740  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   741  	}
   742  
   743  	for i, tt := range tests {
   744  		resp, err := http.Get(tt.url)
   745  		if err != nil {
   746  			t.Fatalf("failed to reach proxy: %v", err)
   747  		}
   748  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   749  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   750  		}
   751  		resp.Body.Close()
   752  	}
   753  }
   754  
   755  type failingRoundTripper struct{}
   756  
   757  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   758  	return nil, errors.New("some error")
   759  }
   760  
   761  type staticResponseRoundTripper struct{ res *http.Response }
   762  
   763  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   764  	return rt.res, nil
   765  }
   766  
   767  func TestReverseProxyErrorHandler(t *testing.T) {
   768  	tests := []struct {
   769  		name           string
   770  		wantCode       int
   771  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   772  		transport      http.RoundTripper // defaults to failingRoundTripper
   773  		modifyResponse func(*http.Response) error
   774  	}{
   775  		{
   776  			name:     "default",
   777  			wantCode: http.StatusBadGateway,
   778  		},
   779  		{
   780  			name:         "errorhandler",
   781  			wantCode:     http.StatusTeapot,
   782  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   783  		},
   784  		{
   785  			name: "modifyresponse_noerr",
   786  			transport: staticResponseRoundTripper{
   787  				&http.Response{StatusCode: 345, Body: http.NoBody},
   788  			},
   789  			modifyResponse: func(res *http.Response) error {
   790  				res.StatusCode++
   791  				return nil
   792  			},
   793  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   794  			wantCode:     346,
   795  		},
   796  		{
   797  			name: "modifyresponse_err",
   798  			transport: staticResponseRoundTripper{
   799  				&http.Response{StatusCode: 345, Body: http.NoBody},
   800  			},
   801  			modifyResponse: func(res *http.Response) error {
   802  				res.StatusCode++
   803  				return errors.New("some error to trigger errorHandler")
   804  			},
   805  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   806  			wantCode:     http.StatusTeapot,
   807  		},
   808  	}
   809  
   810  	for _, tt := range tests {
   811  		t.Run(tt.name, func(t *testing.T) {
   812  			target := &url.URL{
   813  				Scheme: "http",
   814  				Host:   "dummy.tld",
   815  				Path:   "/",
   816  			}
   817  			rproxy := NewSingleHostReverseProxy(target)
   818  			rproxy.Transport = tt.transport
   819  			rproxy.ModifyResponse = tt.modifyResponse
   820  			if rproxy.Transport == nil {
   821  				rproxy.Transport = failingRoundTripper{}
   822  			}
   823  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   824  			if tt.errorHandler != nil {
   825  				rproxy.ErrorHandler = tt.errorHandler
   826  			}
   827  			frontendProxy := httptest.NewServer(rproxy)
   828  			defer frontendProxy.Close()
   829  
   830  			resp, err := http.Get(frontendProxy.URL + "/test")
   831  			if err != nil {
   832  				t.Fatalf("failed to reach proxy: %v", err)
   833  			}
   834  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   835  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   836  			}
   837  			resp.Body.Close()
   838  		})
   839  	}
   840  }
   841  
   842  // Issue 16659: log errors from short read
   843  func TestReverseProxy_CopyBuffer(t *testing.T) {
   844  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   845  		out := "this call was relayed by the reverse proxy"
   846  		// Coerce a wrong content length to induce io.UnexpectedEOF
   847  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   848  		fmt.Fprintln(w, out)
   849  	}))
   850  	defer backendServer.Close()
   851  
   852  	rpURL, err := url.Parse(backendServer.URL)
   853  	if err != nil {
   854  		t.Fatal(err)
   855  	}
   856  
   857  	var proxyLog bytes.Buffer
   858  	rproxy := NewSingleHostReverseProxy(rpURL)
   859  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   860  	donec := make(chan bool, 1)
   861  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   862  		defer func() { donec <- true }()
   863  		rproxy.ServeHTTP(w, r)
   864  	}))
   865  	defer frontendProxy.Close()
   866  
   867  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   868  		t.Fatalf("want non-nil error")
   869  	}
   870  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   871  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   872  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   873  	// continue after Get.
   874  	<-donec
   875  
   876  	expected := []string{
   877  		"EOF",
   878  		"read",
   879  	}
   880  	for _, phrase := range expected {
   881  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   882  			t.Errorf("expected log to contain phrase %q", phrase)
   883  		}
   884  	}
   885  }
   886  
   887  type staticTransport struct {
   888  	res *http.Response
   889  }
   890  
   891  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   892  	return t.res, nil
   893  }
   894  
   895  func BenchmarkServeHTTP(b *testing.B) {
   896  	res := &http.Response{
   897  		StatusCode: 200,
   898  		Body:       io.NopCloser(strings.NewReader("")),
   899  	}
   900  	proxy := &ReverseProxy{
   901  		Director:  func(*http.Request) {},
   902  		Transport: &staticTransport{res},
   903  	}
   904  
   905  	w := httptest.NewRecorder()
   906  	r := httptest.NewRequest("GET", "/", nil)
   907  
   908  	b.ReportAllocs()
   909  	for i := 0; i < b.N; i++ {
   910  		proxy.ServeHTTP(w, r)
   911  	}
   912  }
   913  
   914  func TestServeHTTPDeepCopy(t *testing.T) {
   915  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   916  		w.Write([]byte("Hello Gopher!"))
   917  	}))
   918  	defer backend.Close()
   919  	backendURL, err := url.Parse(backend.URL)
   920  	if err != nil {
   921  		t.Fatal(err)
   922  	}
   923  
   924  	type result struct {
   925  		before, after string
   926  	}
   927  
   928  	resultChan := make(chan result, 1)
   929  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   930  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   931  		before := r.URL.String()
   932  		proxyHandler.ServeHTTP(w, r)
   933  		after := r.URL.String()
   934  		resultChan <- result{before: before, after: after}
   935  	}))
   936  	defer frontend.Close()
   937  
   938  	want := result{before: "/", after: "/"}
   939  
   940  	res, err := frontend.Client().Get(frontend.URL)
   941  	if err != nil {
   942  		t.Fatalf("Do: %v", err)
   943  	}
   944  	res.Body.Close()
   945  
   946  	got := <-resultChan
   947  	if got != want {
   948  		t.Errorf("got = %+v; want = %+v", got, want)
   949  	}
   950  }
   951  
   952  // Issue 18327: verify we always do a deep copy of the Request.Header map
   953  // before any mutations.
   954  func TestClonesRequestHeaders(t *testing.T) {
   955  	log.SetOutput(io.Discard)
   956  	defer log.SetOutput(os.Stderr)
   957  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   958  	req.RemoteAddr = "1.2.3.4:56789"
   959  	rp := &ReverseProxy{
   960  		Director: func(req *http.Request) {
   961  			req.Header.Set("From-Director", "1")
   962  		},
   963  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   964  			if v := req.Header.Get("From-Director"); v != "1" {
   965  				t.Errorf("From-Directory value = %q; want 1", v)
   966  			}
   967  			return nil, io.EOF
   968  		}),
   969  	}
   970  	rp.ServeHTTP(httptest.NewRecorder(), req)
   971  
   972  	if req.Header.Get("From-Director") == "1" {
   973  		t.Error("Director header mutation modified caller's request")
   974  	}
   975  	if req.Header.Get("X-Forwarded-For") != "" {
   976  		t.Error("X-Forward-For header mutation modified caller's request")
   977  	}
   978  
   979  }
   980  
   981  type roundTripperFunc func(req *http.Request) (*http.Response, error)
   982  
   983  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   984  	return fn(req)
   985  }
   986  
   987  func TestModifyResponseClosesBody(t *testing.T) {
   988  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   989  	req.RemoteAddr = "1.2.3.4:56789"
   990  	closeCheck := new(checkCloser)
   991  	logBuf := new(bytes.Buffer)
   992  	outErr := errors.New("ModifyResponse error")
   993  	rp := &ReverseProxy{
   994  		Director: func(req *http.Request) {},
   995  		Transport: &staticTransport{&http.Response{
   996  			StatusCode: 200,
   997  			Body:       closeCheck,
   998  		}},
   999  		ErrorLog: log.New(logBuf, "", 0),
  1000  		ModifyResponse: func(*http.Response) error {
  1001  			return outErr
  1002  		},
  1003  	}
  1004  	rec := httptest.NewRecorder()
  1005  	rp.ServeHTTP(rec, req)
  1006  	res := rec.Result()
  1007  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1008  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1009  	}
  1010  	if !closeCheck.closed {
  1011  		t.Errorf("body should have been closed")
  1012  	}
  1013  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1014  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1015  	}
  1016  }
  1017  
  1018  type checkCloser struct {
  1019  	closed bool
  1020  }
  1021  
  1022  func (cc *checkCloser) Close() error {
  1023  	cc.closed = true
  1024  	return nil
  1025  }
  1026  
  1027  func (cc *checkCloser) Read(b []byte) (int, error) {
  1028  	return len(b), nil
  1029  }
  1030  
  1031  // Issue 23643: panic on body copy error
  1032  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1033  	log.SetOutput(io.Discard)
  1034  	defer log.SetOutput(os.Stderr)
  1035  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1036  		out := "this call was relayed by the reverse proxy"
  1037  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1038  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1039  		fmt.Fprintln(w, out)
  1040  	}))
  1041  	defer backendServer.Close()
  1042  
  1043  	rpURL, err := url.Parse(backendServer.URL)
  1044  	if err != nil {
  1045  		t.Fatal(err)
  1046  	}
  1047  
  1048  	rproxy := NewSingleHostReverseProxy(rpURL)
  1049  
  1050  	// Ensure that the handler panics when the body read encounters an
  1051  	// io.ErrUnexpectedEOF
  1052  	defer func() {
  1053  		err := recover()
  1054  		if err == nil {
  1055  			t.Fatal("handler should have panicked")
  1056  		}
  1057  		if err != http.ErrAbortHandler {
  1058  			t.Fatal("expected ErrAbortHandler, got", err)
  1059  		}
  1060  	}()
  1061  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1062  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1063  }
  1064  
  1065  func TestSelectFlushInterval(t *testing.T) {
  1066  	tests := []struct {
  1067  		name string
  1068  		p    *ReverseProxy
  1069  		res  *http.Response
  1070  		want time.Duration
  1071  	}{
  1072  		{
  1073  			name: "default",
  1074  			res:  &http.Response{},
  1075  			p:    &ReverseProxy{FlushInterval: 123},
  1076  			want: 123,
  1077  		},
  1078  		{
  1079  			name: "server-sent events overrides non-zero",
  1080  			res: &http.Response{
  1081  				Header: http.Header{
  1082  					"Content-Type": {"text/event-stream"},
  1083  				},
  1084  			},
  1085  			p:    &ReverseProxy{FlushInterval: 123},
  1086  			want: -1,
  1087  		},
  1088  		{
  1089  			name: "server-sent events overrides zero",
  1090  			res: &http.Response{
  1091  				Header: http.Header{
  1092  					"Content-Type": {"text/event-stream"},
  1093  				},
  1094  			},
  1095  			p:    &ReverseProxy{FlushInterval: 0},
  1096  			want: -1,
  1097  		},
  1098  		{
  1099  			name: "Content-Length: -1, overrides non-zero",
  1100  			res: &http.Response{
  1101  				ContentLength: -1,
  1102  			},
  1103  			p:    &ReverseProxy{FlushInterval: 123},
  1104  			want: -1,
  1105  		},
  1106  		{
  1107  			name: "Content-Length: -1, overrides zero",
  1108  			res: &http.Response{
  1109  				ContentLength: -1,
  1110  			},
  1111  			p:    &ReverseProxy{FlushInterval: 0},
  1112  			want: -1,
  1113  		},
  1114  	}
  1115  	for _, tt := range tests {
  1116  		t.Run(tt.name, func(t *testing.T) {
  1117  			got := tt.p.flushInterval(tt.res)
  1118  			if got != tt.want {
  1119  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1120  			}
  1121  		})
  1122  	}
  1123  }
  1124  
  1125  func TestReverseProxyWebSocket(t *testing.T) {
  1126  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1127  		if upgradeType(r.Header) != "websocket" {
  1128  			t.Error("unexpected backend request")
  1129  			http.Error(w, "unexpected request", 400)
  1130  			return
  1131  		}
  1132  		c, _, err := w.(http.Hijacker).Hijack()
  1133  		if err != nil {
  1134  			t.Error(err)
  1135  			return
  1136  		}
  1137  		defer c.Close()
  1138  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1139  		bs := bufio.NewScanner(c)
  1140  		if !bs.Scan() {
  1141  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1142  			return
  1143  		}
  1144  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1145  	}))
  1146  	defer backendServer.Close()
  1147  
  1148  	backURL, _ := url.Parse(backendServer.URL)
  1149  	rproxy := NewSingleHostReverseProxy(backURL)
  1150  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1151  	rproxy.ModifyResponse = func(res *http.Response) error {
  1152  		res.Header.Add("X-Modified", "true")
  1153  		return nil
  1154  	}
  1155  
  1156  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1157  		rw.Header().Set("X-Header", "X-Value")
  1158  		rproxy.ServeHTTP(rw, req)
  1159  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1160  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1161  		}
  1162  	})
  1163  
  1164  	frontendProxy := httptest.NewServer(handler)
  1165  	defer frontendProxy.Close()
  1166  
  1167  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1168  	req.Header.Set("Connection", "Upgrade")
  1169  	req.Header.Set("Upgrade", "websocket")
  1170  
  1171  	c := frontendProxy.Client()
  1172  	res, err := c.Do(req)
  1173  	if err != nil {
  1174  		t.Fatal(err)
  1175  	}
  1176  	if res.StatusCode != 101 {
  1177  		t.Fatalf("status = %v; want 101", res.Status)
  1178  	}
  1179  
  1180  	got := res.Header.Get("X-Header")
  1181  	want := "X-Value"
  1182  	if got != want {
  1183  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1184  	}
  1185  
  1186  	if upgradeType(res.Header) != "websocket" {
  1187  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1188  	}
  1189  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1190  	if !ok {
  1191  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1192  	}
  1193  	defer rwc.Close()
  1194  
  1195  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1196  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1197  	}
  1198  
  1199  	io.WriteString(rwc, "Hello\n")
  1200  	bs := bufio.NewScanner(rwc)
  1201  	if !bs.Scan() {
  1202  		t.Fatalf("Scan: %v", bs.Err())
  1203  	}
  1204  	got = bs.Text()
  1205  	want = `backend got "Hello"`
  1206  	if got != want {
  1207  		t.Errorf("got %#q, want %#q", got, want)
  1208  	}
  1209  }
  1210  
  1211  func TestReverseProxyWebSocketCancelation(t *testing.T) {
  1212  	n := 5
  1213  	triggerCancelCh := make(chan bool, n)
  1214  	nthResponse := func(i int) string {
  1215  		return fmt.Sprintf("backend response #%d\n", i)
  1216  	}
  1217  	terminalMsg := "final message"
  1218  
  1219  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1220  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1221  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1222  			http.Error(w, "Unexpected request", 400)
  1223  			return
  1224  		}
  1225  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1226  		if err != nil {
  1227  			t.Error(err)
  1228  			return
  1229  		}
  1230  		defer conn.Close()
  1231  
  1232  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1233  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1234  			t.Error(err)
  1235  			return
  1236  		}
  1237  		if _, _, err := bufrw.ReadLine(); err != nil {
  1238  			t.Errorf("Failed to read line from client: %v", err)
  1239  			return
  1240  		}
  1241  
  1242  		for i := 0; i < n; i++ {
  1243  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1244  				select {
  1245  				case <-triggerCancelCh:
  1246  				default:
  1247  					t.Errorf("Writing response #%d failed: %v", i, err)
  1248  				}
  1249  				return
  1250  			}
  1251  			bufrw.Flush()
  1252  			time.Sleep(time.Second)
  1253  		}
  1254  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1255  			select {
  1256  			case <-triggerCancelCh:
  1257  			default:
  1258  				t.Errorf("Failed to write terminal message: %v", err)
  1259  			}
  1260  		}
  1261  		bufrw.Flush()
  1262  	}))
  1263  	defer cst.Close()
  1264  
  1265  	backendURL, _ := url.Parse(cst.URL)
  1266  	rproxy := NewSingleHostReverseProxy(backendURL)
  1267  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1268  	rproxy.ModifyResponse = func(res *http.Response) error {
  1269  		res.Header.Add("X-Modified", "true")
  1270  		return nil
  1271  	}
  1272  
  1273  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1274  		rw.Header().Set("X-Header", "X-Value")
  1275  		ctx, cancel := context.WithCancel(req.Context())
  1276  		go func() {
  1277  			<-triggerCancelCh
  1278  			cancel()
  1279  		}()
  1280  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1281  	})
  1282  
  1283  	frontendProxy := httptest.NewServer(handler)
  1284  	defer frontendProxy.Close()
  1285  
  1286  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1287  	req.Header.Set("Connection", "Upgrade")
  1288  	req.Header.Set("Upgrade", "websocket")
  1289  
  1290  	res, err := frontendProxy.Client().Do(req)
  1291  	if err != nil {
  1292  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1293  	}
  1294  	defer res.Body.Close()
  1295  	if g, w := res.StatusCode, 101; g != w {
  1296  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1297  	}
  1298  
  1299  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1300  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1301  	}
  1302  
  1303  	if g, w := upgradeType(res.Header), "websocket"; g != w {
  1304  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1305  	}
  1306  
  1307  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1308  	if !ok {
  1309  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1310  	}
  1311  
  1312  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1313  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1314  	}
  1315  
  1316  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1317  		t.Fatalf("Failed to write first message: %v", err)
  1318  	}
  1319  
  1320  	// Read loop.
  1321  
  1322  	br := bufio.NewReader(rwc)
  1323  	for {
  1324  		line, err := br.ReadString('\n')
  1325  		switch {
  1326  		case line == terminalMsg: // this case before "err == io.EOF"
  1327  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1328  
  1329  		case err == io.EOF:
  1330  			return
  1331  
  1332  		case err != nil:
  1333  			t.Fatalf("Unexpected error: %v", err)
  1334  
  1335  		case line == nthResponse(0): // We've gotten the first response back
  1336  			// Let's trigger a cancel.
  1337  			close(triggerCancelCh)
  1338  		}
  1339  	}
  1340  }
  1341  
  1342  func TestUnannouncedTrailer(t *testing.T) {
  1343  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1344  		w.WriteHeader(http.StatusOK)
  1345  		w.(http.Flusher).Flush()
  1346  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1347  	}))
  1348  	defer backend.Close()
  1349  	backendURL, err := url.Parse(backend.URL)
  1350  	if err != nil {
  1351  		t.Fatal(err)
  1352  	}
  1353  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1354  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1355  	frontend := httptest.NewServer(proxyHandler)
  1356  	defer frontend.Close()
  1357  	frontendClient := frontend.Client()
  1358  
  1359  	res, err := frontendClient.Get(frontend.URL)
  1360  	if err != nil {
  1361  		t.Fatalf("Get: %v", err)
  1362  	}
  1363  
  1364  	io.ReadAll(res.Body)
  1365  
  1366  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1367  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1368  	}
  1369  
  1370  }
  1371  
  1372  func TestSingleJoinSlash(t *testing.T) {
  1373  	tests := []struct {
  1374  		slasha   string
  1375  		slashb   string
  1376  		expected string
  1377  	}{
  1378  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1379  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1380  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1381  		{"https://www.google.com", "", "https://www.google.com/"},
  1382  		{"", "favicon.ico", "/favicon.ico"},
  1383  	}
  1384  	for _, tt := range tests {
  1385  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1386  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1387  				tt.slasha,
  1388  				tt.slashb,
  1389  				tt.expected,
  1390  				got)
  1391  		}
  1392  	}
  1393  }
  1394  
  1395  func TestJoinURLPath(t *testing.T) {
  1396  	tests := []struct {
  1397  		a        *url.URL
  1398  		b        *url.URL
  1399  		wantPath string
  1400  		wantRaw  string
  1401  	}{
  1402  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1403  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1404  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1405  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1406  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1407  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1408  	}
  1409  
  1410  	for _, tt := range tests {
  1411  		p, rp := joinURLPath(tt.a, tt.b)
  1412  		if p != tt.wantPath || rp != tt.wantRaw {
  1413  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1414  				tt.a.Path, tt.a.RawPath,
  1415  				tt.b.Path, tt.b.RawPath,
  1416  				tt.wantPath, tt.wantRaw,
  1417  				p, rp)
  1418  		}
  1419  	}
  1420  }