github.com/useflyent/fhttp@v0.0.0-20211004035111-333f430cfbbf/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/url"
    18  	"os"
    19  	"reflect"
    20  	"sort"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	http "github.com/useflyent/fhttp"
    28  	"github.com/useflyent/fhttp/httptest"
    29  )
    30  
    31  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    32  
    33  func init() {
    34  	inOurTests = true
    35  	hopHeaders = append(hopHeaders, fakeHopHeader)
    36  }
    37  
    38  func TestReverseProxy(t *testing.T) {
    39  	const backendResponse = "I am the backend"
    40  	const backendStatus = 404
    41  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    42  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    43  			c, _, _ := w.(http.Hijacker).Hijack()
    44  			c.Close()
    45  			return
    46  		}
    47  		if len(r.TransferEncoding) > 0 {
    48  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    49  		}
    50  		if r.Header.Get("X-Forwarded-For") == "" {
    51  			t.Errorf("didn't get X-Forwarded-For header")
    52  		}
    53  		if c := r.Header.Get("Connection"); c != "" {
    54  			t.Errorf("handler got Connection header value %q", c)
    55  		}
    56  		if c := r.Header.Get("Te"); c != "trailers" {
    57  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    58  		}
    59  		if c := r.Header.Get("Upgrade"); c != "" {
    60  			t.Errorf("handler got Upgrade header value %q", c)
    61  		}
    62  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    63  			t.Errorf("handler got Proxy-Connection header value %q", c)
    64  		}
    65  		if g, e := r.Host, "some-name"; g != e {
    66  			t.Errorf("backend got Host header %q, want %q", g, e)
    67  		}
    68  		w.Header().Set("Trailers", "not a special header field name")
    69  		w.Header().Set("Trailer", "X-Trailer")
    70  		w.Header().Set("X-Foo", "bar")
    71  		w.Header().Set("Upgrade", "foo")
    72  		w.Header().Set(fakeHopHeader, "foo")
    73  		w.Header().Add("X-Multi-Value", "foo")
    74  		w.Header().Add("X-Multi-Value", "bar")
    75  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    76  		w.WriteHeader(backendStatus)
    77  		w.Write([]byte(backendResponse))
    78  		w.Header().Set("X-Trailer", "trailer_value")
    79  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    80  	}))
    81  	defer backend.Close()
    82  	backendURL, err := url.Parse(backend.URL)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    87  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    88  	frontend := httptest.NewServer(proxyHandler)
    89  	defer frontend.Close()
    90  	frontendClient := frontend.Client()
    91  
    92  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    93  	getReq.Host = "some-name"
    94  	getReq.Header.Set("Connection", "close")
    95  	getReq.Header.Set("Te", "trailers")
    96  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    97  	getReq.Header.Set("Upgrade", "foo")
    98  	getReq.Close = true
    99  	res, err := frontendClient.Do(getReq)
   100  	if err != nil {
   101  		t.Fatalf("Get: %v", err)
   102  	}
   103  	if g, e := res.StatusCode, backendStatus; g != e {
   104  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   105  	}
   106  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   107  		t.Errorf("got X-Foo %q; expected %q", g, e)
   108  	}
   109  	if c := res.Header.Get(fakeHopHeader); c != "" {
   110  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   111  	}
   112  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   113  		t.Errorf("header Trailers = %q; want %q", g, e)
   114  	}
   115  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   116  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   117  	}
   118  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   119  		t.Fatalf("got %d SetCookies, want %d", g, e)
   120  	}
   121  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   122  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   123  	}
   124  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   125  		t.Errorf("unexpected cookie %q", cookie.Name)
   126  	}
   127  	bodyBytes, _ := io.ReadAll(res.Body)
   128  	if g, e := string(bodyBytes), backendResponse; g != e {
   129  		t.Errorf("got body %q; expected %q", g, e)
   130  	}
   131  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   132  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   133  	}
   134  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   135  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   136  	}
   137  
   138  	// Test that a backend failing to be reached or one which doesn't return
   139  	// a response results in a StatusBadGateway.
   140  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   141  	getReq.Close = true
   142  	res, err = frontendClient.Do(getReq)
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  	res.Body.Close()
   147  	if res.StatusCode != http.StatusBadGateway {
   148  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   149  	}
   150  
   151  }
   152  
   153  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   154  // header value.
   155  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   156  	const fakeConnectionToken = "X-Fake-Connection-Token"
   157  	const backendResponse = "I am the backend"
   158  
   159  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   160  	// in the Request's Connection header.
   161  	const someConnHeader = "X-Some-Conn-Header"
   162  
   163  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   164  		if c := r.Header.Get("Connection"); c != "" {
   165  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   166  		}
   167  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   168  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   169  		}
   170  		if c := r.Header.Get(someConnHeader); c != "" {
   171  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   172  		}
   173  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   174  		w.Header().Add("Connection", someConnHeader)
   175  		w.Header().Set(someConnHeader, "should be deleted")
   176  		w.Header().Set(fakeConnectionToken, "should be deleted")
   177  		io.WriteString(w, backendResponse)
   178  	}))
   179  	defer backend.Close()
   180  	backendURL, err := url.Parse(backend.URL)
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   185  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   186  		proxyHandler.ServeHTTP(w, r)
   187  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   188  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   189  		}
   190  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   191  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   192  		}
   193  		c := r.Header["Connection"]
   194  		var cf []string
   195  		for _, f := range c {
   196  			for _, sf := range strings.Split(f, ",") {
   197  				if sf = strings.TrimSpace(sf); sf != "" {
   198  					cf = append(cf, sf)
   199  				}
   200  			}
   201  		}
   202  		sort.Strings(cf)
   203  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   204  		sort.Strings(expectedValues)
   205  		if !reflect.DeepEqual(cf, expectedValues) {
   206  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   207  		}
   208  	}))
   209  	defer frontend.Close()
   210  
   211  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   212  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   213  	getReq.Header.Add("Connection", someConnHeader)
   214  	getReq.Header.Set(someConnHeader, "should be deleted")
   215  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   216  	res, err := frontend.Client().Do(getReq)
   217  	if err != nil {
   218  		t.Fatalf("Get: %v", err)
   219  	}
   220  	defer res.Body.Close()
   221  	bodyBytes, err := io.ReadAll(res.Body)
   222  	if err != nil {
   223  		t.Fatalf("reading body: %v", err)
   224  	}
   225  	if got, want := string(bodyBytes), backendResponse; got != want {
   226  		t.Errorf("got body %q; want %q", got, want)
   227  	}
   228  	if c := res.Header.Get("Connection"); c != "" {
   229  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   230  	}
   231  	if c := res.Header.Get(someConnHeader); c != "" {
   232  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   233  	}
   234  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   235  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   236  	}
   237  }
   238  
   239  func TestXForwardedFor(t *testing.T) {
   240  	const prevForwardedFor = "client ip"
   241  	const backendResponse = "I am the backend"
   242  	const backendStatus = 404
   243  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   244  		if r.Header.Get("X-Forwarded-For") == "" {
   245  			t.Errorf("didn't get X-Forwarded-For header")
   246  		}
   247  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   248  			t.Errorf("X-Forwarded-For didn't contain prior data")
   249  		}
   250  		w.WriteHeader(backendStatus)
   251  		w.Write([]byte(backendResponse))
   252  	}))
   253  	defer backend.Close()
   254  	backendURL, err := url.Parse(backend.URL)
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   259  	frontend := httptest.NewServer(proxyHandler)
   260  	defer frontend.Close()
   261  
   262  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   263  	getReq.Host = "some-name"
   264  	getReq.Header.Set("Connection", "close")
   265  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   266  	getReq.Close = true
   267  	res, err := frontend.Client().Do(getReq)
   268  	if err != nil {
   269  		t.Fatalf("Get: %v", err)
   270  	}
   271  	if g, e := res.StatusCode, backendStatus; g != e {
   272  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   273  	}
   274  	bodyBytes, _ := io.ReadAll(res.Body)
   275  	if g, e := string(bodyBytes), backendResponse; g != e {
   276  		t.Errorf("got body %q; expected %q", g, e)
   277  	}
   278  }
   279  
   280  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   281  func TestXForwardedFor_Omit(t *testing.T) {
   282  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   283  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   284  			t.Errorf("got X-Forwarded-For header: %q", v)
   285  		}
   286  		w.Write([]byte("hi"))
   287  	}))
   288  	defer backend.Close()
   289  	backendURL, err := url.Parse(backend.URL)
   290  	if err != nil {
   291  		t.Fatal(err)
   292  	}
   293  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   294  	frontend := httptest.NewServer(proxyHandler)
   295  	defer frontend.Close()
   296  
   297  	oldDirector := proxyHandler.Director
   298  	proxyHandler.Director = func(r *http.Request) {
   299  		r.Header["X-Forwarded-For"] = nil
   300  		oldDirector(r)
   301  	}
   302  
   303  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   304  	getReq.Host = "some-name"
   305  	getReq.Close = true
   306  	res, err := frontend.Client().Do(getReq)
   307  	if err != nil {
   308  		t.Fatalf("Get: %v", err)
   309  	}
   310  	res.Body.Close()
   311  }
   312  
   313  var proxyQueryTests = []struct {
   314  	baseSuffix string // suffix to add to backend URL
   315  	reqSuffix  string // suffix to add to frontend's request URL
   316  	want       string // what backend should see for final request URL (without ?)
   317  }{
   318  	{"", "", ""},
   319  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   320  	{"", "?us=er", "us=er"},
   321  	{"?sta=tic", "", "sta=tic"},
   322  }
   323  
   324  func TestReverseProxyQuery(t *testing.T) {
   325  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   326  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   327  		w.Write([]byte("hi"))
   328  	}))
   329  	defer backend.Close()
   330  
   331  	for i, tt := range proxyQueryTests {
   332  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   333  		if err != nil {
   334  			t.Fatal(err)
   335  		}
   336  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   337  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   338  		req.Close = true
   339  		res, err := frontend.Client().Do(req)
   340  		if err != nil {
   341  			t.Fatalf("%d. Get: %v", i, err)
   342  		}
   343  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   344  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   345  		}
   346  		res.Body.Close()
   347  		frontend.Close()
   348  	}
   349  }
   350  
   351  func TestReverseProxyFlushInterval(t *testing.T) {
   352  	const expected = "hi"
   353  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   354  		w.Write([]byte(expected))
   355  	}))
   356  	defer backend.Close()
   357  
   358  	backendURL, err := url.Parse(backend.URL)
   359  	if err != nil {
   360  		t.Fatal(err)
   361  	}
   362  
   363  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   364  	proxyHandler.FlushInterval = time.Microsecond
   365  
   366  	frontend := httptest.NewServer(proxyHandler)
   367  	defer frontend.Close()
   368  
   369  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   370  	req.Close = true
   371  	res, err := frontend.Client().Do(req)
   372  	if err != nil {
   373  		t.Fatalf("Get: %v", err)
   374  	}
   375  	defer res.Body.Close()
   376  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   377  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   378  	}
   379  }
   380  
   381  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   382  	const expected = "hi"
   383  	stopCh := make(chan struct{})
   384  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   385  		w.Header().Add("MyHeader", expected)
   386  		w.WriteHeader(200)
   387  		w.(http.Flusher).Flush()
   388  		<-stopCh
   389  	}))
   390  	defer backend.Close()
   391  	defer close(stopCh)
   392  
   393  	backendURL, err := url.Parse(backend.URL)
   394  	if err != nil {
   395  		t.Fatal(err)
   396  	}
   397  
   398  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   399  	proxyHandler.FlushInterval = time.Microsecond
   400  
   401  	frontend := httptest.NewServer(proxyHandler)
   402  	defer frontend.Close()
   403  
   404  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   405  	req.Close = true
   406  
   407  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   408  	defer cancel()
   409  	req = req.WithContext(ctx)
   410  
   411  	res, err := frontend.Client().Do(req)
   412  	if err != nil {
   413  		t.Fatalf("Get: %v", err)
   414  	}
   415  	defer res.Body.Close()
   416  
   417  	if res.Header.Get("MyHeader") != expected {
   418  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   419  	}
   420  }
   421  
   422  func TestReverseProxyCancellation(t *testing.T) {
   423  	const backendResponse = "I am the backend"
   424  
   425  	reqInFlight := make(chan struct{})
   426  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   427  		close(reqInFlight) // cause the client to cancel its request
   428  
   429  		select {
   430  		case <-time.After(10 * time.Second):
   431  			// Note: this should only happen in broken implementations, and the
   432  			// closenotify case should be instantaneous.
   433  			t.Error("Handler never saw CloseNotify")
   434  			return
   435  		case <-w.(http.CloseNotifier).CloseNotify():
   436  		}
   437  
   438  		w.WriteHeader(http.StatusOK)
   439  		w.Write([]byte(backendResponse))
   440  	}))
   441  
   442  	defer backend.Close()
   443  
   444  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   445  
   446  	backendURL, err := url.Parse(backend.URL)
   447  	if err != nil {
   448  		t.Fatal(err)
   449  	}
   450  
   451  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   452  
   453  	// Discards errors of the form:
   454  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   455  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   456  
   457  	frontend := httptest.NewServer(proxyHandler)
   458  	defer frontend.Close()
   459  	frontendClient := frontend.Client()
   460  
   461  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   462  	go func() {
   463  		<-reqInFlight
   464  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   465  	}()
   466  	res, err := frontendClient.Do(getReq)
   467  	if res != nil {
   468  		t.Errorf("got response %v; want nil", res.Status)
   469  	}
   470  	if err == nil {
   471  		// This should be an error like:
   472  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   473  		//    use of closed network connection
   474  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   475  	}
   476  }
   477  
   478  func req(t *testing.T, v string) *http.Request {
   479  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   480  	if err != nil {
   481  		t.Fatal(err)
   482  	}
   483  	return req
   484  }
   485  
   486  // Issue 12344
   487  func TestNilBody(t *testing.T) {
   488  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   489  		w.Write([]byte("hi"))
   490  	}))
   491  	defer backend.Close()
   492  
   493  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   494  		backURL, _ := url.Parse(backend.URL)
   495  		rp := NewSingleHostReverseProxy(backURL)
   496  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   497  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   498  		rp.ServeHTTP(w, r)
   499  	}))
   500  	defer frontend.Close()
   501  
   502  	res, err := http.Get(frontend.URL)
   503  	if err != nil {
   504  		t.Fatal(err)
   505  	}
   506  	defer res.Body.Close()
   507  	slurp, err := io.ReadAll(res.Body)
   508  	if err != nil {
   509  		t.Fatal(err)
   510  	}
   511  	if string(slurp) != "hi" {
   512  		t.Errorf("Got %q; want %q", slurp, "hi")
   513  	}
   514  }
   515  
   516  // Issue 15524
   517  func TestUserAgentHeader(t *testing.T) {
   518  	const explicitUA = "explicit UA"
   519  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   520  		if r.URL.Path == "/noua" {
   521  			if c := r.Header.Get("User-Agent"); c != "" {
   522  				t.Errorf("handler got non-empty User-Agent header %q", c)
   523  			}
   524  			return
   525  		}
   526  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   527  			t.Errorf("handler got unexpected User-Agent header %q", c)
   528  		}
   529  	}))
   530  	defer backend.Close()
   531  	backendURL, err := url.Parse(backend.URL)
   532  	if err != nil {
   533  		t.Fatal(err)
   534  	}
   535  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   536  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   537  	frontend := httptest.NewServer(proxyHandler)
   538  	defer frontend.Close()
   539  	frontendClient := frontend.Client()
   540  
   541  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   542  	getReq.Header.Set("User-Agent", explicitUA)
   543  	getReq.Close = true
   544  	res, err := frontendClient.Do(getReq)
   545  	if err != nil {
   546  		t.Fatalf("Get: %v", err)
   547  	}
   548  	res.Body.Close()
   549  
   550  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   551  	getReq.Header.Set("User-Agent", "")
   552  	getReq.Close = true
   553  	res, err = frontendClient.Do(getReq)
   554  	if err != nil {
   555  		t.Fatalf("Get: %v", err)
   556  	}
   557  	res.Body.Close()
   558  }
   559  
   560  type bufferPool struct {
   561  	get func() []byte
   562  	put func([]byte)
   563  }
   564  
   565  func (bp bufferPool) Get() []byte  { return bp.get() }
   566  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   567  
   568  func TestReverseProxyGetPutBuffer(t *testing.T) {
   569  	const msg = "hi"
   570  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   571  		io.WriteString(w, msg)
   572  	}))
   573  	defer backend.Close()
   574  
   575  	backendURL, err := url.Parse(backend.URL)
   576  	if err != nil {
   577  		t.Fatal(err)
   578  	}
   579  
   580  	var (
   581  		mu  sync.Mutex
   582  		log []string
   583  	)
   584  	addLog := func(event string) {
   585  		mu.Lock()
   586  		defer mu.Unlock()
   587  		log = append(log, event)
   588  	}
   589  	rp := NewSingleHostReverseProxy(backendURL)
   590  	const size = 1234
   591  	rp.BufferPool = bufferPool{
   592  		get: func() []byte {
   593  			addLog("getBuf")
   594  			return make([]byte, size)
   595  		},
   596  		put: func(p []byte) {
   597  			addLog("putBuf-" + strconv.Itoa(len(p)))
   598  		},
   599  	}
   600  	frontend := httptest.NewServer(rp)
   601  	defer frontend.Close()
   602  
   603  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   604  	req.Close = true
   605  	res, err := frontend.Client().Do(req)
   606  	if err != nil {
   607  		t.Fatalf("Get: %v", err)
   608  	}
   609  	slurp, err := io.ReadAll(res.Body)
   610  	res.Body.Close()
   611  	if err != nil {
   612  		t.Fatalf("reading body: %v", err)
   613  	}
   614  	if string(slurp) != msg {
   615  		t.Errorf("msg = %q; want %q", slurp, msg)
   616  	}
   617  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   618  	mu.Lock()
   619  	defer mu.Unlock()
   620  	if !reflect.DeepEqual(log, wantLog) {
   621  		t.Errorf("Log events = %q; want %q", log, wantLog)
   622  	}
   623  }
   624  
   625  func TestReverseProxy_Post(t *testing.T) {
   626  	const backendResponse = "I am the backend"
   627  	const backendStatus = 200
   628  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   629  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   630  		slurp, err := io.ReadAll(r.Body)
   631  		if err != nil {
   632  			t.Errorf("Backend body read = %v", err)
   633  		}
   634  		if len(slurp) != len(requestBody) {
   635  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   636  		}
   637  		if !bytes.Equal(slurp, requestBody) {
   638  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   639  		}
   640  		w.Write([]byte(backendResponse))
   641  	}))
   642  	defer backend.Close()
   643  	backendURL, err := url.Parse(backend.URL)
   644  	if err != nil {
   645  		t.Fatal(err)
   646  	}
   647  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   648  	frontend := httptest.NewServer(proxyHandler)
   649  	defer frontend.Close()
   650  
   651  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   652  	res, err := frontend.Client().Do(postReq)
   653  	if err != nil {
   654  		t.Fatalf("Do: %v", err)
   655  	}
   656  	if g, e := res.StatusCode, backendStatus; g != e {
   657  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   658  	}
   659  	bodyBytes, _ := io.ReadAll(res.Body)
   660  	if g, e := string(bodyBytes), backendResponse; g != e {
   661  		t.Errorf("got body %q; expected %q", g, e)
   662  	}
   663  }
   664  
   665  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   666  
   667  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   668  	return fn(req)
   669  }
   670  
   671  // Issue 16036: send a Request with a nil Body when possible
   672  func TestReverseProxy_NilBody(t *testing.T) {
   673  	backendURL, _ := url.Parse("http://fake.tld/")
   674  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   675  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   676  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   677  		if req.Body != nil {
   678  			t.Error("Body != nil; want a nil Body")
   679  		}
   680  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   681  	})
   682  	frontend := httptest.NewServer(proxyHandler)
   683  	defer frontend.Close()
   684  
   685  	res, err := frontend.Client().Get(frontend.URL)
   686  	if err != nil {
   687  		t.Fatal(err)
   688  	}
   689  	defer res.Body.Close()
   690  	if res.StatusCode != 502 {
   691  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   692  	}
   693  }
   694  
   695  // Issue 33142: always allocate the request headers
   696  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   697  	proxyHandler := new(ReverseProxy)
   698  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   699  	proxyHandler.Director = func(*http.Request) {}     // noop
   700  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   701  		if req.Header == nil {
   702  			t.Error("Header == nil; want a non-nil Header")
   703  		}
   704  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   705  	})
   706  
   707  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   708  		Method:     "GET",
   709  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   710  		Proto:      "HTTP/1.0",
   711  		ProtoMajor: 1,
   712  	})
   713  }
   714  
   715  // Issue 14237. Test ModifyResponse and that an error from it
   716  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   717  func TestReverseProxyModifyResponse(t *testing.T) {
   718  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   719  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   720  	}))
   721  	defer backendServer.Close()
   722  
   723  	rpURL, _ := url.Parse(backendServer.URL)
   724  	rproxy := NewSingleHostReverseProxy(rpURL)
   725  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   726  	rproxy.ModifyResponse = func(resp *http.Response) error {
   727  		if resp.Header.Get("X-Hit-Mod") != "true" {
   728  			return fmt.Errorf("tried to by-pass proxy")
   729  		}
   730  		return nil
   731  	}
   732  
   733  	frontendProxy := httptest.NewServer(rproxy)
   734  	defer frontendProxy.Close()
   735  
   736  	tests := []struct {
   737  		url      string
   738  		wantCode int
   739  	}{
   740  		{frontendProxy.URL + "/mod", http.StatusOK},
   741  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   742  	}
   743  
   744  	for i, tt := range tests {
   745  		resp, err := http.Get(tt.url)
   746  		if err != nil {
   747  			t.Fatalf("failed to reach proxy: %v", err)
   748  		}
   749  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   750  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   751  		}
   752  		resp.Body.Close()
   753  	}
   754  }
   755  
   756  type failingRoundTripper struct{}
   757  
   758  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   759  	return nil, errors.New("some error")
   760  }
   761  
   762  type staticResponseRoundTripper struct{ res *http.Response }
   763  
   764  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   765  	return rt.res, nil
   766  }
   767  
   768  func TestReverseProxyErrorHandler(t *testing.T) {
   769  	tests := []struct {
   770  		name           string
   771  		wantCode       int
   772  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   773  		transport      http.RoundTripper // defaults to failingRoundTripper
   774  		modifyResponse func(*http.Response) error
   775  	}{
   776  		{
   777  			name:     "default",
   778  			wantCode: http.StatusBadGateway,
   779  		},
   780  		{
   781  			name:         "errorhandler",
   782  			wantCode:     http.StatusTeapot,
   783  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   784  		},
   785  		{
   786  			name: "modifyresponse_noerr",
   787  			transport: staticResponseRoundTripper{
   788  				&http.Response{StatusCode: 345, Body: http.NoBody},
   789  			},
   790  			modifyResponse: func(res *http.Response) error {
   791  				res.StatusCode++
   792  				return nil
   793  			},
   794  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   795  			wantCode:     346,
   796  		},
   797  		{
   798  			name: "modifyresponse_err",
   799  			transport: staticResponseRoundTripper{
   800  				&http.Response{StatusCode: 345, Body: http.NoBody},
   801  			},
   802  			modifyResponse: func(res *http.Response) error {
   803  				res.StatusCode++
   804  				return errors.New("some error to trigger errorHandler")
   805  			},
   806  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   807  			wantCode:     http.StatusTeapot,
   808  		},
   809  	}
   810  
   811  	for _, tt := range tests {
   812  		t.Run(tt.name, func(t *testing.T) {
   813  			target := &url.URL{
   814  				Scheme: "http",
   815  				Host:   "dummy.tld",
   816  				Path:   "/",
   817  			}
   818  			rproxy := NewSingleHostReverseProxy(target)
   819  			rproxy.Transport = tt.transport
   820  			rproxy.ModifyResponse = tt.modifyResponse
   821  			if rproxy.Transport == nil {
   822  				rproxy.Transport = failingRoundTripper{}
   823  			}
   824  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   825  			if tt.errorHandler != nil {
   826  				rproxy.ErrorHandler = tt.errorHandler
   827  			}
   828  			frontendProxy := httptest.NewServer(rproxy)
   829  			defer frontendProxy.Close()
   830  
   831  			resp, err := http.Get(frontendProxy.URL + "/test")
   832  			if err != nil {
   833  				t.Fatalf("failed to reach proxy: %v", err)
   834  			}
   835  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   836  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   837  			}
   838  			resp.Body.Close()
   839  		})
   840  	}
   841  }
   842  
   843  // Issue 16659: log errors from short read
   844  func TestReverseProxy_CopyBuffer(t *testing.T) {
   845  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   846  		out := "this call was relayed by the reverse proxy"
   847  		// Coerce a wrong content length to induce io.UnexpectedEOF
   848  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   849  		fmt.Fprintln(w, out)
   850  	}))
   851  	defer backendServer.Close()
   852  
   853  	rpURL, err := url.Parse(backendServer.URL)
   854  	if err != nil {
   855  		t.Fatal(err)
   856  	}
   857  
   858  	var proxyLog bytes.Buffer
   859  	rproxy := NewSingleHostReverseProxy(rpURL)
   860  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   861  	donec := make(chan bool, 1)
   862  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   863  		defer func() { donec <- true }()
   864  		rproxy.ServeHTTP(w, r)
   865  	}))
   866  	defer frontendProxy.Close()
   867  
   868  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   869  		t.Fatalf("want non-nil error")
   870  	}
   871  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   872  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   873  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   874  	// continue after Get.
   875  	<-donec
   876  
   877  	expected := []string{
   878  		"EOF",
   879  		"read",
   880  	}
   881  	for _, phrase := range expected {
   882  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   883  			t.Errorf("expected log to contain phrase %q", phrase)
   884  		}
   885  	}
   886  }
   887  
   888  type staticTransport struct {
   889  	res *http.Response
   890  }
   891  
   892  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   893  	return t.res, nil
   894  }
   895  
   896  func BenchmarkServeHTTP(b *testing.B) {
   897  	res := &http.Response{
   898  		StatusCode: 200,
   899  		Body:       io.NopCloser(strings.NewReader("")),
   900  	}
   901  	proxy := &ReverseProxy{
   902  		Director:  func(*http.Request) {},
   903  		Transport: &staticTransport{res},
   904  	}
   905  
   906  	w := httptest.NewRecorder()
   907  	r := httptest.NewRequest("GET", "/", nil)
   908  
   909  	b.ReportAllocs()
   910  	for i := 0; i < b.N; i++ {
   911  		proxy.ServeHTTP(w, r)
   912  	}
   913  }
   914  
   915  func TestServeHTTPDeepCopy(t *testing.T) {
   916  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   917  		w.Write([]byte("Hello Gopher!"))
   918  	}))
   919  	defer backend.Close()
   920  	backendURL, err := url.Parse(backend.URL)
   921  	if err != nil {
   922  		t.Fatal(err)
   923  	}
   924  
   925  	type result struct {
   926  		before, after string
   927  	}
   928  
   929  	resultChan := make(chan result, 1)
   930  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   931  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   932  		before := r.URL.String()
   933  		proxyHandler.ServeHTTP(w, r)
   934  		after := r.URL.String()
   935  		resultChan <- result{before: before, after: after}
   936  	}))
   937  	defer frontend.Close()
   938  
   939  	want := result{before: "/", after: "/"}
   940  
   941  	res, err := frontend.Client().Get(frontend.URL)
   942  	if err != nil {
   943  		t.Fatalf("Do: %v", err)
   944  	}
   945  	res.Body.Close()
   946  
   947  	got := <-resultChan
   948  	if got != want {
   949  		t.Errorf("got = %+v; want = %+v", got, want)
   950  	}
   951  }
   952  
   953  // Issue 18327: verify we always do a deep copy of the Request.Header map
   954  // before any mutations.
   955  func TestClonesRequestHeaders(t *testing.T) {
   956  	log.SetOutput(io.Discard)
   957  	defer log.SetOutput(os.Stderr)
   958  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   959  	req.RemoteAddr = "1.2.3.4:56789"
   960  	rp := &ReverseProxy{
   961  		Director: func(req *http.Request) {
   962  			req.Header.Set("From-Director", "1")
   963  		},
   964  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   965  			if v := req.Header.Get("From-Director"); v != "1" {
   966  				t.Errorf("From-Directory value = %q; want 1", v)
   967  			}
   968  			return nil, io.EOF
   969  		}),
   970  	}
   971  	rp.ServeHTTP(httptest.NewRecorder(), req)
   972  
   973  	if req.Header.Get("From-Director") == "1" {
   974  		t.Error("Director header mutation modified caller's request")
   975  	}
   976  	if req.Header.Get("X-Forwarded-For") != "" {
   977  		t.Error("X-Forward-For header mutation modified caller's request")
   978  	}
   979  
   980  }
   981  
   982  type roundTripperFunc func(req *http.Request) (*http.Response, error)
   983  
   984  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   985  	return fn(req)
   986  }
   987  
   988  func TestModifyResponseClosesBody(t *testing.T) {
   989  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   990  	req.RemoteAddr = "1.2.3.4:56789"
   991  	closeCheck := new(checkCloser)
   992  	logBuf := new(bytes.Buffer)
   993  	outErr := errors.New("ModifyResponse error")
   994  	rp := &ReverseProxy{
   995  		Director: func(req *http.Request) {},
   996  		Transport: &staticTransport{&http.Response{
   997  			StatusCode: 200,
   998  			Body:       closeCheck,
   999  		}},
  1000  		ErrorLog: log.New(logBuf, "", 0),
  1001  		ModifyResponse: func(*http.Response) error {
  1002  			return outErr
  1003  		},
  1004  	}
  1005  	rec := httptest.NewRecorder()
  1006  	rp.ServeHTTP(rec, req)
  1007  	res := rec.Result()
  1008  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1009  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1010  	}
  1011  	if !closeCheck.closed {
  1012  		t.Errorf("body should have been closed")
  1013  	}
  1014  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1015  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1016  	}
  1017  }
  1018  
  1019  type checkCloser struct {
  1020  	closed bool
  1021  }
  1022  
  1023  func (cc *checkCloser) Close() error {
  1024  	cc.closed = true
  1025  	return nil
  1026  }
  1027  
  1028  func (cc *checkCloser) Read(b []byte) (int, error) {
  1029  	return len(b), nil
  1030  }
  1031  
  1032  // Issue 23643: panic on body copy error
  1033  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1034  	log.SetOutput(io.Discard)
  1035  	defer log.SetOutput(os.Stderr)
  1036  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1037  		out := "this call was relayed by the reverse proxy"
  1038  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1039  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1040  		fmt.Fprintln(w, out)
  1041  	}))
  1042  	defer backendServer.Close()
  1043  
  1044  	rpURL, err := url.Parse(backendServer.URL)
  1045  	if err != nil {
  1046  		t.Fatal(err)
  1047  	}
  1048  
  1049  	rproxy := NewSingleHostReverseProxy(rpURL)
  1050  
  1051  	// Ensure that the handler panics when the body read encounters an
  1052  	// io.ErrUnexpectedEOF
  1053  	defer func() {
  1054  		err := recover()
  1055  		if err == nil {
  1056  			t.Fatal("handler should have panicked")
  1057  		}
  1058  		if err != http.ErrAbortHandler {
  1059  			t.Fatal("expected ErrAbortHandler, got", err)
  1060  		}
  1061  	}()
  1062  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1063  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1064  }
  1065  
  1066  func TestSelectFlushInterval(t *testing.T) {
  1067  	tests := []struct {
  1068  		name string
  1069  		p    *ReverseProxy
  1070  		res  *http.Response
  1071  		want time.Duration
  1072  	}{
  1073  		{
  1074  			name: "default",
  1075  			res:  &http.Response{},
  1076  			p:    &ReverseProxy{FlushInterval: 123},
  1077  			want: 123,
  1078  		},
  1079  		{
  1080  			name: "server-sent events overrides non-zero",
  1081  			res: &http.Response{
  1082  				Header: http.Header{
  1083  					"Content-Type": {"text/event-stream"},
  1084  				},
  1085  			},
  1086  			p:    &ReverseProxy{FlushInterval: 123},
  1087  			want: -1,
  1088  		},
  1089  		{
  1090  			name: "server-sent events overrides zero",
  1091  			res: &http.Response{
  1092  				Header: http.Header{
  1093  					"Content-Type": {"text/event-stream"},
  1094  				},
  1095  			},
  1096  			p:    &ReverseProxy{FlushInterval: 0},
  1097  			want: -1,
  1098  		},
  1099  		{
  1100  			name: "Content-Length: -1, overrides non-zero",
  1101  			res: &http.Response{
  1102  				ContentLength: -1,
  1103  			},
  1104  			p:    &ReverseProxy{FlushInterval: 123},
  1105  			want: -1,
  1106  		},
  1107  		{
  1108  			name: "Content-Length: -1, overrides zero",
  1109  			res: &http.Response{
  1110  				ContentLength: -1,
  1111  			},
  1112  			p:    &ReverseProxy{FlushInterval: 0},
  1113  			want: -1,
  1114  		},
  1115  	}
  1116  	for _, tt := range tests {
  1117  		t.Run(tt.name, func(t *testing.T) {
  1118  			got := tt.p.flushInterval(tt.res)
  1119  			if got != tt.want {
  1120  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1121  			}
  1122  		})
  1123  	}
  1124  }
  1125  
  1126  func TestReverseProxyWebSocket(t *testing.T) {
  1127  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1128  		if upgradeType(r.Header) != "websocket" {
  1129  			t.Error("unexpected backend request")
  1130  			http.Error(w, "unexpected request", 400)
  1131  			return
  1132  		}
  1133  		c, _, err := w.(http.Hijacker).Hijack()
  1134  		if err != nil {
  1135  			t.Error(err)
  1136  			return
  1137  		}
  1138  		defer c.Close()
  1139  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1140  		bs := bufio.NewScanner(c)
  1141  		if !bs.Scan() {
  1142  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1143  			return
  1144  		}
  1145  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1146  	}))
  1147  	defer backendServer.Close()
  1148  
  1149  	backURL, _ := url.Parse(backendServer.URL)
  1150  	rproxy := NewSingleHostReverseProxy(backURL)
  1151  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1152  	rproxy.ModifyResponse = func(res *http.Response) error {
  1153  		res.Header.Add("X-Modified", "true")
  1154  		return nil
  1155  	}
  1156  
  1157  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1158  		rw.Header().Set("X-Header", "X-Value")
  1159  		rproxy.ServeHTTP(rw, req)
  1160  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1161  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1162  		}
  1163  	})
  1164  
  1165  	frontendProxy := httptest.NewServer(handler)
  1166  	defer frontendProxy.Close()
  1167  
  1168  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1169  	req.Header.Set("Connection", "Upgrade")
  1170  	req.Header.Set("Upgrade", "websocket")
  1171  
  1172  	c := frontendProxy.Client()
  1173  	res, err := c.Do(req)
  1174  	if err != nil {
  1175  		t.Fatal(err)
  1176  	}
  1177  	if res.StatusCode != 101 {
  1178  		t.Fatalf("status = %v; want 101", res.Status)
  1179  	}
  1180  
  1181  	got := res.Header.Get("X-Header")
  1182  	want := "X-Value"
  1183  	if got != want {
  1184  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1185  	}
  1186  
  1187  	if upgradeType(res.Header) != "websocket" {
  1188  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1189  	}
  1190  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1191  	if !ok {
  1192  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1193  	}
  1194  	defer rwc.Close()
  1195  
  1196  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1197  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1198  	}
  1199  
  1200  	io.WriteString(rwc, "Hello\n")
  1201  	bs := bufio.NewScanner(rwc)
  1202  	if !bs.Scan() {
  1203  		t.Fatalf("Scan: %v", bs.Err())
  1204  	}
  1205  	got = bs.Text()
  1206  	want = `backend got "Hello"`
  1207  	if got != want {
  1208  		t.Errorf("got %#q, want %#q", got, want)
  1209  	}
  1210  }
  1211  
  1212  func TestReverseProxyWebSocketCancelation(t *testing.T) {
  1213  	n := 5
  1214  	triggerCancelCh := make(chan bool, n)
  1215  	nthResponse := func(i int) string {
  1216  		return fmt.Sprintf("backend response #%d\n", i)
  1217  	}
  1218  	terminalMsg := "final message"
  1219  
  1220  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1221  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1222  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1223  			http.Error(w, "Unexpected request", 400)
  1224  			return
  1225  		}
  1226  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1227  		if err != nil {
  1228  			t.Error(err)
  1229  			return
  1230  		}
  1231  		defer conn.Close()
  1232  
  1233  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1234  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1235  			t.Error(err)
  1236  			return
  1237  		}
  1238  		if _, _, err := bufrw.ReadLine(); err != nil {
  1239  			t.Errorf("Failed to read line from client: %v", err)
  1240  			return
  1241  		}
  1242  
  1243  		for i := 0; i < n; i++ {
  1244  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1245  				select {
  1246  				case <-triggerCancelCh:
  1247  				default:
  1248  					t.Errorf("Writing response #%d failed: %v", i, err)
  1249  				}
  1250  				return
  1251  			}
  1252  			bufrw.Flush()
  1253  			time.Sleep(time.Second)
  1254  		}
  1255  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1256  			select {
  1257  			case <-triggerCancelCh:
  1258  			default:
  1259  				t.Errorf("Failed to write terminal message: %v", err)
  1260  			}
  1261  		}
  1262  		bufrw.Flush()
  1263  	}))
  1264  	defer cst.Close()
  1265  
  1266  	backendURL, _ := url.Parse(cst.URL)
  1267  	rproxy := NewSingleHostReverseProxy(backendURL)
  1268  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1269  	rproxy.ModifyResponse = func(res *http.Response) error {
  1270  		res.Header.Add("X-Modified", "true")
  1271  		return nil
  1272  	}
  1273  
  1274  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1275  		rw.Header().Set("X-Header", "X-Value")
  1276  		ctx, cancel := context.WithCancel(req.Context())
  1277  		go func() {
  1278  			<-triggerCancelCh
  1279  			cancel()
  1280  		}()
  1281  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1282  	})
  1283  
  1284  	frontendProxy := httptest.NewServer(handler)
  1285  	defer frontendProxy.Close()
  1286  
  1287  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1288  	req.Header.Set("Connection", "Upgrade")
  1289  	req.Header.Set("Upgrade", "websocket")
  1290  
  1291  	res, err := frontendProxy.Client().Do(req)
  1292  	if err != nil {
  1293  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1294  	}
  1295  	defer res.Body.Close()
  1296  	if g, w := res.StatusCode, 101; g != w {
  1297  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1298  	}
  1299  
  1300  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1301  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1302  	}
  1303  
  1304  	if g, w := upgradeType(res.Header), "websocket"; g != w {
  1305  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1306  	}
  1307  
  1308  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1309  	if !ok {
  1310  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1311  	}
  1312  
  1313  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1314  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1315  	}
  1316  
  1317  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1318  		t.Fatalf("Failed to write first message: %v", err)
  1319  	}
  1320  
  1321  	// Read loop.
  1322  
  1323  	br := bufio.NewReader(rwc)
  1324  	for {
  1325  		line, err := br.ReadString('\n')
  1326  		switch {
  1327  		case line == terminalMsg: // this case before "err == io.EOF"
  1328  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1329  
  1330  		case err == io.EOF:
  1331  			return
  1332  
  1333  		case err != nil:
  1334  			t.Fatalf("Unexpected error: %v", err)
  1335  
  1336  		case line == nthResponse(0): // We've gotten the first response back
  1337  			// Let's trigger a cancel.
  1338  			close(triggerCancelCh)
  1339  		}
  1340  	}
  1341  }
  1342  
  1343  func TestUnannouncedTrailer(t *testing.T) {
  1344  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1345  		w.WriteHeader(http.StatusOK)
  1346  		w.(http.Flusher).Flush()
  1347  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1348  	}))
  1349  	defer backend.Close()
  1350  	backendURL, err := url.Parse(backend.URL)
  1351  	if err != nil {
  1352  		t.Fatal(err)
  1353  	}
  1354  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1355  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1356  	frontend := httptest.NewServer(proxyHandler)
  1357  	defer frontend.Close()
  1358  	frontendClient := frontend.Client()
  1359  
  1360  	res, err := frontendClient.Get(frontend.URL)
  1361  	if err != nil {
  1362  		t.Fatalf("Get: %v", err)
  1363  	}
  1364  
  1365  	io.ReadAll(res.Body)
  1366  
  1367  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1368  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1369  	}
  1370  
  1371  }
  1372  
  1373  func TestSingleJoinSlash(t *testing.T) {
  1374  	tests := []struct {
  1375  		slasha   string
  1376  		slashb   string
  1377  		expected string
  1378  	}{
  1379  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1380  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1381  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1382  		{"https://www.google.com", "", "https://www.google.com/"},
  1383  		{"", "favicon.ico", "/favicon.ico"},
  1384  	}
  1385  	for _, tt := range tests {
  1386  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1387  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1388  				tt.slasha,
  1389  				tt.slashb,
  1390  				tt.expected,
  1391  				got)
  1392  		}
  1393  	}
  1394  }
  1395  
  1396  func TestJoinURLPath(t *testing.T) {
  1397  	tests := []struct {
  1398  		a        *url.URL
  1399  		b        *url.URL
  1400  		wantPath string
  1401  		wantRaw  string
  1402  	}{
  1403  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1404  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1405  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1406  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1407  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1408  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1409  	}
  1410  
  1411  	for _, tt := range tests {
  1412  		p, rp := joinURLPath(tt.a, tt.b)
  1413  		if p != tt.wantPath || rp != tt.wantRaw {
  1414  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1415  				tt.a.Path, tt.a.RawPath,
  1416  				tt.b.Path, tt.b.RawPath,
  1417  				tt.wantPath, tt.wantRaw,
  1418  				p, rp)
  1419  		}
  1420  	}
  1421  }