github.com/ck00004/CobaltStrikeParser-Go@v1.0.14/lib/http/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"os"
    18  	"reflect"
    19  	"sort"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/ck00004/CobaltStrikeParser-Go/lib/url"
    27  
    28  	"github.com/ck00004/CobaltStrikeParser-Go/lib/http/internal/ascii"
    29  
    30  	"github.com/ck00004/CobaltStrikeParser-Go/lib/http"
    31  
    32  	"github.com/ck00004/CobaltStrikeParser-Go/lib/http/httptest"
    33  )
    34  
    35  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    36  
    37  func init() {
    38  	inOurTests = true
    39  	hopHeaders = append(hopHeaders, fakeHopHeader)
    40  }
    41  
    42  func TestReverseProxy(t *testing.T) {
    43  	const backendResponse = "I am the backend"
    44  	const backendStatus = 404
    45  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    47  			c, _, _ := w.(http.Hijacker).Hijack()
    48  			c.Close()
    49  			return
    50  		}
    51  		if len(r.TransferEncoding) > 0 {
    52  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    53  		}
    54  		if r.Header.Get("X-Forwarded-For") == "" {
    55  			t.Errorf("didn't get X-Forwarded-For header")
    56  		}
    57  		if c := r.Header.Get("Connection"); c != "" {
    58  			t.Errorf("handler got Connection header value %q", c)
    59  		}
    60  		if c := r.Header.Get("Te"); c != "trailers" {
    61  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    62  		}
    63  		if c := r.Header.Get("Upgrade"); c != "" {
    64  			t.Errorf("handler got Upgrade header value %q", c)
    65  		}
    66  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    67  			t.Errorf("handler got Proxy-Connection header value %q", c)
    68  		}
    69  		if g, e := r.Host, "some-name"; g != e {
    70  			t.Errorf("backend got Host header %q, want %q", g, e)
    71  		}
    72  		w.Header().Set("Trailers", "not a special header field name")
    73  		w.Header().Set("Trailer", "X-Trailer")
    74  		w.Header().Set("X-Foo", "bar")
    75  		w.Header().Set("Upgrade", "foo")
    76  		w.Header().Set(fakeHopHeader, "foo")
    77  		w.Header().Add("X-Multi-Value", "foo")
    78  		w.Header().Add("X-Multi-Value", "bar")
    79  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    80  		w.WriteHeader(backendStatus)
    81  		w.Write([]byte(backendResponse))
    82  		w.Header().Set("X-Trailer", "trailer_value")
    83  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    84  	}))
    85  	defer backend.Close()
    86  	backendURL, err := url.Parse(backend.URL)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    91  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    92  	frontend := httptest.NewServer(proxyHandler)
    93  	defer frontend.Close()
    94  	frontendClient := frontend.Client()
    95  
    96  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    97  	getReq.Host = "some-name"
    98  	getReq.Header.Set("Connection", "close, TE")
    99  	getReq.Header.Add("Te", "foo")
   100  	getReq.Header.Add("Te", "bar, trailers")
   101  	getReq.Header.Set("Proxy-Connection", "should be deleted")
   102  	getReq.Header.Set("Upgrade", "foo")
   103  	getReq.Close = true
   104  	res, err := frontendClient.Do(getReq)
   105  	if err != nil {
   106  		t.Fatalf("Get: %v", err)
   107  	}
   108  	if g, e := res.StatusCode, backendStatus; g != e {
   109  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   110  	}
   111  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   112  		t.Errorf("got X-Foo %q; expected %q", g, e)
   113  	}
   114  	if c := res.Header.Get(fakeHopHeader); c != "" {
   115  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   116  	}
   117  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   118  		t.Errorf("header Trailers = %q; want %q", g, e)
   119  	}
   120  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   121  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   122  	}
   123  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   124  		t.Fatalf("got %d SetCookies, want %d", g, e)
   125  	}
   126  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   127  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   128  	}
   129  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   130  		t.Errorf("unexpected cookie %q", cookie.Name)
   131  	}
   132  	bodyBytes, _ := io.ReadAll(res.Body)
   133  	if g, e := string(bodyBytes), backendResponse; g != e {
   134  		t.Errorf("got body %q; expected %q", g, e)
   135  	}
   136  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   137  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   138  	}
   139  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   140  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   141  	}
   142  
   143  	// Test that a backend failing to be reached or one which doesn't return
   144  	// a response results in a StatusBadGateway.
   145  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   146  	getReq.Close = true
   147  	res, err = frontendClient.Do(getReq)
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  	res.Body.Close()
   152  	if res.StatusCode != http.StatusBadGateway {
   153  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   154  	}
   155  
   156  }
   157  
   158  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   159  // header value.
   160  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   161  	const fakeConnectionToken = "X-Fake-Connection-Token"
   162  	const backendResponse = "I am the backend"
   163  
   164  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   165  	// in the Request's Connection header.
   166  	const someConnHeader = "X-Some-Conn-Header"
   167  
   168  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   169  		if c := r.Header.Get("Connection"); c != "" {
   170  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   171  		}
   172  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   173  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   174  		}
   175  		if c := r.Header.Get(someConnHeader); c != "" {
   176  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   177  		}
   178  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   179  		w.Header().Add("Connection", someConnHeader)
   180  		w.Header().Set(someConnHeader, "should be deleted")
   181  		w.Header().Set(fakeConnectionToken, "should be deleted")
   182  		io.WriteString(w, backendResponse)
   183  	}))
   184  	defer backend.Close()
   185  	backendURL, err := url.Parse(backend.URL)
   186  	if err != nil {
   187  		t.Fatal(err)
   188  	}
   189  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   190  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   191  		proxyHandler.ServeHTTP(w, r)
   192  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   193  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   194  		}
   195  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   196  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   197  		}
   198  		c := r.Header["Connection"]
   199  		var cf []string
   200  		for _, f := range c {
   201  			for _, sf := range strings.Split(f, ",") {
   202  				if sf = strings.TrimSpace(sf); sf != "" {
   203  					cf = append(cf, sf)
   204  				}
   205  			}
   206  		}
   207  		sort.Strings(cf)
   208  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   209  		sort.Strings(expectedValues)
   210  		if !reflect.DeepEqual(cf, expectedValues) {
   211  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   212  		}
   213  	}))
   214  	defer frontend.Close()
   215  
   216  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   217  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   218  	getReq.Header.Add("Connection", someConnHeader)
   219  	getReq.Header.Set(someConnHeader, "should be deleted")
   220  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   221  	res, err := frontend.Client().Do(getReq)
   222  	if err != nil {
   223  		t.Fatalf("Get: %v", err)
   224  	}
   225  	defer res.Body.Close()
   226  	bodyBytes, err := io.ReadAll(res.Body)
   227  	if err != nil {
   228  		t.Fatalf("reading body: %v", err)
   229  	}
   230  	if got, want := string(bodyBytes), backendResponse; got != want {
   231  		t.Errorf("got body %q; want %q", got, want)
   232  	}
   233  	if c := res.Header.Get("Connection"); c != "" {
   234  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   235  	}
   236  	if c := res.Header.Get(someConnHeader); c != "" {
   237  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   238  	}
   239  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   240  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   241  	}
   242  }
   243  
   244  func TestReverseProxyStripEmptyConnection(t *testing.T) {
   245  	// See Issue 46313.
   246  	const backendResponse = "I am the backend"
   247  
   248  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   249  	// in the Request's Connection header.
   250  	const someConnHeader = "X-Some-Conn-Header"
   251  
   252  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   253  		if c := r.Header.Values("Connection"); len(c) != 0 {
   254  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
   255  		}
   256  		if c := r.Header.Get(someConnHeader); c != "" {
   257  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   258  		}
   259  		w.Header().Add("Connection", "")
   260  		w.Header().Add("Connection", someConnHeader)
   261  		w.Header().Set(someConnHeader, "should be deleted")
   262  		io.WriteString(w, backendResponse)
   263  	}))
   264  	defer backend.Close()
   265  	backendURL, err := url.Parse(backend.URL)
   266  	if err != nil {
   267  		t.Fatal(err)
   268  	}
   269  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   270  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   271  		proxyHandler.ServeHTTP(w, r)
   272  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   273  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   274  		}
   275  	}))
   276  	defer frontend.Close()
   277  
   278  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   279  	getReq.Header.Add("Connection", "")
   280  	getReq.Header.Add("Connection", someConnHeader)
   281  	getReq.Header.Set(someConnHeader, "should be deleted")
   282  	res, err := frontend.Client().Do(getReq)
   283  	if err != nil {
   284  		t.Fatalf("Get: %v", err)
   285  	}
   286  	defer res.Body.Close()
   287  	bodyBytes, err := io.ReadAll(res.Body)
   288  	if err != nil {
   289  		t.Fatalf("reading body: %v", err)
   290  	}
   291  	if got, want := string(bodyBytes), backendResponse; got != want {
   292  		t.Errorf("got body %q; want %q", got, want)
   293  	}
   294  	if c := res.Header.Get("Connection"); c != "" {
   295  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   296  	}
   297  	if c := res.Header.Get(someConnHeader); c != "" {
   298  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   299  	}
   300  }
   301  
   302  func TestXForwardedFor(t *testing.T) {
   303  	const prevForwardedFor = "client ip"
   304  	const backendResponse = "I am the backend"
   305  	const backendStatus = 404
   306  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   307  		if r.Header.Get("X-Forwarded-For") == "" {
   308  			t.Errorf("didn't get X-Forwarded-For header")
   309  		}
   310  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   311  			t.Errorf("X-Forwarded-For didn't contain prior data")
   312  		}
   313  		w.WriteHeader(backendStatus)
   314  		w.Write([]byte(backendResponse))
   315  	}))
   316  	defer backend.Close()
   317  	backendURL, err := url.Parse(backend.URL)
   318  	if err != nil {
   319  		t.Fatal(err)
   320  	}
   321  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   322  	frontend := httptest.NewServer(proxyHandler)
   323  	defer frontend.Close()
   324  
   325  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   326  	getReq.Host = "some-name"
   327  	getReq.Header.Set("Connection", "close")
   328  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   329  	getReq.Close = true
   330  	res, err := frontend.Client().Do(getReq)
   331  	if err != nil {
   332  		t.Fatalf("Get: %v", err)
   333  	}
   334  	if g, e := res.StatusCode, backendStatus; g != e {
   335  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   336  	}
   337  	bodyBytes, _ := io.ReadAll(res.Body)
   338  	if g, e := string(bodyBytes), backendResponse; g != e {
   339  		t.Errorf("got body %q; expected %q", g, e)
   340  	}
   341  }
   342  
   343  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   344  func TestXForwardedFor_Omit(t *testing.T) {
   345  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   346  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   347  			t.Errorf("got X-Forwarded-For header: %q", v)
   348  		}
   349  		w.Write([]byte("hi"))
   350  	}))
   351  	defer backend.Close()
   352  	backendURL, err := url.Parse(backend.URL)
   353  	if err != nil {
   354  		t.Fatal(err)
   355  	}
   356  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   357  	frontend := httptest.NewServer(proxyHandler)
   358  	defer frontend.Close()
   359  
   360  	oldDirector := proxyHandler.Director
   361  	proxyHandler.Director = func(r *http.Request) {
   362  		r.Header["X-Forwarded-For"] = nil
   363  		oldDirector(r)
   364  	}
   365  
   366  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   367  	getReq.Host = "some-name"
   368  	getReq.Close = true
   369  	res, err := frontend.Client().Do(getReq)
   370  	if err != nil {
   371  		t.Fatalf("Get: %v", err)
   372  	}
   373  	res.Body.Close()
   374  }
   375  
   376  var proxyQueryTests = []struct {
   377  	baseSuffix string // suffix to add to backend URL
   378  	reqSuffix  string // suffix to add to frontend's request URL
   379  	want       string // what backend should see for final request URL (without ?)
   380  }{
   381  	{"", "", ""},
   382  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   383  	{"", "?us=er", "us=er"},
   384  	{"?sta=tic", "", "sta=tic"},
   385  }
   386  
   387  func TestReverseProxyQuery(t *testing.T) {
   388  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   389  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   390  		w.Write([]byte("hi"))
   391  	}))
   392  	defer backend.Close()
   393  
   394  	for i, tt := range proxyQueryTests {
   395  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   396  		if err != nil {
   397  			t.Fatal(err)
   398  		}
   399  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   400  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   401  		req.Close = true
   402  		res, err := frontend.Client().Do(req)
   403  		if err != nil {
   404  			t.Fatalf("%d. Get: %v", i, err)
   405  		}
   406  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   407  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   408  		}
   409  		res.Body.Close()
   410  		frontend.Close()
   411  	}
   412  }
   413  
   414  func TestReverseProxyFlushInterval(t *testing.T) {
   415  	const expected = "hi"
   416  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   417  		w.Write([]byte(expected))
   418  	}))
   419  	defer backend.Close()
   420  
   421  	backendURL, err := url.Parse(backend.URL)
   422  	if err != nil {
   423  		t.Fatal(err)
   424  	}
   425  
   426  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   427  	proxyHandler.FlushInterval = time.Microsecond
   428  
   429  	frontend := httptest.NewServer(proxyHandler)
   430  	defer frontend.Close()
   431  
   432  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   433  	req.Close = true
   434  	res, err := frontend.Client().Do(req)
   435  	if err != nil {
   436  		t.Fatalf("Get: %v", err)
   437  	}
   438  	defer res.Body.Close()
   439  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   440  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   441  	}
   442  }
   443  
   444  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   445  	const expected = "hi"
   446  	stopCh := make(chan struct{})
   447  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   448  		w.Header().Add("MyHeader", expected)
   449  		w.WriteHeader(200)
   450  		w.(http.Flusher).Flush()
   451  		<-stopCh
   452  	}))
   453  	defer backend.Close()
   454  	defer close(stopCh)
   455  
   456  	backendURL, err := url.Parse(backend.URL)
   457  	if err != nil {
   458  		t.Fatal(err)
   459  	}
   460  
   461  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   462  	proxyHandler.FlushInterval = time.Microsecond
   463  
   464  	frontend := httptest.NewServer(proxyHandler)
   465  	defer frontend.Close()
   466  
   467  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   468  	req.Close = true
   469  
   470  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   471  	defer cancel()
   472  	req = req.WithContext(ctx)
   473  
   474  	res, err := frontend.Client().Do(req)
   475  	if err != nil {
   476  		t.Fatalf("Get: %v", err)
   477  	}
   478  	defer res.Body.Close()
   479  
   480  	if res.Header.Get("MyHeader") != expected {
   481  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   482  	}
   483  }
   484  
   485  func TestReverseProxyCancellation(t *testing.T) {
   486  	const backendResponse = "I am the backend"
   487  
   488  	reqInFlight := make(chan struct{})
   489  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   490  		close(reqInFlight) // cause the client to cancel its request
   491  
   492  		select {
   493  		case <-time.After(10 * time.Second):
   494  			// Note: this should only happen in broken implementations, and the
   495  			// closenotify case should be instantaneous.
   496  			t.Error("Handler never saw CloseNotify")
   497  			return
   498  		case <-w.(http.CloseNotifier).CloseNotify():
   499  		}
   500  
   501  		w.WriteHeader(http.StatusOK)
   502  		w.Write([]byte(backendResponse))
   503  	}))
   504  
   505  	defer backend.Close()
   506  
   507  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   508  
   509  	backendURL, err := url.Parse(backend.URL)
   510  	if err != nil {
   511  		t.Fatal(err)
   512  	}
   513  
   514  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   515  
   516  	// Discards errors of the form:
   517  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   518  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   519  
   520  	frontend := httptest.NewServer(proxyHandler)
   521  	defer frontend.Close()
   522  	frontendClient := frontend.Client()
   523  
   524  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   525  	go func() {
   526  		<-reqInFlight
   527  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   528  	}()
   529  	res, err := frontendClient.Do(getReq)
   530  	if res != nil {
   531  		t.Errorf("got response %v; want nil", res.Status)
   532  	}
   533  	if err == nil {
   534  		// This should be an error like:
   535  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   536  		//    use of closed network connection
   537  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   538  	}
   539  }
   540  
   541  func req(t *testing.T, v string) *http.Request {
   542  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   543  	if err != nil {
   544  		t.Fatal(err)
   545  	}
   546  	return req
   547  }
   548  
   549  // Issue 12344
   550  func TestNilBody(t *testing.T) {
   551  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   552  		w.Write([]byte("hi"))
   553  	}))
   554  	defer backend.Close()
   555  
   556  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   557  		backURL, _ := url.Parse(backend.URL)
   558  		rp := NewSingleHostReverseProxy(backURL)
   559  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   560  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   561  		rp.ServeHTTP(w, r)
   562  	}))
   563  	defer frontend.Close()
   564  
   565  	res, err := http.Get(frontend.URL)
   566  	if err != nil {
   567  		t.Fatal(err)
   568  	}
   569  	defer res.Body.Close()
   570  	slurp, err := io.ReadAll(res.Body)
   571  	if err != nil {
   572  		t.Fatal(err)
   573  	}
   574  	if string(slurp) != "hi" {
   575  		t.Errorf("Got %q; want %q", slurp, "hi")
   576  	}
   577  }
   578  
   579  // Issue 15524
   580  func TestUserAgentHeader(t *testing.T) {
   581  	const explicitUA = "explicit UA"
   582  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   583  		if r.URL.Path == "/noua" {
   584  			if c := r.Header.Get("User-Agent"); c != "" {
   585  				t.Errorf("handler got non-empty User-Agent header %q", c)
   586  			}
   587  			return
   588  		}
   589  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   590  			t.Errorf("handler got unexpected User-Agent header %q", c)
   591  		}
   592  	}))
   593  	defer backend.Close()
   594  	backendURL, err := url.Parse(backend.URL)
   595  	if err != nil {
   596  		t.Fatal(err)
   597  	}
   598  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   599  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   600  	frontend := httptest.NewServer(proxyHandler)
   601  	defer frontend.Close()
   602  	frontendClient := frontend.Client()
   603  
   604  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   605  	getReq.Header.Set("User-Agent", explicitUA)
   606  	getReq.Close = true
   607  	res, err := frontendClient.Do(getReq)
   608  	if err != nil {
   609  		t.Fatalf("Get: %v", err)
   610  	}
   611  	res.Body.Close()
   612  
   613  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   614  	getReq.Header.Set("User-Agent", "")
   615  	getReq.Close = true
   616  	res, err = frontendClient.Do(getReq)
   617  	if err != nil {
   618  		t.Fatalf("Get: %v", err)
   619  	}
   620  	res.Body.Close()
   621  }
   622  
   623  type bufferPool struct {
   624  	get func() []byte
   625  	put func([]byte)
   626  }
   627  
   628  func (bp bufferPool) Get() []byte  { return bp.get() }
   629  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   630  
   631  func TestReverseProxyGetPutBuffer(t *testing.T) {
   632  	const msg = "hi"
   633  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   634  		io.WriteString(w, msg)
   635  	}))
   636  	defer backend.Close()
   637  
   638  	backendURL, err := url.Parse(backend.URL)
   639  	if err != nil {
   640  		t.Fatal(err)
   641  	}
   642  
   643  	var (
   644  		mu  sync.Mutex
   645  		log []string
   646  	)
   647  	addLog := func(event string) {
   648  		mu.Lock()
   649  		defer mu.Unlock()
   650  		log = append(log, event)
   651  	}
   652  	rp := NewSingleHostReverseProxy(backendURL)
   653  	const size = 1234
   654  	rp.BufferPool = bufferPool{
   655  		get: func() []byte {
   656  			addLog("getBuf")
   657  			return make([]byte, size)
   658  		},
   659  		put: func(p []byte) {
   660  			addLog("putBuf-" + strconv.Itoa(len(p)))
   661  		},
   662  	}
   663  	frontend := httptest.NewServer(rp)
   664  	defer frontend.Close()
   665  
   666  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   667  	req.Close = true
   668  	res, err := frontend.Client().Do(req)
   669  	if err != nil {
   670  		t.Fatalf("Get: %v", err)
   671  	}
   672  	slurp, err := io.ReadAll(res.Body)
   673  	res.Body.Close()
   674  	if err != nil {
   675  		t.Fatalf("reading body: %v", err)
   676  	}
   677  	if string(slurp) != msg {
   678  		t.Errorf("msg = %q; want %q", slurp, msg)
   679  	}
   680  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   681  	mu.Lock()
   682  	defer mu.Unlock()
   683  	if !reflect.DeepEqual(log, wantLog) {
   684  		t.Errorf("Log events = %q; want %q", log, wantLog)
   685  	}
   686  }
   687  
   688  func TestReverseProxy_Post(t *testing.T) {
   689  	const backendResponse = "I am the backend"
   690  	const backendStatus = 200
   691  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   692  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   693  		slurp, err := io.ReadAll(r.Body)
   694  		if err != nil {
   695  			t.Errorf("Backend body read = %v", err)
   696  		}
   697  		if len(slurp) != len(requestBody) {
   698  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   699  		}
   700  		if !bytes.Equal(slurp, requestBody) {
   701  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   702  		}
   703  		w.Write([]byte(backendResponse))
   704  	}))
   705  	defer backend.Close()
   706  	backendURL, err := url.Parse(backend.URL)
   707  	if err != nil {
   708  		t.Fatal(err)
   709  	}
   710  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   711  	frontend := httptest.NewServer(proxyHandler)
   712  	defer frontend.Close()
   713  
   714  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   715  	res, err := frontend.Client().Do(postReq)
   716  	if err != nil {
   717  		t.Fatalf("Do: %v", err)
   718  	}
   719  	if g, e := res.StatusCode, backendStatus; g != e {
   720  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   721  	}
   722  	bodyBytes, _ := io.ReadAll(res.Body)
   723  	if g, e := string(bodyBytes), backendResponse; g != e {
   724  		t.Errorf("got body %q; expected %q", g, e)
   725  	}
   726  }
   727  
   728  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   729  
   730  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   731  	return fn(req)
   732  }
   733  
   734  // Issue 16036: send a Request with a nil Body when possible
   735  func TestReverseProxy_NilBody(t *testing.T) {
   736  	backendURL, _ := url.Parse("http://fake.tld/")
   737  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   738  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   739  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   740  		if req.Body != nil {
   741  			t.Error("Body != nil; want a nil Body")
   742  		}
   743  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   744  	})
   745  	frontend := httptest.NewServer(proxyHandler)
   746  	defer frontend.Close()
   747  
   748  	res, err := frontend.Client().Get(frontend.URL)
   749  	if err != nil {
   750  		t.Fatal(err)
   751  	}
   752  	defer res.Body.Close()
   753  	if res.StatusCode != 502 {
   754  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   755  	}
   756  }
   757  
   758  // Issue 33142: always allocate the request headers
   759  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   760  	proxyHandler := new(ReverseProxy)
   761  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   762  	proxyHandler.Director = func(*http.Request) {}     // noop
   763  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   764  		if req.Header == nil {
   765  			t.Error("Header == nil; want a non-nil Header")
   766  		}
   767  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   768  	})
   769  
   770  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   771  		Method:     "GET",
   772  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   773  		Proto:      "HTTP/1.0",
   774  		ProtoMajor: 1,
   775  	})
   776  }
   777  
   778  // Issue 14237. Test ModifyResponse and that an error from it
   779  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   780  func TestReverseProxyModifyResponse(t *testing.T) {
   781  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   782  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   783  	}))
   784  	defer backendServer.Close()
   785  
   786  	rpURL, _ := url.Parse(backendServer.URL)
   787  	rproxy := NewSingleHostReverseProxy(rpURL)
   788  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   789  	rproxy.ModifyResponse = func(resp *http.Response) error {
   790  		if resp.Header.Get("X-Hit-Mod") != "true" {
   791  			return fmt.Errorf("tried to by-pass proxy")
   792  		}
   793  		return nil
   794  	}
   795  
   796  	frontendProxy := httptest.NewServer(rproxy)
   797  	defer frontendProxy.Close()
   798  
   799  	tests := []struct {
   800  		url      string
   801  		wantCode int
   802  	}{
   803  		{frontendProxy.URL + "/mod", http.StatusOK},
   804  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   805  	}
   806  
   807  	for i, tt := range tests {
   808  		resp, err := http.Get(tt.url)
   809  		if err != nil {
   810  			t.Fatalf("failed to reach proxy: %v", err)
   811  		}
   812  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   813  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   814  		}
   815  		resp.Body.Close()
   816  	}
   817  }
   818  
   819  type failingRoundTripper struct{}
   820  
   821  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   822  	return nil, errors.New("some error")
   823  }
   824  
   825  type staticResponseRoundTripper struct{ res *http.Response }
   826  
   827  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   828  	return rt.res, nil
   829  }
   830  
   831  func TestReverseProxyErrorHandler(t *testing.T) {
   832  	tests := []struct {
   833  		name           string
   834  		wantCode       int
   835  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   836  		transport      http.RoundTripper // defaults to failingRoundTripper
   837  		modifyResponse func(*http.Response) error
   838  	}{
   839  		{
   840  			name:     "default",
   841  			wantCode: http.StatusBadGateway,
   842  		},
   843  		{
   844  			name:         "errorhandler",
   845  			wantCode:     http.StatusTeapot,
   846  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   847  		},
   848  		{
   849  			name: "modifyresponse_noerr",
   850  			transport: staticResponseRoundTripper{
   851  				&http.Response{StatusCode: 345, Body: http.NoBody},
   852  			},
   853  			modifyResponse: func(res *http.Response) error {
   854  				res.StatusCode++
   855  				return nil
   856  			},
   857  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   858  			wantCode:     346,
   859  		},
   860  		{
   861  			name: "modifyresponse_err",
   862  			transport: staticResponseRoundTripper{
   863  				&http.Response{StatusCode: 345, Body: http.NoBody},
   864  			},
   865  			modifyResponse: func(res *http.Response) error {
   866  				res.StatusCode++
   867  				return errors.New("some error to trigger errorHandler")
   868  			},
   869  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   870  			wantCode:     http.StatusTeapot,
   871  		},
   872  	}
   873  
   874  	for _, tt := range tests {
   875  		t.Run(tt.name, func(t *testing.T) {
   876  			target := &url.URL{
   877  				Scheme: "http",
   878  				Host:   "dummy.tld",
   879  				Path:   "/",
   880  			}
   881  			rproxy := NewSingleHostReverseProxy(target)
   882  			rproxy.Transport = tt.transport
   883  			rproxy.ModifyResponse = tt.modifyResponse
   884  			if rproxy.Transport == nil {
   885  				rproxy.Transport = failingRoundTripper{}
   886  			}
   887  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   888  			if tt.errorHandler != nil {
   889  				rproxy.ErrorHandler = tt.errorHandler
   890  			}
   891  			frontendProxy := httptest.NewServer(rproxy)
   892  			defer frontendProxy.Close()
   893  
   894  			resp, err := http.Get(frontendProxy.URL + "/test")
   895  			if err != nil {
   896  				t.Fatalf("failed to reach proxy: %v", err)
   897  			}
   898  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   899  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   900  			}
   901  			resp.Body.Close()
   902  		})
   903  	}
   904  }
   905  
   906  // Issue 16659: log errors from short read
   907  func TestReverseProxy_CopyBuffer(t *testing.T) {
   908  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   909  		out := "this call was relayed by the reverse proxy"
   910  		// Coerce a wrong content length to induce io.UnexpectedEOF
   911  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   912  		fmt.Fprintln(w, out)
   913  	}))
   914  	defer backendServer.Close()
   915  
   916  	rpURL, err := url.Parse(backendServer.URL)
   917  	if err != nil {
   918  		t.Fatal(err)
   919  	}
   920  
   921  	var proxyLog bytes.Buffer
   922  	rproxy := NewSingleHostReverseProxy(rpURL)
   923  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   924  	donec := make(chan bool, 1)
   925  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   926  		defer func() { donec <- true }()
   927  		rproxy.ServeHTTP(w, r)
   928  	}))
   929  	defer frontendProxy.Close()
   930  
   931  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   932  		t.Fatalf("want non-nil error")
   933  	}
   934  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   935  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   936  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   937  	// continue after Get.
   938  	<-donec
   939  
   940  	expected := []string{
   941  		"EOF",
   942  		"read",
   943  	}
   944  	for _, phrase := range expected {
   945  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   946  			t.Errorf("expected log to contain phrase %q", phrase)
   947  		}
   948  	}
   949  }
   950  
   951  type staticTransport struct {
   952  	res *http.Response
   953  }
   954  
   955  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   956  	return t.res, nil
   957  }
   958  
   959  func BenchmarkServeHTTP(b *testing.B) {
   960  	res := &http.Response{
   961  		StatusCode: 200,
   962  		Body:       io.NopCloser(strings.NewReader("")),
   963  	}
   964  	proxy := &ReverseProxy{
   965  		Director:  func(*http.Request) {},
   966  		Transport: &staticTransport{res},
   967  	}
   968  
   969  	w := httptest.NewRecorder()
   970  	r := httptest.NewRequest("GET", "/", nil)
   971  
   972  	b.ReportAllocs()
   973  	for i := 0; i < b.N; i++ {
   974  		proxy.ServeHTTP(w, r)
   975  	}
   976  }
   977  
   978  func TestServeHTTPDeepCopy(t *testing.T) {
   979  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   980  		w.Write([]byte("Hello Gopher!"))
   981  	}))
   982  	defer backend.Close()
   983  	backendURL, err := url.Parse(backend.URL)
   984  	if err != nil {
   985  		t.Fatal(err)
   986  	}
   987  
   988  	type result struct {
   989  		before, after string
   990  	}
   991  
   992  	resultChan := make(chan result, 1)
   993  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   994  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   995  		before := r.URL.String()
   996  		proxyHandler.ServeHTTP(w, r)
   997  		after := r.URL.String()
   998  		resultChan <- result{before: before, after: after}
   999  	}))
  1000  	defer frontend.Close()
  1001  
  1002  	want := result{before: "/", after: "/"}
  1003  
  1004  	res, err := frontend.Client().Get(frontend.URL)
  1005  	if err != nil {
  1006  		t.Fatalf("Do: %v", err)
  1007  	}
  1008  	res.Body.Close()
  1009  
  1010  	got := <-resultChan
  1011  	if got != want {
  1012  		t.Errorf("got = %+v; want = %+v", got, want)
  1013  	}
  1014  }
  1015  
  1016  // Issue 18327: verify we always do a deep copy of the Request.Header map
  1017  // before any mutations.
  1018  func TestClonesRequestHeaders(t *testing.T) {
  1019  	log.SetOutput(io.Discard)
  1020  	defer log.SetOutput(os.Stderr)
  1021  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1022  	req.RemoteAddr = "1.2.3.4:56789"
  1023  	rp := &ReverseProxy{
  1024  		Director: func(req *http.Request) {
  1025  			req.Header.Set("From-Director", "1")
  1026  		},
  1027  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  1028  			if v := req.Header.Get("From-Director"); v != "1" {
  1029  				t.Errorf("From-Directory value = %q; want 1", v)
  1030  			}
  1031  			return nil, io.EOF
  1032  		}),
  1033  	}
  1034  	rp.ServeHTTP(httptest.NewRecorder(), req)
  1035  
  1036  	if req.Header.Get("From-Director") == "1" {
  1037  		t.Error("Director header mutation modified caller's request")
  1038  	}
  1039  	if req.Header.Get("X-Forwarded-For") != "" {
  1040  		t.Error("X-Forward-For header mutation modified caller's request")
  1041  	}
  1042  
  1043  }
  1044  
  1045  type roundTripperFunc func(req *http.Request) (*http.Response, error)
  1046  
  1047  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  1048  	return fn(req)
  1049  }
  1050  
  1051  func TestModifyResponseClosesBody(t *testing.T) {
  1052  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1053  	req.RemoteAddr = "1.2.3.4:56789"
  1054  	closeCheck := new(checkCloser)
  1055  	logBuf := new(bytes.Buffer)
  1056  	outErr := errors.New("ModifyResponse error")
  1057  	rp := &ReverseProxy{
  1058  		Director: func(req *http.Request) {},
  1059  		Transport: &staticTransport{&http.Response{
  1060  			StatusCode: 200,
  1061  			Body:       closeCheck,
  1062  		}},
  1063  		ErrorLog: log.New(logBuf, "", 0),
  1064  		ModifyResponse: func(*http.Response) error {
  1065  			return outErr
  1066  		},
  1067  	}
  1068  	rec := httptest.NewRecorder()
  1069  	rp.ServeHTTP(rec, req)
  1070  	res := rec.Result()
  1071  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1072  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1073  	}
  1074  	if !closeCheck.closed {
  1075  		t.Errorf("body should have been closed")
  1076  	}
  1077  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1078  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1079  	}
  1080  }
  1081  
  1082  type checkCloser struct {
  1083  	closed bool
  1084  }
  1085  
  1086  func (cc *checkCloser) Close() error {
  1087  	cc.closed = true
  1088  	return nil
  1089  }
  1090  
  1091  func (cc *checkCloser) Read(b []byte) (int, error) {
  1092  	return len(b), nil
  1093  }
  1094  
  1095  // Issue 23643: panic on body copy error
  1096  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1097  	log.SetOutput(io.Discard)
  1098  	defer log.SetOutput(os.Stderr)
  1099  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1100  		out := "this call was relayed by the reverse proxy"
  1101  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1102  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1103  		fmt.Fprintln(w, out)
  1104  	}))
  1105  	defer backendServer.Close()
  1106  
  1107  	rpURL, err := url.Parse(backendServer.URL)
  1108  	if err != nil {
  1109  		t.Fatal(err)
  1110  	}
  1111  
  1112  	rproxy := NewSingleHostReverseProxy(rpURL)
  1113  
  1114  	// Ensure that the handler panics when the body read encounters an
  1115  	// io.ErrUnexpectedEOF
  1116  	defer func() {
  1117  		err := recover()
  1118  		if err == nil {
  1119  			t.Fatal("handler should have panicked")
  1120  		}
  1121  		if err != http.ErrAbortHandler {
  1122  			t.Fatal("expected ErrAbortHandler, got", err)
  1123  		}
  1124  	}()
  1125  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1126  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1127  }
  1128  
  1129  // Issue #46866: panic without closing incoming request body causes a panic
  1130  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1131  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1132  		out := "this call was relayed by the reverse proxy"
  1133  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1134  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1135  		fmt.Fprintln(w, out)
  1136  	}))
  1137  	defer backend.Close()
  1138  	backendURL, err := url.Parse(backend.URL)
  1139  	if err != nil {
  1140  		t.Fatal(err)
  1141  	}
  1142  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1143  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1144  	frontend := httptest.NewServer(proxyHandler)
  1145  	defer frontend.Close()
  1146  	frontendClient := frontend.Client()
  1147  
  1148  	var wg sync.WaitGroup
  1149  	for i := 0; i < 2; i++ {
  1150  		wg.Add(1)
  1151  		go func() {
  1152  			defer wg.Done()
  1153  			for j := 0; j < 10; j++ {
  1154  				const reqLen = 6 * 1024 * 1024
  1155  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1156  				req.ContentLength = reqLen
  1157  				resp, _ := frontendClient.Transport.RoundTrip(req)
  1158  				if resp != nil {
  1159  					io.Copy(io.Discard, resp.Body)
  1160  					resp.Body.Close()
  1161  				}
  1162  			}
  1163  		}()
  1164  	}
  1165  	wg.Wait()
  1166  }
  1167  
  1168  func TestSelectFlushInterval(t *testing.T) {
  1169  	tests := []struct {
  1170  		name string
  1171  		p    *ReverseProxy
  1172  		res  *http.Response
  1173  		want time.Duration
  1174  	}{
  1175  		{
  1176  			name: "default",
  1177  			res:  &http.Response{},
  1178  			p:    &ReverseProxy{FlushInterval: 123},
  1179  			want: 123,
  1180  		},
  1181  		{
  1182  			name: "server-sent events overrides non-zero",
  1183  			res: &http.Response{
  1184  				Header: http.Header{
  1185  					"Content-Type": {"text/event-stream"},
  1186  				},
  1187  			},
  1188  			p:    &ReverseProxy{FlushInterval: 123},
  1189  			want: -1,
  1190  		},
  1191  		{
  1192  			name: "server-sent events overrides zero",
  1193  			res: &http.Response{
  1194  				Header: http.Header{
  1195  					"Content-Type": {"text/event-stream"},
  1196  				},
  1197  			},
  1198  			p:    &ReverseProxy{FlushInterval: 0},
  1199  			want: -1,
  1200  		},
  1201  		{
  1202  			name: "Content-Length: -1, overrides non-zero",
  1203  			res: &http.Response{
  1204  				ContentLength: -1,
  1205  			},
  1206  			p:    &ReverseProxy{FlushInterval: 123},
  1207  			want: -1,
  1208  		},
  1209  		{
  1210  			name: "Content-Length: -1, overrides zero",
  1211  			res: &http.Response{
  1212  				ContentLength: -1,
  1213  			},
  1214  			p:    &ReverseProxy{FlushInterval: 0},
  1215  			want: -1,
  1216  		},
  1217  	}
  1218  	for _, tt := range tests {
  1219  		t.Run(tt.name, func(t *testing.T) {
  1220  			got := tt.p.flushInterval(tt.res)
  1221  			if got != tt.want {
  1222  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1223  			}
  1224  		})
  1225  	}
  1226  }
  1227  
  1228  func TestReverseProxyWebSocket(t *testing.T) {
  1229  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1230  		if upgradeType(r.Header) != "websocket" {
  1231  			t.Error("unexpected backend request")
  1232  			http.Error(w, "unexpected request", 400)
  1233  			return
  1234  		}
  1235  		c, _, err := w.(http.Hijacker).Hijack()
  1236  		if err != nil {
  1237  			t.Error(err)
  1238  			return
  1239  		}
  1240  		defer c.Close()
  1241  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1242  		bs := bufio.NewScanner(c)
  1243  		if !bs.Scan() {
  1244  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1245  			return
  1246  		}
  1247  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1248  	}))
  1249  	defer backendServer.Close()
  1250  
  1251  	backURL, _ := url.Parse(backendServer.URL)
  1252  	rproxy := NewSingleHostReverseProxy(backURL)
  1253  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1254  	rproxy.ModifyResponse = func(res *http.Response) error {
  1255  		res.Header.Add("X-Modified", "true")
  1256  		return nil
  1257  	}
  1258  
  1259  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1260  		rw.Header().Set("X-Header", "X-Value")
  1261  		rproxy.ServeHTTP(rw, req)
  1262  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1263  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1264  		}
  1265  	})
  1266  
  1267  	frontendProxy := httptest.NewServer(handler)
  1268  	defer frontendProxy.Close()
  1269  
  1270  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1271  	req.Header.Set("Connection", "Upgrade")
  1272  	req.Header.Set("Upgrade", "websocket")
  1273  
  1274  	c := frontendProxy.Client()
  1275  	res, err := c.Do(req)
  1276  	if err != nil {
  1277  		t.Fatal(err)
  1278  	}
  1279  	if res.StatusCode != 101 {
  1280  		t.Fatalf("status = %v; want 101", res.Status)
  1281  	}
  1282  
  1283  	got := res.Header.Get("X-Header")
  1284  	want := "X-Value"
  1285  	if got != want {
  1286  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1287  	}
  1288  
  1289  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1290  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1291  	}
  1292  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1293  	if !ok {
  1294  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1295  	}
  1296  	defer rwc.Close()
  1297  
  1298  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1299  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1300  	}
  1301  
  1302  	io.WriteString(rwc, "Hello\n")
  1303  	bs := bufio.NewScanner(rwc)
  1304  	if !bs.Scan() {
  1305  		t.Fatalf("Scan: %v", bs.Err())
  1306  	}
  1307  	got = bs.Text()
  1308  	want = `backend got "Hello"`
  1309  	if got != want {
  1310  		t.Errorf("got %#q, want %#q", got, want)
  1311  	}
  1312  }
  1313  
  1314  func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1315  	n := 5
  1316  	triggerCancelCh := make(chan bool, n)
  1317  	nthResponse := func(i int) string {
  1318  		return fmt.Sprintf("backend response #%d\n", i)
  1319  	}
  1320  	terminalMsg := "final message"
  1321  
  1322  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1323  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1324  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1325  			http.Error(w, "Unexpected request", 400)
  1326  			return
  1327  		}
  1328  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1329  		if err != nil {
  1330  			t.Error(err)
  1331  			return
  1332  		}
  1333  		defer conn.Close()
  1334  
  1335  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1336  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1337  			t.Error(err)
  1338  			return
  1339  		}
  1340  		if _, _, err := bufrw.ReadLine(); err != nil {
  1341  			t.Errorf("Failed to read line from client: %v", err)
  1342  			return
  1343  		}
  1344  
  1345  		for i := 0; i < n; i++ {
  1346  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1347  				select {
  1348  				case <-triggerCancelCh:
  1349  				default:
  1350  					t.Errorf("Writing response #%d failed: %v", i, err)
  1351  				}
  1352  				return
  1353  			}
  1354  			bufrw.Flush()
  1355  			time.Sleep(time.Second)
  1356  		}
  1357  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1358  			select {
  1359  			case <-triggerCancelCh:
  1360  			default:
  1361  				t.Errorf("Failed to write terminal message: %v", err)
  1362  			}
  1363  		}
  1364  		bufrw.Flush()
  1365  	}))
  1366  	defer cst.Close()
  1367  
  1368  	backendURL, _ := url.Parse(cst.URL)
  1369  	rproxy := NewSingleHostReverseProxy(backendURL)
  1370  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1371  	rproxy.ModifyResponse = func(res *http.Response) error {
  1372  		res.Header.Add("X-Modified", "true")
  1373  		return nil
  1374  	}
  1375  
  1376  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1377  		rw.Header().Set("X-Header", "X-Value")
  1378  		ctx, cancel := context.WithCancel(req.Context())
  1379  		go func() {
  1380  			<-triggerCancelCh
  1381  			cancel()
  1382  		}()
  1383  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1384  	})
  1385  
  1386  	frontendProxy := httptest.NewServer(handler)
  1387  	defer frontendProxy.Close()
  1388  
  1389  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1390  	req.Header.Set("Connection", "Upgrade")
  1391  	req.Header.Set("Upgrade", "websocket")
  1392  
  1393  	res, err := frontendProxy.Client().Do(req)
  1394  	if err != nil {
  1395  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1396  	}
  1397  	defer res.Body.Close()
  1398  	if g, w := res.StatusCode, 101; g != w {
  1399  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1400  	}
  1401  
  1402  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1403  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1404  	}
  1405  
  1406  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1407  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1408  	}
  1409  
  1410  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1411  	if !ok {
  1412  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1413  	}
  1414  
  1415  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1416  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1417  	}
  1418  
  1419  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1420  		t.Fatalf("Failed to write first message: %v", err)
  1421  	}
  1422  
  1423  	// Read loop.
  1424  
  1425  	br := bufio.NewReader(rwc)
  1426  	for {
  1427  		line, err := br.ReadString('\n')
  1428  		switch {
  1429  		case line == terminalMsg: // this case before "err == io.EOF"
  1430  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1431  
  1432  		case err == io.EOF:
  1433  			return
  1434  
  1435  		case err != nil:
  1436  			t.Fatalf("Unexpected error: %v", err)
  1437  
  1438  		case line == nthResponse(0): // We've gotten the first response back
  1439  			// Let's trigger a cancel.
  1440  			close(triggerCancelCh)
  1441  		}
  1442  	}
  1443  }
  1444  
  1445  func TestUnannouncedTrailer(t *testing.T) {
  1446  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1447  		w.WriteHeader(http.StatusOK)
  1448  		w.(http.Flusher).Flush()
  1449  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1450  	}))
  1451  	defer backend.Close()
  1452  	backendURL, err := url.Parse(backend.URL)
  1453  	if err != nil {
  1454  		t.Fatal(err)
  1455  	}
  1456  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1457  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1458  	frontend := httptest.NewServer(proxyHandler)
  1459  	defer frontend.Close()
  1460  	frontendClient := frontend.Client()
  1461  
  1462  	res, err := frontendClient.Get(frontend.URL)
  1463  	if err != nil {
  1464  		t.Fatalf("Get: %v", err)
  1465  	}
  1466  
  1467  	io.ReadAll(res.Body)
  1468  
  1469  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1470  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1471  	}
  1472  
  1473  }
  1474  
  1475  func TestSingleJoinSlash(t *testing.T) {
  1476  	tests := []struct {
  1477  		slasha   string
  1478  		slashb   string
  1479  		expected string
  1480  	}{
  1481  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1482  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1483  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1484  		{"https://www.google.com", "", "https://www.google.com/"},
  1485  		{"", "favicon.ico", "/favicon.ico"},
  1486  	}
  1487  	for _, tt := range tests {
  1488  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1489  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1490  				tt.slasha,
  1491  				tt.slashb,
  1492  				tt.expected,
  1493  				got)
  1494  		}
  1495  	}
  1496  }
  1497  
  1498  func TestJoinURLPath(t *testing.T) {
  1499  	tests := []struct {
  1500  		a        *url.URL
  1501  		b        *url.URL
  1502  		wantPath string
  1503  		wantRaw  string
  1504  	}{
  1505  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1506  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1507  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1508  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1509  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1510  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1511  	}
  1512  
  1513  	for _, tt := range tests {
  1514  		p, rp := joinURLPath(tt.a, tt.b)
  1515  		if p != tt.wantPath || rp != tt.wantRaw {
  1516  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1517  				tt.a.Path, tt.a.RawPath,
  1518  				tt.b.Path, tt.b.RawPath,
  1519  				tt.wantPath, tt.wantRaw,
  1520  				p, rp)
  1521  		}
  1522  	}
  1523  }