github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmhttp/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/url"
    18  	"os"
    19  	"reflect"
    20  	"sort"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	http "github.com/hxx258456/ccgo/gmhttp"
    28  	"github.com/hxx258456/ccgo/gmhttp/httptest"
    29  	"github.com/hxx258456/ccgo/gmhttp/internal/ascii"
    30  )
    31  
    32  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    33  
    34  func init() {
    35  	inOurTests = true
    36  	hopHeaders = append(hopHeaders, fakeHopHeader)
    37  }
    38  
    39  func TestReverseProxy(t *testing.T) {
    40  	const backendResponse = "I am the backend"
    41  	const backendStatus = 404
    42  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    43  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    44  			c, _, _ := w.(http.Hijacker).Hijack()
    45  			c.Close()
    46  			return
    47  		}
    48  		if len(r.TransferEncoding) > 0 {
    49  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    50  		}
    51  		if r.Header.Get("X-Forwarded-For") == "" {
    52  			t.Errorf("didn't get X-Forwarded-For header")
    53  		}
    54  		if c := r.Header.Get("Connection"); c != "" {
    55  			t.Errorf("handler got Connection header value %q", c)
    56  		}
    57  		if c := r.Header.Get("Te"); c != "trailers" {
    58  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    59  		}
    60  		if c := r.Header.Get("Upgrade"); c != "" {
    61  			t.Errorf("handler got Upgrade header value %q", c)
    62  		}
    63  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    64  			t.Errorf("handler got Proxy-Connection header value %q", c)
    65  		}
    66  		if g, e := r.Host, "some-name"; g != e {
    67  			t.Errorf("backend got Host header %q, want %q", g, e)
    68  		}
    69  		w.Header().Set("Trailers", "not a special header field name")
    70  		w.Header().Set("Trailer", "X-Trailer")
    71  		w.Header().Set("X-Foo", "bar")
    72  		w.Header().Set("Upgrade", "foo")
    73  		w.Header().Set(fakeHopHeader, "foo")
    74  		w.Header().Add("X-Multi-Value", "foo")
    75  		w.Header().Add("X-Multi-Value", "bar")
    76  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    77  		w.WriteHeader(backendStatus)
    78  		w.Write([]byte(backendResponse))
    79  		w.Header().Set("X-Trailer", "trailer_value")
    80  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    81  	}))
    82  	defer backend.Close()
    83  	backendURL, err := url.Parse(backend.URL)
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    88  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    89  	frontend := httptest.NewServer(proxyHandler)
    90  	defer frontend.Close()
    91  	frontendClient := frontend.Client()
    92  
    93  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    94  	getReq.Host = "some-name"
    95  	getReq.Header.Set("Connection", "close, TE")
    96  	getReq.Header.Add("Te", "foo")
    97  	getReq.Header.Add("Te", "bar, trailers")
    98  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    99  	getReq.Header.Set("Upgrade", "foo")
   100  	getReq.Close = true
   101  	res, err := frontendClient.Do(getReq)
   102  	if err != nil {
   103  		t.Fatalf("Get: %v", err)
   104  	}
   105  	if g, e := res.StatusCode, backendStatus; g != e {
   106  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   107  	}
   108  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   109  		t.Errorf("got X-Foo %q; expected %q", g, e)
   110  	}
   111  	if c := res.Header.Get(fakeHopHeader); c != "" {
   112  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   113  	}
   114  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   115  		t.Errorf("header Trailers = %q; want %q", g, e)
   116  	}
   117  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   118  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   119  	}
   120  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   121  		t.Fatalf("got %d SetCookies, want %d", g, e)
   122  	}
   123  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   124  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   125  	}
   126  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   127  		t.Errorf("unexpected cookie %q", cookie.Name)
   128  	}
   129  	bodyBytes, _ := io.ReadAll(res.Body)
   130  	if g, e := string(bodyBytes), backendResponse; g != e {
   131  		t.Errorf("got body %q; expected %q", g, e)
   132  	}
   133  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   134  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   135  	}
   136  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   137  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   138  	}
   139  
   140  	// Test that a backend failing to be reached or one which doesn't return
   141  	// a response results in a StatusBadGateway.
   142  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   143  	getReq.Close = true
   144  	res, err = frontendClient.Do(getReq)
   145  	if err != nil {
   146  		t.Fatal(err)
   147  	}
   148  	res.Body.Close()
   149  	if res.StatusCode != http.StatusBadGateway {
   150  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   151  	}
   152  
   153  }
   154  
   155  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   156  // header value.
   157  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   158  	const fakeConnectionToken = "X-Fake-Connection-Token"
   159  	const backendResponse = "I am the backend"
   160  
   161  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   162  	// in the Request's Connection header.
   163  	const someConnHeader = "X-Some-Conn-Header"
   164  
   165  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   166  		if c := r.Header.Get("Connection"); c != "" {
   167  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   168  		}
   169  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   170  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   171  		}
   172  		if c := r.Header.Get(someConnHeader); c != "" {
   173  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   174  		}
   175  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   176  		w.Header().Add("Connection", someConnHeader)
   177  		w.Header().Set(someConnHeader, "should be deleted")
   178  		w.Header().Set(fakeConnectionToken, "should be deleted")
   179  		io.WriteString(w, backendResponse)
   180  	}))
   181  	defer backend.Close()
   182  	backendURL, err := url.Parse(backend.URL)
   183  	if err != nil {
   184  		t.Fatal(err)
   185  	}
   186  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   187  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   188  		proxyHandler.ServeHTTP(w, r)
   189  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   190  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   191  		}
   192  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   193  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   194  		}
   195  		c := r.Header["Connection"]
   196  		var cf []string
   197  		for _, f := range c {
   198  			for _, sf := range strings.Split(f, ",") {
   199  				if sf = strings.TrimSpace(sf); sf != "" {
   200  					cf = append(cf, sf)
   201  				}
   202  			}
   203  		}
   204  		sort.Strings(cf)
   205  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   206  		sort.Strings(expectedValues)
   207  		if !reflect.DeepEqual(cf, expectedValues) {
   208  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   209  		}
   210  	}))
   211  	defer frontend.Close()
   212  
   213  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   214  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   215  	getReq.Header.Add("Connection", someConnHeader)
   216  	getReq.Header.Set(someConnHeader, "should be deleted")
   217  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   218  	res, err := frontend.Client().Do(getReq)
   219  	if err != nil {
   220  		t.Fatalf("Get: %v", err)
   221  	}
   222  	defer res.Body.Close()
   223  	bodyBytes, err := io.ReadAll(res.Body)
   224  	if err != nil {
   225  		t.Fatalf("reading body: %v", err)
   226  	}
   227  	if got, want := string(bodyBytes), backendResponse; got != want {
   228  		t.Errorf("got body %q; want %q", got, want)
   229  	}
   230  	if c := res.Header.Get("Connection"); c != "" {
   231  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   232  	}
   233  	if c := res.Header.Get(someConnHeader); c != "" {
   234  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   235  	}
   236  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   237  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   238  	}
   239  }
   240  
   241  func TestReverseProxyStripEmptyConnection(t *testing.T) {
   242  	// See Issue 46313.
   243  	const backendResponse = "I am the backend"
   244  
   245  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   246  	// in the Request's Connection header.
   247  	const someConnHeader = "X-Some-Conn-Header"
   248  
   249  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   250  		if c := r.Header.Values("Connection"); len(c) != 0 {
   251  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
   252  		}
   253  		if c := r.Header.Get(someConnHeader); c != "" {
   254  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   255  		}
   256  		w.Header().Add("Connection", "")
   257  		w.Header().Add("Connection", someConnHeader)
   258  		w.Header().Set(someConnHeader, "should be deleted")
   259  		io.WriteString(w, backendResponse)
   260  	}))
   261  	defer backend.Close()
   262  	backendURL, err := url.Parse(backend.URL)
   263  	if err != nil {
   264  		t.Fatal(err)
   265  	}
   266  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   267  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   268  		proxyHandler.ServeHTTP(w, r)
   269  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   270  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   271  		}
   272  	}))
   273  	defer frontend.Close()
   274  
   275  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   276  	getReq.Header.Add("Connection", "")
   277  	getReq.Header.Add("Connection", someConnHeader)
   278  	getReq.Header.Set(someConnHeader, "should be deleted")
   279  	res, err := frontend.Client().Do(getReq)
   280  	if err != nil {
   281  		t.Fatalf("Get: %v", err)
   282  	}
   283  	defer res.Body.Close()
   284  	bodyBytes, err := io.ReadAll(res.Body)
   285  	if err != nil {
   286  		t.Fatalf("reading body: %v", err)
   287  	}
   288  	if got, want := string(bodyBytes), backendResponse; got != want {
   289  		t.Errorf("got body %q; want %q", got, want)
   290  	}
   291  	if c := res.Header.Get("Connection"); c != "" {
   292  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   293  	}
   294  	if c := res.Header.Get(someConnHeader); c != "" {
   295  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   296  	}
   297  }
   298  
   299  func TestXForwardedFor(t *testing.T) {
   300  	const prevForwardedFor = "client ip"
   301  	const backendResponse = "I am the backend"
   302  	const backendStatus = 404
   303  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   304  		if r.Header.Get("X-Forwarded-For") == "" {
   305  			t.Errorf("didn't get X-Forwarded-For header")
   306  		}
   307  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   308  			t.Errorf("X-Forwarded-For didn't contain prior data")
   309  		}
   310  		w.WriteHeader(backendStatus)
   311  		w.Write([]byte(backendResponse))
   312  	}))
   313  	defer backend.Close()
   314  	backendURL, err := url.Parse(backend.URL)
   315  	if err != nil {
   316  		t.Fatal(err)
   317  	}
   318  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   319  	frontend := httptest.NewServer(proxyHandler)
   320  	defer frontend.Close()
   321  
   322  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   323  	getReq.Host = "some-name"
   324  	getReq.Header.Set("Connection", "close")
   325  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   326  	getReq.Close = true
   327  	res, err := frontend.Client().Do(getReq)
   328  	if err != nil {
   329  		t.Fatalf("Get: %v", err)
   330  	}
   331  	if g, e := res.StatusCode, backendStatus; g != e {
   332  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   333  	}
   334  	bodyBytes, _ := io.ReadAll(res.Body)
   335  	if g, e := string(bodyBytes), backendResponse; g != e {
   336  		t.Errorf("got body %q; expected %q", g, e)
   337  	}
   338  }
   339  
   340  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   341  func TestXForwardedFor_Omit(t *testing.T) {
   342  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   343  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   344  			t.Errorf("got X-Forwarded-For header: %q", v)
   345  		}
   346  		w.Write([]byte("hi"))
   347  	}))
   348  	defer backend.Close()
   349  	backendURL, err := url.Parse(backend.URL)
   350  	if err != nil {
   351  		t.Fatal(err)
   352  	}
   353  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   354  	frontend := httptest.NewServer(proxyHandler)
   355  	defer frontend.Close()
   356  
   357  	oldDirector := proxyHandler.Director
   358  	proxyHandler.Director = func(r *http.Request) {
   359  		r.Header["X-Forwarded-For"] = nil
   360  		oldDirector(r)
   361  	}
   362  
   363  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   364  	getReq.Host = "some-name"
   365  	getReq.Close = true
   366  	res, err := frontend.Client().Do(getReq)
   367  	if err != nil {
   368  		t.Fatalf("Get: %v", err)
   369  	}
   370  	res.Body.Close()
   371  }
   372  
   373  var proxyQueryTests = []struct {
   374  	baseSuffix string // suffix to add to backend URL
   375  	reqSuffix  string // suffix to add to frontend's request URL
   376  	want       string // what backend should see for final request URL (without ?)
   377  }{
   378  	{"", "", ""},
   379  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   380  	{"", "?us=er", "us=er"},
   381  	{"?sta=tic", "", "sta=tic"},
   382  }
   383  
   384  func TestReverseProxyQuery(t *testing.T) {
   385  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   386  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   387  		w.Write([]byte("hi"))
   388  	}))
   389  	defer backend.Close()
   390  
   391  	for i, tt := range proxyQueryTests {
   392  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   393  		if err != nil {
   394  			t.Fatal(err)
   395  		}
   396  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   397  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   398  		req.Close = true
   399  		res, err := frontend.Client().Do(req)
   400  		if err != nil {
   401  			t.Fatalf("%d. Get: %v", i, err)
   402  		}
   403  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   404  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   405  		}
   406  		res.Body.Close()
   407  		frontend.Close()
   408  	}
   409  }
   410  
   411  func TestReverseProxyFlushInterval(t *testing.T) {
   412  	const expected = "hi"
   413  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   414  		w.Write([]byte(expected))
   415  	}))
   416  	defer backend.Close()
   417  
   418  	backendURL, err := url.Parse(backend.URL)
   419  	if err != nil {
   420  		t.Fatal(err)
   421  	}
   422  
   423  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   424  	proxyHandler.FlushInterval = time.Microsecond
   425  
   426  	frontend := httptest.NewServer(proxyHandler)
   427  	defer frontend.Close()
   428  
   429  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   430  	req.Close = true
   431  	res, err := frontend.Client().Do(req)
   432  	if err != nil {
   433  		t.Fatalf("Get: %v", err)
   434  	}
   435  	defer res.Body.Close()
   436  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   437  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   438  	}
   439  }
   440  
   441  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   442  	const expected = "hi"
   443  	stopCh := make(chan struct{})
   444  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   445  		w.Header().Add("MyHeader", expected)
   446  		w.WriteHeader(200)
   447  		w.(http.Flusher).Flush()
   448  		<-stopCh
   449  	}))
   450  	defer backend.Close()
   451  	defer close(stopCh)
   452  
   453  	backendURL, err := url.Parse(backend.URL)
   454  	if err != nil {
   455  		t.Fatal(err)
   456  	}
   457  
   458  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   459  	proxyHandler.FlushInterval = time.Microsecond
   460  
   461  	frontend := httptest.NewServer(proxyHandler)
   462  	defer frontend.Close()
   463  
   464  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   465  	req.Close = true
   466  
   467  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   468  	defer cancel()
   469  	req = req.WithContext(ctx)
   470  
   471  	res, err := frontend.Client().Do(req)
   472  	if err != nil {
   473  		t.Fatalf("Get: %v", err)
   474  	}
   475  	defer res.Body.Close()
   476  
   477  	if res.Header.Get("MyHeader") != expected {
   478  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   479  	}
   480  }
   481  
   482  func TestReverseProxyCancellation(t *testing.T) {
   483  	const backendResponse = "I am the backend"
   484  
   485  	reqInFlight := make(chan struct{})
   486  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   487  		close(reqInFlight) // cause the client to cancel its request
   488  
   489  		select {
   490  		case <-time.After(10 * time.Second):
   491  			// Note: this should only happen in broken implementations, and the
   492  			// closenotify case should be instantaneous.
   493  			t.Error("Handler never saw CloseNotify")
   494  			return
   495  		case <-w.(http.CloseNotifier).CloseNotify():
   496  		}
   497  
   498  		w.WriteHeader(http.StatusOK)
   499  		w.Write([]byte(backendResponse))
   500  	}))
   501  
   502  	defer backend.Close()
   503  
   504  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   505  
   506  	backendURL, err := url.Parse(backend.URL)
   507  	if err != nil {
   508  		t.Fatal(err)
   509  	}
   510  
   511  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   512  
   513  	// Discards errors of the form:
   514  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   515  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   516  
   517  	frontend := httptest.NewServer(proxyHandler)
   518  	defer frontend.Close()
   519  	frontendClient := frontend.Client()
   520  
   521  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   522  	go func() {
   523  		<-reqInFlight
   524  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   525  	}()
   526  	res, err := frontendClient.Do(getReq)
   527  	if res != nil {
   528  		t.Errorf("got response %v; want nil", res.Status)
   529  	}
   530  	if err == nil {
   531  		// This should be an error like:
   532  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   533  		//    use of closed network connection
   534  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   535  	}
   536  }
   537  
   538  func req(t *testing.T, v string) *http.Request {
   539  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   540  	if err != nil {
   541  		t.Fatal(err)
   542  	}
   543  	return req
   544  }
   545  
   546  // Issue 12344
   547  func TestNilBody(t *testing.T) {
   548  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   549  		w.Write([]byte("hi"))
   550  	}))
   551  	defer backend.Close()
   552  
   553  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   554  		backURL, _ := url.Parse(backend.URL)
   555  		rp := NewSingleHostReverseProxy(backURL)
   556  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   557  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   558  		rp.ServeHTTP(w, r)
   559  	}))
   560  	defer frontend.Close()
   561  
   562  	res, err := http.Get(frontend.URL)
   563  	if err != nil {
   564  		t.Fatal(err)
   565  	}
   566  	defer res.Body.Close()
   567  	slurp, err := io.ReadAll(res.Body)
   568  	if err != nil {
   569  		t.Fatal(err)
   570  	}
   571  	if string(slurp) != "hi" {
   572  		t.Errorf("Got %q; want %q", slurp, "hi")
   573  	}
   574  }
   575  
   576  // Issue 15524
   577  func TestUserAgentHeader(t *testing.T) {
   578  	const explicitUA = "explicit UA"
   579  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   580  		if r.URL.Path == "/noua" {
   581  			if c := r.Header.Get("User-Agent"); c != "" {
   582  				t.Errorf("handler got non-empty User-Agent header %q", c)
   583  			}
   584  			return
   585  		}
   586  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   587  			t.Errorf("handler got unexpected User-Agent header %q", c)
   588  		}
   589  	}))
   590  	defer backend.Close()
   591  	backendURL, err := url.Parse(backend.URL)
   592  	if err != nil {
   593  		t.Fatal(err)
   594  	}
   595  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   596  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   597  	frontend := httptest.NewServer(proxyHandler)
   598  	defer frontend.Close()
   599  	frontendClient := frontend.Client()
   600  
   601  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   602  	getReq.Header.Set("User-Agent", explicitUA)
   603  	getReq.Close = true
   604  	res, err := frontendClient.Do(getReq)
   605  	if err != nil {
   606  		t.Fatalf("Get: %v", err)
   607  	}
   608  	res.Body.Close()
   609  
   610  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   611  	getReq.Header.Set("User-Agent", "")
   612  	getReq.Close = true
   613  	res, err = frontendClient.Do(getReq)
   614  	if err != nil {
   615  		t.Fatalf("Get: %v", err)
   616  	}
   617  	res.Body.Close()
   618  }
   619  
   620  type bufferPool struct {
   621  	get func() []byte
   622  	put func([]byte)
   623  }
   624  
   625  func (bp bufferPool) Get() []byte  { return bp.get() }
   626  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   627  
   628  func TestReverseProxyGetPutBuffer(t *testing.T) {
   629  	const msg = "hi"
   630  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   631  		io.WriteString(w, msg)
   632  	}))
   633  	defer backend.Close()
   634  
   635  	backendURL, err := url.Parse(backend.URL)
   636  	if err != nil {
   637  		t.Fatal(err)
   638  	}
   639  
   640  	var (
   641  		mu  sync.Mutex
   642  		log []string
   643  	)
   644  	addLog := func(event string) {
   645  		mu.Lock()
   646  		defer mu.Unlock()
   647  		log = append(log, event)
   648  	}
   649  	rp := NewSingleHostReverseProxy(backendURL)
   650  	const size = 1234
   651  	rp.BufferPool = bufferPool{
   652  		get: func() []byte {
   653  			addLog("getBuf")
   654  			return make([]byte, size)
   655  		},
   656  		put: func(p []byte) {
   657  			addLog("putBuf-" + strconv.Itoa(len(p)))
   658  		},
   659  	}
   660  	frontend := httptest.NewServer(rp)
   661  	defer frontend.Close()
   662  
   663  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   664  	req.Close = true
   665  	res, err := frontend.Client().Do(req)
   666  	if err != nil {
   667  		t.Fatalf("Get: %v", err)
   668  	}
   669  	slurp, err := io.ReadAll(res.Body)
   670  	res.Body.Close()
   671  	if err != nil {
   672  		t.Fatalf("reading body: %v", err)
   673  	}
   674  	if string(slurp) != msg {
   675  		t.Errorf("msg = %q; want %q", slurp, msg)
   676  	}
   677  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   678  	mu.Lock()
   679  	defer mu.Unlock()
   680  	if !reflect.DeepEqual(log, wantLog) {
   681  		t.Errorf("Log events = %q; want %q", log, wantLog)
   682  	}
   683  }
   684  
   685  func TestReverseProxy_Post(t *testing.T) {
   686  	const backendResponse = "I am the backend"
   687  	const backendStatus = 200
   688  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   689  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   690  		slurp, err := io.ReadAll(r.Body)
   691  		if err != nil {
   692  			t.Errorf("Backend body read = %v", err)
   693  		}
   694  		if len(slurp) != len(requestBody) {
   695  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   696  		}
   697  		if !bytes.Equal(slurp, requestBody) {
   698  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   699  		}
   700  		w.Write([]byte(backendResponse))
   701  	}))
   702  	defer backend.Close()
   703  	backendURL, err := url.Parse(backend.URL)
   704  	if err != nil {
   705  		t.Fatal(err)
   706  	}
   707  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   708  	frontend := httptest.NewServer(proxyHandler)
   709  	defer frontend.Close()
   710  
   711  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   712  	res, err := frontend.Client().Do(postReq)
   713  	if err != nil {
   714  		t.Fatalf("Do: %v", err)
   715  	}
   716  	if g, e := res.StatusCode, backendStatus; g != e {
   717  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   718  	}
   719  	bodyBytes, _ := io.ReadAll(res.Body)
   720  	if g, e := string(bodyBytes), backendResponse; g != e {
   721  		t.Errorf("got body %q; expected %q", g, e)
   722  	}
   723  }
   724  
   725  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   726  
   727  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   728  	return fn(req)
   729  }
   730  
   731  // Issue 16036: send a Request with a nil Body when possible
   732  func TestReverseProxy_NilBody(t *testing.T) {
   733  	backendURL, _ := url.Parse("http://fake.tld/")
   734  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   735  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   736  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   737  		if req.Body != nil {
   738  			t.Error("Body != nil; want a nil Body")
   739  		}
   740  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   741  	})
   742  	frontend := httptest.NewServer(proxyHandler)
   743  	defer frontend.Close()
   744  
   745  	res, err := frontend.Client().Get(frontend.URL)
   746  	if err != nil {
   747  		t.Fatal(err)
   748  	}
   749  	defer res.Body.Close()
   750  	if res.StatusCode != 502 {
   751  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   752  	}
   753  }
   754  
   755  // Issue 33142: always allocate the request headers
   756  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   757  	proxyHandler := new(ReverseProxy)
   758  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   759  	proxyHandler.Director = func(*http.Request) {}     // noop
   760  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   761  		if req.Header == nil {
   762  			t.Error("Header == nil; want a non-nil Header")
   763  		}
   764  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   765  	})
   766  
   767  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   768  		Method:     "GET",
   769  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   770  		Proto:      "HTTP/1.0",
   771  		ProtoMajor: 1,
   772  	})
   773  }
   774  
   775  // Issue 14237. Test ModifyResponse and that an error from it
   776  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   777  func TestReverseProxyModifyResponse(t *testing.T) {
   778  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   779  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   780  	}))
   781  	defer backendServer.Close()
   782  
   783  	rpURL, _ := url.Parse(backendServer.URL)
   784  	rproxy := NewSingleHostReverseProxy(rpURL)
   785  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   786  	rproxy.ModifyResponse = func(resp *http.Response) error {
   787  		if resp.Header.Get("X-Hit-Mod") != "true" {
   788  			return fmt.Errorf("tried to by-pass proxy")
   789  		}
   790  		return nil
   791  	}
   792  
   793  	frontendProxy := httptest.NewServer(rproxy)
   794  	defer frontendProxy.Close()
   795  
   796  	tests := []struct {
   797  		url      string
   798  		wantCode int
   799  	}{
   800  		{frontendProxy.URL + "/mod", http.StatusOK},
   801  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   802  	}
   803  
   804  	for i, tt := range tests {
   805  		resp, err := http.Get(tt.url)
   806  		if err != nil {
   807  			t.Fatalf("failed to reach proxy: %v", err)
   808  		}
   809  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   810  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   811  		}
   812  		resp.Body.Close()
   813  	}
   814  }
   815  
   816  type failingRoundTripper struct{}
   817  
   818  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   819  	return nil, errors.New("some error")
   820  }
   821  
   822  type staticResponseRoundTripper struct{ res *http.Response }
   823  
   824  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   825  	return rt.res, nil
   826  }
   827  
   828  func TestReverseProxyErrorHandler(t *testing.T) {
   829  	tests := []struct {
   830  		name           string
   831  		wantCode       int
   832  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   833  		transport      http.RoundTripper // defaults to failingRoundTripper
   834  		modifyResponse func(*http.Response) error
   835  	}{
   836  		{
   837  			name:     "default",
   838  			wantCode: http.StatusBadGateway,
   839  		},
   840  		{
   841  			name:         "errorhandler",
   842  			wantCode:     http.StatusTeapot,
   843  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   844  		},
   845  		{
   846  			name: "modifyresponse_noerr",
   847  			transport: staticResponseRoundTripper{
   848  				&http.Response{StatusCode: 345, Body: http.NoBody},
   849  			},
   850  			modifyResponse: func(res *http.Response) error {
   851  				res.StatusCode++
   852  				return nil
   853  			},
   854  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   855  			wantCode:     346,
   856  		},
   857  		{
   858  			name: "modifyresponse_err",
   859  			transport: staticResponseRoundTripper{
   860  				&http.Response{StatusCode: 345, Body: http.NoBody},
   861  			},
   862  			modifyResponse: func(res *http.Response) error {
   863  				res.StatusCode++
   864  				return errors.New("some error to trigger errorHandler")
   865  			},
   866  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   867  			wantCode:     http.StatusTeapot,
   868  		},
   869  	}
   870  
   871  	for _, tt := range tests {
   872  		t.Run(tt.name, func(t *testing.T) {
   873  			target := &url.URL{
   874  				Scheme: "http",
   875  				Host:   "dummy.tld",
   876  				Path:   "/",
   877  			}
   878  			rproxy := NewSingleHostReverseProxy(target)
   879  			rproxy.Transport = tt.transport
   880  			rproxy.ModifyResponse = tt.modifyResponse
   881  			if rproxy.Transport == nil {
   882  				rproxy.Transport = failingRoundTripper{}
   883  			}
   884  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   885  			if tt.errorHandler != nil {
   886  				rproxy.ErrorHandler = tt.errorHandler
   887  			}
   888  			frontendProxy := httptest.NewServer(rproxy)
   889  			defer frontendProxy.Close()
   890  
   891  			resp, err := http.Get(frontendProxy.URL + "/test")
   892  			if err != nil {
   893  				t.Fatalf("failed to reach proxy: %v", err)
   894  			}
   895  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   896  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   897  			}
   898  			resp.Body.Close()
   899  		})
   900  	}
   901  }
   902  
   903  // Issue 16659: log errors from short read
   904  func TestReverseProxy_CopyBuffer(t *testing.T) {
   905  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   906  		out := "this call was relayed by the reverse proxy"
   907  		// Coerce a wrong content length to induce io.UnexpectedEOF
   908  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   909  		fmt.Fprintln(w, out)
   910  	}))
   911  	defer backendServer.Close()
   912  
   913  	rpURL, err := url.Parse(backendServer.URL)
   914  	if err != nil {
   915  		t.Fatal(err)
   916  	}
   917  
   918  	var proxyLog bytes.Buffer
   919  	rproxy := NewSingleHostReverseProxy(rpURL)
   920  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   921  	donec := make(chan bool, 1)
   922  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   923  		defer func() { donec <- true }()
   924  		rproxy.ServeHTTP(w, r)
   925  	}))
   926  	defer frontendProxy.Close()
   927  
   928  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   929  		t.Fatalf("want non-nil error")
   930  	}
   931  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   932  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   933  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   934  	// continue after Get.
   935  	<-donec
   936  
   937  	expected := []string{
   938  		"EOF",
   939  		"read",
   940  	}
   941  	for _, phrase := range expected {
   942  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   943  			t.Errorf("expected log to contain phrase %q", phrase)
   944  		}
   945  	}
   946  }
   947  
   948  type staticTransport struct {
   949  	res *http.Response
   950  }
   951  
   952  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   953  	return t.res, nil
   954  }
   955  
   956  func BenchmarkServeHTTP(b *testing.B) {
   957  	res := &http.Response{
   958  		StatusCode: 200,
   959  		Body:       io.NopCloser(strings.NewReader("")),
   960  	}
   961  	proxy := &ReverseProxy{
   962  		Director:  func(*http.Request) {},
   963  		Transport: &staticTransport{res},
   964  	}
   965  
   966  	w := httptest.NewRecorder()
   967  	r := httptest.NewRequest("GET", "/", nil)
   968  
   969  	b.ReportAllocs()
   970  	for i := 0; i < b.N; i++ {
   971  		proxy.ServeHTTP(w, r)
   972  	}
   973  }
   974  
   975  func TestServeHTTPDeepCopy(t *testing.T) {
   976  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   977  		w.Write([]byte("Hello Gopher!"))
   978  	}))
   979  	defer backend.Close()
   980  	backendURL, err := url.Parse(backend.URL)
   981  	if err != nil {
   982  		t.Fatal(err)
   983  	}
   984  
   985  	type result struct {
   986  		before, after string
   987  	}
   988  
   989  	resultChan := make(chan result, 1)
   990  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   991  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   992  		before := r.URL.String()
   993  		proxyHandler.ServeHTTP(w, r)
   994  		after := r.URL.String()
   995  		resultChan <- result{before: before, after: after}
   996  	}))
   997  	defer frontend.Close()
   998  
   999  	want := result{before: "/", after: "/"}
  1000  
  1001  	res, err := frontend.Client().Get(frontend.URL)
  1002  	if err != nil {
  1003  		t.Fatalf("Do: %v", err)
  1004  	}
  1005  	res.Body.Close()
  1006  
  1007  	got := <-resultChan
  1008  	if got != want {
  1009  		t.Errorf("got = %+v; want = %+v", got, want)
  1010  	}
  1011  }
  1012  
  1013  // Issue 18327: verify we always do a deep copy of the Request.Header map
  1014  // before any mutations.
  1015  func TestClonesRequestHeaders(t *testing.T) {
  1016  	log.SetOutput(io.Discard)
  1017  	defer log.SetOutput(os.Stderr)
  1018  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1019  	req.RemoteAddr = "1.2.3.4:56789"
  1020  	rp := &ReverseProxy{
  1021  		Director: func(req *http.Request) {
  1022  			req.Header.Set("From-Director", "1")
  1023  		},
  1024  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  1025  			if v := req.Header.Get("From-Director"); v != "1" {
  1026  				t.Errorf("From-Directory value = %q; want 1", v)
  1027  			}
  1028  			return nil, io.EOF
  1029  		}),
  1030  	}
  1031  	rp.ServeHTTP(httptest.NewRecorder(), req)
  1032  
  1033  	if req.Header.Get("From-Director") == "1" {
  1034  		t.Error("Director header mutation modified caller's request")
  1035  	}
  1036  	if req.Header.Get("X-Forwarded-For") != "" {
  1037  		t.Error("X-Forward-For header mutation modified caller's request")
  1038  	}
  1039  
  1040  }
  1041  
  1042  type roundTripperFunc func(req *http.Request) (*http.Response, error)
  1043  
  1044  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  1045  	return fn(req)
  1046  }
  1047  
  1048  func TestModifyResponseClosesBody(t *testing.T) {
  1049  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1050  	req.RemoteAddr = "1.2.3.4:56789"
  1051  	closeCheck := new(checkCloser)
  1052  	logBuf := new(bytes.Buffer)
  1053  	outErr := errors.New("ModifyResponse error")
  1054  	rp := &ReverseProxy{
  1055  		Director: func(req *http.Request) {},
  1056  		Transport: &staticTransport{&http.Response{
  1057  			StatusCode: 200,
  1058  			Body:       closeCheck,
  1059  		}},
  1060  		ErrorLog: log.New(logBuf, "", 0),
  1061  		ModifyResponse: func(*http.Response) error {
  1062  			return outErr
  1063  		},
  1064  	}
  1065  	rec := httptest.NewRecorder()
  1066  	rp.ServeHTTP(rec, req)
  1067  	res := rec.Result()
  1068  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1069  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1070  	}
  1071  	if !closeCheck.closed {
  1072  		t.Errorf("body should have been closed")
  1073  	}
  1074  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1075  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1076  	}
  1077  }
  1078  
  1079  type checkCloser struct {
  1080  	closed bool
  1081  }
  1082  
  1083  func (cc *checkCloser) Close() error {
  1084  	cc.closed = true
  1085  	return nil
  1086  }
  1087  
  1088  func (cc *checkCloser) Read(b []byte) (int, error) {
  1089  	return len(b), nil
  1090  }
  1091  
  1092  // Issue 23643: panic on body copy error
  1093  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1094  	log.SetOutput(io.Discard)
  1095  	defer log.SetOutput(os.Stderr)
  1096  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1097  		out := "this call was relayed by the reverse proxy"
  1098  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1099  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1100  		fmt.Fprintln(w, out)
  1101  	}))
  1102  	defer backendServer.Close()
  1103  
  1104  	rpURL, err := url.Parse(backendServer.URL)
  1105  	if err != nil {
  1106  		t.Fatal(err)
  1107  	}
  1108  
  1109  	rproxy := NewSingleHostReverseProxy(rpURL)
  1110  
  1111  	// Ensure that the handler panics when the body read encounters an
  1112  	// io.ErrUnexpectedEOF
  1113  	defer func() {
  1114  		err := recover()
  1115  		if err == nil {
  1116  			t.Fatal("handler should have panicked")
  1117  		}
  1118  		if err != http.ErrAbortHandler {
  1119  			t.Fatal("expected ErrAbortHandler, got", err)
  1120  		}
  1121  	}()
  1122  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1123  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1124  }
  1125  
  1126  // Issue #46866: panic without closing incoming request body causes a panic
  1127  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1128  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1129  		out := "this call was relayed by the reverse proxy"
  1130  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1131  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1132  		fmt.Fprintln(w, out)
  1133  	}))
  1134  	defer backend.Close()
  1135  	backendURL, err := url.Parse(backend.URL)
  1136  	if err != nil {
  1137  		t.Fatal(err)
  1138  	}
  1139  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1140  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1141  	frontend := httptest.NewServer(proxyHandler)
  1142  	defer frontend.Close()
  1143  	frontendClient := frontend.Client()
  1144  
  1145  	var wg sync.WaitGroup
  1146  	for i := 0; i < 2; i++ {
  1147  		wg.Add(1)
  1148  		go func() {
  1149  			defer wg.Done()
  1150  			for j := 0; j < 10; j++ {
  1151  				const reqLen = 6 * 1024 * 1024
  1152  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1153  				req.ContentLength = reqLen
  1154  				resp, _ := frontendClient.Transport.RoundTrip(req)
  1155  				if resp != nil {
  1156  					io.Copy(io.Discard, resp.Body)
  1157  					resp.Body.Close()
  1158  				}
  1159  			}
  1160  		}()
  1161  	}
  1162  	wg.Wait()
  1163  }
  1164  
  1165  func TestSelectFlushInterval(t *testing.T) {
  1166  	tests := []struct {
  1167  		name string
  1168  		p    *ReverseProxy
  1169  		res  *http.Response
  1170  		want time.Duration
  1171  	}{
  1172  		{
  1173  			name: "default",
  1174  			res:  &http.Response{},
  1175  			p:    &ReverseProxy{FlushInterval: 123},
  1176  			want: 123,
  1177  		},
  1178  		{
  1179  			name: "server-sent events overrides non-zero",
  1180  			res: &http.Response{
  1181  				Header: http.Header{
  1182  					"Content-Type": {"text/event-stream"},
  1183  				},
  1184  			},
  1185  			p:    &ReverseProxy{FlushInterval: 123},
  1186  			want: -1,
  1187  		},
  1188  		{
  1189  			name: "server-sent events overrides zero",
  1190  			res: &http.Response{
  1191  				Header: http.Header{
  1192  					"Content-Type": {"text/event-stream"},
  1193  				},
  1194  			},
  1195  			p:    &ReverseProxy{FlushInterval: 0},
  1196  			want: -1,
  1197  		},
  1198  		{
  1199  			name: "Content-Length: -1, overrides non-zero",
  1200  			res: &http.Response{
  1201  				ContentLength: -1,
  1202  			},
  1203  			p:    &ReverseProxy{FlushInterval: 123},
  1204  			want: -1,
  1205  		},
  1206  		{
  1207  			name: "Content-Length: -1, overrides zero",
  1208  			res: &http.Response{
  1209  				ContentLength: -1,
  1210  			},
  1211  			p:    &ReverseProxy{FlushInterval: 0},
  1212  			want: -1,
  1213  		},
  1214  	}
  1215  	for _, tt := range tests {
  1216  		t.Run(tt.name, func(t *testing.T) {
  1217  			got := tt.p.flushInterval(tt.res)
  1218  			if got != tt.want {
  1219  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1220  			}
  1221  		})
  1222  	}
  1223  }
  1224  
  1225  func TestReverseProxyWebSocket(t *testing.T) {
  1226  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1227  		if upgradeType(r.Header) != "websocket" {
  1228  			t.Error("unexpected backend request")
  1229  			http.Error(w, "unexpected request", 400)
  1230  			return
  1231  		}
  1232  		c, _, err := w.(http.Hijacker).Hijack()
  1233  		if err != nil {
  1234  			t.Error(err)
  1235  			return
  1236  		}
  1237  		defer c.Close()
  1238  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1239  		bs := bufio.NewScanner(c)
  1240  		if !bs.Scan() {
  1241  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1242  			return
  1243  		}
  1244  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1245  	}))
  1246  	defer backendServer.Close()
  1247  
  1248  	backURL, _ := url.Parse(backendServer.URL)
  1249  	rproxy := NewSingleHostReverseProxy(backURL)
  1250  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1251  	rproxy.ModifyResponse = func(res *http.Response) error {
  1252  		res.Header.Add("X-Modified", "true")
  1253  		return nil
  1254  	}
  1255  
  1256  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1257  		rw.Header().Set("X-Header", "X-Value")
  1258  		rproxy.ServeHTTP(rw, req)
  1259  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1260  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1261  		}
  1262  	})
  1263  
  1264  	frontendProxy := httptest.NewServer(handler)
  1265  	defer frontendProxy.Close()
  1266  
  1267  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1268  	req.Header.Set("Connection", "Upgrade")
  1269  	req.Header.Set("Upgrade", "websocket")
  1270  
  1271  	c := frontendProxy.Client()
  1272  	res, err := c.Do(req)
  1273  	if err != nil {
  1274  		t.Fatal(err)
  1275  	}
  1276  	if res.StatusCode != 101 {
  1277  		t.Fatalf("status = %v; want 101", res.Status)
  1278  	}
  1279  
  1280  	got := res.Header.Get("X-Header")
  1281  	want := "X-Value"
  1282  	if got != want {
  1283  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1284  	}
  1285  
  1286  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1287  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1288  	}
  1289  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1290  	if !ok {
  1291  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1292  	}
  1293  	defer rwc.Close()
  1294  
  1295  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1296  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1297  	}
  1298  
  1299  	io.WriteString(rwc, "Hello\n")
  1300  	bs := bufio.NewScanner(rwc)
  1301  	if !bs.Scan() {
  1302  		t.Fatalf("Scan: %v", bs.Err())
  1303  	}
  1304  	got = bs.Text()
  1305  	want = `backend got "Hello"`
  1306  	if got != want {
  1307  		t.Errorf("got %#q, want %#q", got, want)
  1308  	}
  1309  }
  1310  
  1311  func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1312  	n := 5
  1313  	triggerCancelCh := make(chan bool, n)
  1314  	nthResponse := func(i int) string {
  1315  		return fmt.Sprintf("backend response #%d\n", i)
  1316  	}
  1317  	terminalMsg := "final message"
  1318  
  1319  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1320  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1321  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1322  			http.Error(w, "Unexpected request", 400)
  1323  			return
  1324  		}
  1325  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1326  		if err != nil {
  1327  			t.Error(err)
  1328  			return
  1329  		}
  1330  		defer conn.Close()
  1331  
  1332  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1333  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1334  			t.Error(err)
  1335  			return
  1336  		}
  1337  		if _, _, err := bufrw.ReadLine(); err != nil {
  1338  			t.Errorf("Failed to read line from client: %v", err)
  1339  			return
  1340  		}
  1341  
  1342  		for i := 0; i < n; i++ {
  1343  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1344  				select {
  1345  				case <-triggerCancelCh:
  1346  				default:
  1347  					t.Errorf("Writing response #%d failed: %v", i, err)
  1348  				}
  1349  				return
  1350  			}
  1351  			bufrw.Flush()
  1352  			time.Sleep(time.Second)
  1353  		}
  1354  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1355  			select {
  1356  			case <-triggerCancelCh:
  1357  			default:
  1358  				t.Errorf("Failed to write terminal message: %v", err)
  1359  			}
  1360  		}
  1361  		bufrw.Flush()
  1362  	}))
  1363  	defer cst.Close()
  1364  
  1365  	backendURL, _ := url.Parse(cst.URL)
  1366  	rproxy := NewSingleHostReverseProxy(backendURL)
  1367  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1368  	rproxy.ModifyResponse = func(res *http.Response) error {
  1369  		res.Header.Add("X-Modified", "true")
  1370  		return nil
  1371  	}
  1372  
  1373  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1374  		rw.Header().Set("X-Header", "X-Value")
  1375  		ctx, cancel := context.WithCancel(req.Context())
  1376  		go func() {
  1377  			<-triggerCancelCh
  1378  			cancel()
  1379  		}()
  1380  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1381  	})
  1382  
  1383  	frontendProxy := httptest.NewServer(handler)
  1384  	defer frontendProxy.Close()
  1385  
  1386  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1387  	req.Header.Set("Connection", "Upgrade")
  1388  	req.Header.Set("Upgrade", "websocket")
  1389  
  1390  	res, err := frontendProxy.Client().Do(req)
  1391  	if err != nil {
  1392  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1393  	}
  1394  	defer res.Body.Close()
  1395  	if g, w := res.StatusCode, 101; g != w {
  1396  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1397  	}
  1398  
  1399  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1400  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1401  	}
  1402  
  1403  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1404  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1405  	}
  1406  
  1407  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1408  	if !ok {
  1409  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1410  	}
  1411  
  1412  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1413  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1414  	}
  1415  
  1416  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1417  		t.Fatalf("Failed to write first message: %v", err)
  1418  	}
  1419  
  1420  	// Read loop.
  1421  
  1422  	br := bufio.NewReader(rwc)
  1423  	for {
  1424  		line, err := br.ReadString('\n')
  1425  		switch {
  1426  		case line == terminalMsg: // this case before "err == io.EOF"
  1427  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1428  
  1429  		case err == io.EOF:
  1430  			return
  1431  
  1432  		case err != nil:
  1433  			t.Fatalf("Unexpected error: %v", err)
  1434  
  1435  		case line == nthResponse(0): // We've gotten the first response back
  1436  			// Let's trigger a cancel.
  1437  			close(triggerCancelCh)
  1438  		}
  1439  	}
  1440  }
  1441  
  1442  func TestUnannouncedTrailer(t *testing.T) {
  1443  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1444  		w.WriteHeader(http.StatusOK)
  1445  		w.(http.Flusher).Flush()
  1446  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1447  	}))
  1448  	defer backend.Close()
  1449  	backendURL, err := url.Parse(backend.URL)
  1450  	if err != nil {
  1451  		t.Fatal(err)
  1452  	}
  1453  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1454  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1455  	frontend := httptest.NewServer(proxyHandler)
  1456  	defer frontend.Close()
  1457  	frontendClient := frontend.Client()
  1458  
  1459  	res, err := frontendClient.Get(frontend.URL)
  1460  	if err != nil {
  1461  		t.Fatalf("Get: %v", err)
  1462  	}
  1463  
  1464  	io.ReadAll(res.Body)
  1465  
  1466  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1467  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1468  	}
  1469  
  1470  }
  1471  
  1472  func TestSingleJoinSlash(t *testing.T) {
  1473  	tests := []struct {
  1474  		slasha   string
  1475  		slashb   string
  1476  		expected string
  1477  	}{
  1478  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1479  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1480  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1481  		{"https://www.google.com", "", "https://www.google.com/"},
  1482  		{"", "favicon.ico", "/favicon.ico"},
  1483  	}
  1484  	for _, tt := range tests {
  1485  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1486  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1487  				tt.slasha,
  1488  				tt.slashb,
  1489  				tt.expected,
  1490  				got)
  1491  		}
  1492  	}
  1493  }
  1494  
  1495  func TestJoinURLPath(t *testing.T) {
  1496  	tests := []struct {
  1497  		a        *url.URL
  1498  		b        *url.URL
  1499  		wantPath string
  1500  		wantRaw  string
  1501  	}{
  1502  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1503  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1504  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1505  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1506  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1507  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1508  	}
  1509  
  1510  	for _, tt := range tests {
  1511  		p, rp := joinURLPath(tt.a, tt.b)
  1512  		if p != tt.wantPath || rp != tt.wantRaw {
  1513  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1514  				tt.a.Path, tt.a.RawPath,
  1515  				tt.b.Path, tt.b.RawPath,
  1516  				tt.wantPath, tt.wantRaw,
  1517  				p, rp)
  1518  		}
  1519  	}
  1520  }