github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/net/http/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/http/httptrace"
    20  	"github.com/mtsmfm/go/src/net/http/internal/ascii"
    21  	"net/textproto"
    22  	"net/url"
    23  	"os"
    24  	"reflect"
    25  	"sort"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  )
    32  
    33  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    34  
    35  func init() {
    36  	inOurTests = true
    37  	hopHeaders = append(hopHeaders, fakeHopHeader)
    38  }
    39  
    40  func TestReverseProxy(t *testing.T) {
    41  	const backendResponse = "I am the backend"
    42  	const backendStatus = 404
    43  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    45  			c, _, _ := w.(http.Hijacker).Hijack()
    46  			c.Close()
    47  			return
    48  		}
    49  		if len(r.TransferEncoding) > 0 {
    50  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    51  		}
    52  		if r.Header.Get("X-Forwarded-For") == "" {
    53  			t.Errorf("didn't get X-Forwarded-For header")
    54  		}
    55  		if r.Header.Get("X-Forwarded-Host") == "" {
    56  			t.Errorf("didn't get X-Forwarded-Host header")
    57  		}
    58  		if r.Header.Get("X-Forwarded-Proto") == "" {
    59  			t.Errorf("didn't get X-Forwarded-Proto header")
    60  		}
    61  		if c := r.Header.Get("Connection"); c != "" {
    62  			t.Errorf("handler got Connection header value %q", c)
    63  		}
    64  		if c := r.Header.Get("Te"); c != "trailers" {
    65  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    66  		}
    67  		if c := r.Header.Get("Upgrade"); c != "" {
    68  			t.Errorf("handler got Upgrade header value %q", c)
    69  		}
    70  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    71  			t.Errorf("handler got Proxy-Connection header value %q", c)
    72  		}
    73  		if g, e := r.Host, "some-name"; g != e {
    74  			t.Errorf("backend got Host header %q, want %q", g, e)
    75  		}
    76  		w.Header().Set("Trailers", "not a special header field name")
    77  		w.Header().Set("Trailer", "X-Trailer")
    78  		w.Header().Set("X-Foo", "bar")
    79  		w.Header().Set("Upgrade", "foo")
    80  		w.Header().Set(fakeHopHeader, "foo")
    81  		w.Header().Add("X-Multi-Value", "foo")
    82  		w.Header().Add("X-Multi-Value", "bar")
    83  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    84  		w.WriteHeader(backendStatus)
    85  		w.Write([]byte(backendResponse))
    86  		w.Header().Set("X-Trailer", "trailer_value")
    87  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    88  	}))
    89  	defer backend.Close()
    90  	backendURL, err := url.Parse(backend.URL)
    91  	if err != nil {
    92  		t.Fatal(err)
    93  	}
    94  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    95  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    96  	frontend := httptest.NewServer(proxyHandler)
    97  	defer frontend.Close()
    98  	frontendClient := frontend.Client()
    99  
   100  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   101  	getReq.Host = "some-name"
   102  	getReq.Header.Set("Connection", "close, TE")
   103  	getReq.Header.Add("Te", "foo")
   104  	getReq.Header.Add("Te", "bar, trailers")
   105  	getReq.Header.Set("Proxy-Connection", "should be deleted")
   106  	getReq.Header.Set("Upgrade", "foo")
   107  	getReq.Close = true
   108  	res, err := frontendClient.Do(getReq)
   109  	if err != nil {
   110  		t.Fatalf("Get: %v", err)
   111  	}
   112  	if g, e := res.StatusCode, backendStatus; g != e {
   113  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   114  	}
   115  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   116  		t.Errorf("got X-Foo %q; expected %q", g, e)
   117  	}
   118  	if c := res.Header.Get(fakeHopHeader); c != "" {
   119  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   120  	}
   121  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   122  		t.Errorf("header Trailers = %q; want %q", g, e)
   123  	}
   124  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   125  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   126  	}
   127  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   128  		t.Fatalf("got %d SetCookies, want %d", g, e)
   129  	}
   130  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   131  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   132  	}
   133  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   134  		t.Errorf("unexpected cookie %q", cookie.Name)
   135  	}
   136  	bodyBytes, _ := io.ReadAll(res.Body)
   137  	if g, e := string(bodyBytes), backendResponse; g != e {
   138  		t.Errorf("got body %q; expected %q", g, e)
   139  	}
   140  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   141  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   142  	}
   143  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   144  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   145  	}
   146  
   147  	// Test that a backend failing to be reached or one which doesn't return
   148  	// a response results in a StatusBadGateway.
   149  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   150  	getReq.Close = true
   151  	res, err = frontendClient.Do(getReq)
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  	res.Body.Close()
   156  	if res.StatusCode != http.StatusBadGateway {
   157  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   158  	}
   159  
   160  }
   161  
   162  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   163  // header value.
   164  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   165  	const fakeConnectionToken = "X-Fake-Connection-Token"
   166  	const backendResponse = "I am the backend"
   167  
   168  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   169  	// in the Request's Connection header.
   170  	const someConnHeader = "X-Some-Conn-Header"
   171  
   172  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   173  		if c := r.Header.Get("Connection"); c != "" {
   174  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   175  		}
   176  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   177  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   178  		}
   179  		if c := r.Header.Get(someConnHeader); c != "" {
   180  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   181  		}
   182  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   183  		w.Header().Add("Connection", someConnHeader)
   184  		w.Header().Set(someConnHeader, "should be deleted")
   185  		w.Header().Set(fakeConnectionToken, "should be deleted")
   186  		io.WriteString(w, backendResponse)
   187  	}))
   188  	defer backend.Close()
   189  	backendURL, err := url.Parse(backend.URL)
   190  	if err != nil {
   191  		t.Fatal(err)
   192  	}
   193  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   194  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   195  		proxyHandler.ServeHTTP(w, r)
   196  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   197  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   198  		}
   199  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   200  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   201  		}
   202  		c := r.Header["Connection"]
   203  		var cf []string
   204  		for _, f := range c {
   205  			for _, sf := range strings.Split(f, ",") {
   206  				if sf = strings.TrimSpace(sf); sf != "" {
   207  					cf = append(cf, sf)
   208  				}
   209  			}
   210  		}
   211  		sort.Strings(cf)
   212  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   213  		sort.Strings(expectedValues)
   214  		if !reflect.DeepEqual(cf, expectedValues) {
   215  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   216  		}
   217  	}))
   218  	defer frontend.Close()
   219  
   220  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   221  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   222  	getReq.Header.Add("Connection", someConnHeader)
   223  	getReq.Header.Set(someConnHeader, "should be deleted")
   224  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   225  	res, err := frontend.Client().Do(getReq)
   226  	if err != nil {
   227  		t.Fatalf("Get: %v", err)
   228  	}
   229  	defer res.Body.Close()
   230  	bodyBytes, err := io.ReadAll(res.Body)
   231  	if err != nil {
   232  		t.Fatalf("reading body: %v", err)
   233  	}
   234  	if got, want := string(bodyBytes), backendResponse; got != want {
   235  		t.Errorf("got body %q; want %q", got, want)
   236  	}
   237  	if c := res.Header.Get("Connection"); c != "" {
   238  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   239  	}
   240  	if c := res.Header.Get(someConnHeader); c != "" {
   241  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   242  	}
   243  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   244  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   245  	}
   246  }
   247  
   248  func TestReverseProxyStripEmptyConnection(t *testing.T) {
   249  	// See Issue 46313.
   250  	const backendResponse = "I am the backend"
   251  
   252  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   253  	// in the Request's Connection header.
   254  	const someConnHeader = "X-Some-Conn-Header"
   255  
   256  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   257  		if c := r.Header.Values("Connection"); len(c) != 0 {
   258  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
   259  		}
   260  		if c := r.Header.Get(someConnHeader); c != "" {
   261  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   262  		}
   263  		w.Header().Add("Connection", "")
   264  		w.Header().Add("Connection", someConnHeader)
   265  		w.Header().Set(someConnHeader, "should be deleted")
   266  		io.WriteString(w, backendResponse)
   267  	}))
   268  	defer backend.Close()
   269  	backendURL, err := url.Parse(backend.URL)
   270  	if err != nil {
   271  		t.Fatal(err)
   272  	}
   273  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   274  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   275  		proxyHandler.ServeHTTP(w, r)
   276  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   277  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   278  		}
   279  	}))
   280  	defer frontend.Close()
   281  
   282  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   283  	getReq.Header.Add("Connection", "")
   284  	getReq.Header.Add("Connection", someConnHeader)
   285  	getReq.Header.Set(someConnHeader, "should be deleted")
   286  	res, err := frontend.Client().Do(getReq)
   287  	if err != nil {
   288  		t.Fatalf("Get: %v", err)
   289  	}
   290  	defer res.Body.Close()
   291  	bodyBytes, err := io.ReadAll(res.Body)
   292  	if err != nil {
   293  		t.Fatalf("reading body: %v", err)
   294  	}
   295  	if got, want := string(bodyBytes), backendResponse; got != want {
   296  		t.Errorf("got body %q; want %q", got, want)
   297  	}
   298  	if c := res.Header.Get("Connection"); c != "" {
   299  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   300  	}
   301  	if c := res.Header.Get(someConnHeader); c != "" {
   302  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   303  	}
   304  }
   305  
   306  func TestXForwardedFor(t *testing.T) {
   307  	const prevForwardedFor = "client ip"
   308  	const backendResponse = "I am the backend"
   309  	const backendStatus = 404
   310  	const host = "some-name"
   311  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   312  		if r.Header.Get("X-Forwarded-For") == "" {
   313  			t.Errorf("didn't get X-Forwarded-For header")
   314  		}
   315  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   316  			t.Errorf("X-Forwarded-For didn't contain prior data")
   317  		}
   318  		if got, want := r.Header.Get("X-Forwarded-Host"), host; got != want {
   319  			t.Errorf("X-Forwarded-Host = %q, want %q", got, want)
   320  		}
   321  		if got, want := r.Header.Get("X-Forwarded-Proto"), "http"; got != want {
   322  			t.Errorf("X-Forwarded-Proto = %q, want %q", got, want)
   323  		}
   324  		w.WriteHeader(backendStatus)
   325  		w.Write([]byte(backendResponse))
   326  	}))
   327  	defer backend.Close()
   328  	backendURL, err := url.Parse(backend.URL)
   329  	if err != nil {
   330  		t.Fatal(err)
   331  	}
   332  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   333  	frontend := httptest.NewServer(proxyHandler)
   334  	defer frontend.Close()
   335  
   336  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   337  	getReq.Host = host
   338  	getReq.Header.Set("Connection", "close")
   339  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   340  	getReq.Close = true
   341  	res, err := frontend.Client().Do(getReq)
   342  	if err != nil {
   343  		t.Fatalf("Get: %v", err)
   344  	}
   345  	if g, e := res.StatusCode, backendStatus; g != e {
   346  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   347  	}
   348  	bodyBytes, _ := io.ReadAll(res.Body)
   349  	if g, e := string(bodyBytes), backendResponse; g != e {
   350  		t.Errorf("got body %q; expected %q", g, e)
   351  	}
   352  }
   353  
   354  func TestXForwardedProtoTLS(t *testing.T) {
   355  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   356  		if got, want := r.Header.Get("X-Forwarded-Proto"), "https"; got != want {
   357  			t.Errorf("X-Forwarded-Proto = %q, want %q", got, want)
   358  		}
   359  	}))
   360  	defer backend.Close()
   361  	backendURL, err := url.Parse(backend.URL)
   362  	if err != nil {
   363  		t.Fatal(err)
   364  	}
   365  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   366  	frontend := httptest.NewTLSServer(proxyHandler)
   367  	defer frontend.Close()
   368  
   369  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   370  	getReq.Host = "some-host"
   371  	_, err = frontend.Client().Do(getReq)
   372  	if err != nil {
   373  		t.Fatalf("Get: %v", err)
   374  	}
   375  }
   376  
   377  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   378  func TestXForwardedFor_Omit(t *testing.T) {
   379  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   380  		for _, h := range []string{"X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto"} {
   381  			if v := r.Header.Get(h); v != "" {
   382  				t.Errorf("got %v header: %q", h, v)
   383  			}
   384  		}
   385  		w.Write([]byte("hi"))
   386  	}))
   387  	defer backend.Close()
   388  	backendURL, err := url.Parse(backend.URL)
   389  	if err != nil {
   390  		t.Fatal(err)
   391  	}
   392  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   393  	frontend := httptest.NewServer(proxyHandler)
   394  	defer frontend.Close()
   395  
   396  	oldDirector := proxyHandler.Director
   397  	proxyHandler.Director = func(r *http.Request) {
   398  		r.Header["X-Forwarded-For"] = nil
   399  		r.Header["X-Forwarded-Host"] = nil
   400  		r.Header["X-Forwarded-Proto"] = nil
   401  		oldDirector(r)
   402  	}
   403  
   404  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   405  	getReq.Host = "some-name"
   406  	getReq.Close = true
   407  	res, err := frontend.Client().Do(getReq)
   408  	if err != nil {
   409  		t.Fatalf("Get: %v", err)
   410  	}
   411  	res.Body.Close()
   412  }
   413  
   414  func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
   415  	headers := []string{
   416  		"Forwarded",
   417  		"X-Forwarded-For",
   418  		"X-Forwarded-Host",
   419  		"X-Forwarded-Proto",
   420  	}
   421  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   422  		for _, h := range headers {
   423  			if v := r.Header.Get(h); v != "" {
   424  				t.Errorf("got %v header: %q", h, v)
   425  			}
   426  		}
   427  	}))
   428  	defer backend.Close()
   429  	backendURL, err := url.Parse(backend.URL)
   430  	if err != nil {
   431  		t.Fatal(err)
   432  	}
   433  	proxyHandler := &ReverseProxy{
   434  		Rewrite: func(r *ProxyRequest) {
   435  			r.SetURL(backendURL)
   436  		},
   437  	}
   438  	frontend := httptest.NewServer(proxyHandler)
   439  	defer frontend.Close()
   440  
   441  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   442  	getReq.Host = "some-name"
   443  	getReq.Close = true
   444  	for _, h := range headers {
   445  		getReq.Header.Set(h, "x")
   446  	}
   447  	res, err := frontend.Client().Do(getReq)
   448  	if err != nil {
   449  		t.Fatalf("Get: %v", err)
   450  	}
   451  	res.Body.Close()
   452  }
   453  
   454  var proxyQueryTests = []struct {
   455  	baseSuffix string // suffix to add to backend URL
   456  	reqSuffix  string // suffix to add to frontend's request URL
   457  	want       string // what backend should see for final request URL (without ?)
   458  }{
   459  	{"", "", ""},
   460  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   461  	{"", "?us=er", "us=er"},
   462  	{"?sta=tic", "", "sta=tic"},
   463  }
   464  
   465  func TestReverseProxyQuery(t *testing.T) {
   466  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   467  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   468  		w.Write([]byte("hi"))
   469  	}))
   470  	defer backend.Close()
   471  
   472  	for i, tt := range proxyQueryTests {
   473  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   474  		if err != nil {
   475  			t.Fatal(err)
   476  		}
   477  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   478  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   479  		req.Close = true
   480  		res, err := frontend.Client().Do(req)
   481  		if err != nil {
   482  			t.Fatalf("%d. Get: %v", i, err)
   483  		}
   484  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   485  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   486  		}
   487  		res.Body.Close()
   488  		frontend.Close()
   489  	}
   490  }
   491  
   492  func TestReverseProxyFlushInterval(t *testing.T) {
   493  	const expected = "hi"
   494  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   495  		w.Write([]byte(expected))
   496  	}))
   497  	defer backend.Close()
   498  
   499  	backendURL, err := url.Parse(backend.URL)
   500  	if err != nil {
   501  		t.Fatal(err)
   502  	}
   503  
   504  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   505  	proxyHandler.FlushInterval = time.Microsecond
   506  
   507  	frontend := httptest.NewServer(proxyHandler)
   508  	defer frontend.Close()
   509  
   510  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   511  	req.Close = true
   512  	res, err := frontend.Client().Do(req)
   513  	if err != nil {
   514  		t.Fatalf("Get: %v", err)
   515  	}
   516  	defer res.Body.Close()
   517  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   518  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   519  	}
   520  }
   521  
   522  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   523  	const expected = "hi"
   524  	stopCh := make(chan struct{})
   525  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   526  		w.Header().Add("MyHeader", expected)
   527  		w.WriteHeader(200)
   528  		w.(http.Flusher).Flush()
   529  		<-stopCh
   530  	}))
   531  	defer backend.Close()
   532  	defer close(stopCh)
   533  
   534  	backendURL, err := url.Parse(backend.URL)
   535  	if err != nil {
   536  		t.Fatal(err)
   537  	}
   538  
   539  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   540  	proxyHandler.FlushInterval = time.Microsecond
   541  
   542  	frontend := httptest.NewServer(proxyHandler)
   543  	defer frontend.Close()
   544  
   545  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   546  	req.Close = true
   547  
   548  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   549  	defer cancel()
   550  	req = req.WithContext(ctx)
   551  
   552  	res, err := frontend.Client().Do(req)
   553  	if err != nil {
   554  		t.Fatalf("Get: %v", err)
   555  	}
   556  	defer res.Body.Close()
   557  
   558  	if res.Header.Get("MyHeader") != expected {
   559  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   560  	}
   561  }
   562  
   563  func TestReverseProxyCancellation(t *testing.T) {
   564  	const backendResponse = "I am the backend"
   565  
   566  	reqInFlight := make(chan struct{})
   567  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   568  		close(reqInFlight) // cause the client to cancel its request
   569  
   570  		select {
   571  		case <-time.After(10 * time.Second):
   572  			// Note: this should only happen in broken implementations, and the
   573  			// closenotify case should be instantaneous.
   574  			t.Error("Handler never saw CloseNotify")
   575  			return
   576  		case <-w.(http.CloseNotifier).CloseNotify():
   577  		}
   578  
   579  		w.WriteHeader(http.StatusOK)
   580  		w.Write([]byte(backendResponse))
   581  	}))
   582  
   583  	defer backend.Close()
   584  
   585  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   586  
   587  	backendURL, err := url.Parse(backend.URL)
   588  	if err != nil {
   589  		t.Fatal(err)
   590  	}
   591  
   592  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   593  
   594  	// Discards errors of the form:
   595  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   596  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   597  
   598  	frontend := httptest.NewServer(proxyHandler)
   599  	defer frontend.Close()
   600  	frontendClient := frontend.Client()
   601  
   602  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   603  	go func() {
   604  		<-reqInFlight
   605  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   606  	}()
   607  	res, err := frontendClient.Do(getReq)
   608  	if res != nil {
   609  		t.Errorf("got response %v; want nil", res.Status)
   610  	}
   611  	if err == nil {
   612  		// This should be an error like:
   613  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   614  		//    use of closed network connection
   615  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   616  	}
   617  }
   618  
   619  func req(t *testing.T, v string) *http.Request {
   620  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	return req
   625  }
   626  
   627  // Issue 12344
   628  func TestNilBody(t *testing.T) {
   629  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   630  		w.Write([]byte("hi"))
   631  	}))
   632  	defer backend.Close()
   633  
   634  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   635  		backURL, _ := url.Parse(backend.URL)
   636  		rp := NewSingleHostReverseProxy(backURL)
   637  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   638  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   639  		rp.ServeHTTP(w, r)
   640  	}))
   641  	defer frontend.Close()
   642  
   643  	res, err := http.Get(frontend.URL)
   644  	if err != nil {
   645  		t.Fatal(err)
   646  	}
   647  	defer res.Body.Close()
   648  	slurp, err := io.ReadAll(res.Body)
   649  	if err != nil {
   650  		t.Fatal(err)
   651  	}
   652  	if string(slurp) != "hi" {
   653  		t.Errorf("Got %q; want %q", slurp, "hi")
   654  	}
   655  }
   656  
   657  // Issue 15524
   658  func TestUserAgentHeader(t *testing.T) {
   659  	var gotUA string
   660  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   661  		gotUA = r.Header.Get("User-Agent")
   662  	}))
   663  	defer backend.Close()
   664  	backendURL, err := url.Parse(backend.URL)
   665  	if err != nil {
   666  		t.Fatal(err)
   667  	}
   668  
   669  	proxyHandler := new(ReverseProxy)
   670  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   671  	proxyHandler.Director = func(req *http.Request) {
   672  		req.URL = backendURL
   673  	}
   674  	frontend := httptest.NewServer(proxyHandler)
   675  	defer frontend.Close()
   676  	frontendClient := frontend.Client()
   677  
   678  	for _, sentUA := range []string{"explicit UA", ""} {
   679  		getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   680  		getReq.Header.Set("User-Agent", sentUA)
   681  		getReq.Close = true
   682  		res, err := frontendClient.Do(getReq)
   683  		if err != nil {
   684  			t.Fatalf("Get: %v", err)
   685  		}
   686  		res.Body.Close()
   687  		if got, want := gotUA, sentUA; got != want {
   688  			t.Errorf("got forwarded User-Agent %q, want %q", got, want)
   689  		}
   690  	}
   691  }
   692  
   693  type bufferPool struct {
   694  	get func() []byte
   695  	put func([]byte)
   696  }
   697  
   698  func (bp bufferPool) Get() []byte  { return bp.get() }
   699  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   700  
   701  func TestReverseProxyGetPutBuffer(t *testing.T) {
   702  	const msg = "hi"
   703  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   704  		io.WriteString(w, msg)
   705  	}))
   706  	defer backend.Close()
   707  
   708  	backendURL, err := url.Parse(backend.URL)
   709  	if err != nil {
   710  		t.Fatal(err)
   711  	}
   712  
   713  	var (
   714  		mu  sync.Mutex
   715  		log []string
   716  	)
   717  	addLog := func(event string) {
   718  		mu.Lock()
   719  		defer mu.Unlock()
   720  		log = append(log, event)
   721  	}
   722  	rp := NewSingleHostReverseProxy(backendURL)
   723  	const size = 1234
   724  	rp.BufferPool = bufferPool{
   725  		get: func() []byte {
   726  			addLog("getBuf")
   727  			return make([]byte, size)
   728  		},
   729  		put: func(p []byte) {
   730  			addLog("putBuf-" + strconv.Itoa(len(p)))
   731  		},
   732  	}
   733  	frontend := httptest.NewServer(rp)
   734  	defer frontend.Close()
   735  
   736  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   737  	req.Close = true
   738  	res, err := frontend.Client().Do(req)
   739  	if err != nil {
   740  		t.Fatalf("Get: %v", err)
   741  	}
   742  	slurp, err := io.ReadAll(res.Body)
   743  	res.Body.Close()
   744  	if err != nil {
   745  		t.Fatalf("reading body: %v", err)
   746  	}
   747  	if string(slurp) != msg {
   748  		t.Errorf("msg = %q; want %q", slurp, msg)
   749  	}
   750  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   751  	mu.Lock()
   752  	defer mu.Unlock()
   753  	if !reflect.DeepEqual(log, wantLog) {
   754  		t.Errorf("Log events = %q; want %q", log, wantLog)
   755  	}
   756  }
   757  
   758  func TestReverseProxy_Post(t *testing.T) {
   759  	const backendResponse = "I am the backend"
   760  	const backendStatus = 200
   761  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   762  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   763  		slurp, err := io.ReadAll(r.Body)
   764  		if err != nil {
   765  			t.Errorf("Backend body read = %v", err)
   766  		}
   767  		if len(slurp) != len(requestBody) {
   768  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   769  		}
   770  		if !bytes.Equal(slurp, requestBody) {
   771  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   772  		}
   773  		w.Write([]byte(backendResponse))
   774  	}))
   775  	defer backend.Close()
   776  	backendURL, err := url.Parse(backend.URL)
   777  	if err != nil {
   778  		t.Fatal(err)
   779  	}
   780  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   781  	frontend := httptest.NewServer(proxyHandler)
   782  	defer frontend.Close()
   783  
   784  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   785  	res, err := frontend.Client().Do(postReq)
   786  	if err != nil {
   787  		t.Fatalf("Do: %v", err)
   788  	}
   789  	if g, e := res.StatusCode, backendStatus; g != e {
   790  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   791  	}
   792  	bodyBytes, _ := io.ReadAll(res.Body)
   793  	if g, e := string(bodyBytes), backendResponse; g != e {
   794  		t.Errorf("got body %q; expected %q", g, e)
   795  	}
   796  }
   797  
   798  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   799  
   800  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   801  	return fn(req)
   802  }
   803  
   804  // Issue 16036: send a Request with a nil Body when possible
   805  func TestReverseProxy_NilBody(t *testing.T) {
   806  	backendURL, _ := url.Parse("http://fake.tld/")
   807  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   808  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   809  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   810  		if req.Body != nil {
   811  			t.Error("Body != nil; want a nil Body")
   812  		}
   813  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   814  	})
   815  	frontend := httptest.NewServer(proxyHandler)
   816  	defer frontend.Close()
   817  
   818  	res, err := frontend.Client().Get(frontend.URL)
   819  	if err != nil {
   820  		t.Fatal(err)
   821  	}
   822  	defer res.Body.Close()
   823  	if res.StatusCode != 502 {
   824  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   825  	}
   826  }
   827  
   828  // Issue 33142: always allocate the request headers
   829  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   830  	proxyHandler := new(ReverseProxy)
   831  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   832  	proxyHandler.Director = func(*http.Request) {}     // noop
   833  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   834  		if req.Header == nil {
   835  			t.Error("Header == nil; want a non-nil Header")
   836  		}
   837  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   838  	})
   839  
   840  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   841  		Method:     "GET",
   842  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   843  		Proto:      "HTTP/1.0",
   844  		ProtoMajor: 1,
   845  	})
   846  }
   847  
   848  // Issue 14237. Test ModifyResponse and that an error from it
   849  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   850  func TestReverseProxyModifyResponse(t *testing.T) {
   851  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   852  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   853  	}))
   854  	defer backendServer.Close()
   855  
   856  	rpURL, _ := url.Parse(backendServer.URL)
   857  	rproxy := NewSingleHostReverseProxy(rpURL)
   858  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   859  	rproxy.ModifyResponse = func(resp *http.Response) error {
   860  		if resp.Header.Get("X-Hit-Mod") != "true" {
   861  			return fmt.Errorf("tried to by-pass proxy")
   862  		}
   863  		return nil
   864  	}
   865  
   866  	frontendProxy := httptest.NewServer(rproxy)
   867  	defer frontendProxy.Close()
   868  
   869  	tests := []struct {
   870  		url      string
   871  		wantCode int
   872  	}{
   873  		{frontendProxy.URL + "/mod", http.StatusOK},
   874  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   875  	}
   876  
   877  	for i, tt := range tests {
   878  		resp, err := http.Get(tt.url)
   879  		if err != nil {
   880  			t.Fatalf("failed to reach proxy: %v", err)
   881  		}
   882  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   883  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   884  		}
   885  		resp.Body.Close()
   886  	}
   887  }
   888  
   889  type failingRoundTripper struct{}
   890  
   891  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   892  	return nil, errors.New("some error")
   893  }
   894  
   895  type staticResponseRoundTripper struct{ res *http.Response }
   896  
   897  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   898  	return rt.res, nil
   899  }
   900  
   901  func TestReverseProxyErrorHandler(t *testing.T) {
   902  	tests := []struct {
   903  		name           string
   904  		wantCode       int
   905  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   906  		transport      http.RoundTripper // defaults to failingRoundTripper
   907  		modifyResponse func(*http.Response) error
   908  	}{
   909  		{
   910  			name:     "default",
   911  			wantCode: http.StatusBadGateway,
   912  		},
   913  		{
   914  			name:         "errorhandler",
   915  			wantCode:     http.StatusTeapot,
   916  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   917  		},
   918  		{
   919  			name: "modifyresponse_noerr",
   920  			transport: staticResponseRoundTripper{
   921  				&http.Response{StatusCode: 345, Body: http.NoBody},
   922  			},
   923  			modifyResponse: func(res *http.Response) error {
   924  				res.StatusCode++
   925  				return nil
   926  			},
   927  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   928  			wantCode:     346,
   929  		},
   930  		{
   931  			name: "modifyresponse_err",
   932  			transport: staticResponseRoundTripper{
   933  				&http.Response{StatusCode: 345, Body: http.NoBody},
   934  			},
   935  			modifyResponse: func(res *http.Response) error {
   936  				res.StatusCode++
   937  				return errors.New("some error to trigger errorHandler")
   938  			},
   939  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   940  			wantCode:     http.StatusTeapot,
   941  		},
   942  	}
   943  
   944  	for _, tt := range tests {
   945  		t.Run(tt.name, func(t *testing.T) {
   946  			target := &url.URL{
   947  				Scheme: "http",
   948  				Host:   "dummy.tld",
   949  				Path:   "/",
   950  			}
   951  			rproxy := NewSingleHostReverseProxy(target)
   952  			rproxy.Transport = tt.transport
   953  			rproxy.ModifyResponse = tt.modifyResponse
   954  			if rproxy.Transport == nil {
   955  				rproxy.Transport = failingRoundTripper{}
   956  			}
   957  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   958  			if tt.errorHandler != nil {
   959  				rproxy.ErrorHandler = tt.errorHandler
   960  			}
   961  			frontendProxy := httptest.NewServer(rproxy)
   962  			defer frontendProxy.Close()
   963  
   964  			resp, err := http.Get(frontendProxy.URL + "/test")
   965  			if err != nil {
   966  				t.Fatalf("failed to reach proxy: %v", err)
   967  			}
   968  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   969  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   970  			}
   971  			resp.Body.Close()
   972  		})
   973  	}
   974  }
   975  
   976  // Issue 16659: log errors from short read
   977  func TestReverseProxy_CopyBuffer(t *testing.T) {
   978  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   979  		out := "this call was relayed by the reverse proxy"
   980  		// Coerce a wrong content length to induce io.UnexpectedEOF
   981  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   982  		fmt.Fprintln(w, out)
   983  	}))
   984  	defer backendServer.Close()
   985  
   986  	rpURL, err := url.Parse(backendServer.URL)
   987  	if err != nil {
   988  		t.Fatal(err)
   989  	}
   990  
   991  	var proxyLog bytes.Buffer
   992  	rproxy := NewSingleHostReverseProxy(rpURL)
   993  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   994  	donec := make(chan bool, 1)
   995  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   996  		defer func() { donec <- true }()
   997  		rproxy.ServeHTTP(w, r)
   998  	}))
   999  	defer frontendProxy.Close()
  1000  
  1001  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
  1002  		t.Fatalf("want non-nil error")
  1003  	}
  1004  	// The race detector complains about the proxyLog usage in logf in copyBuffer
  1005  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
  1006  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
  1007  	// continue after Get.
  1008  	<-donec
  1009  
  1010  	expected := []string{
  1011  		"EOF",
  1012  		"read",
  1013  	}
  1014  	for _, phrase := range expected {
  1015  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
  1016  			t.Errorf("expected log to contain phrase %q", phrase)
  1017  		}
  1018  	}
  1019  }
  1020  
  1021  type staticTransport struct {
  1022  	res *http.Response
  1023  }
  1024  
  1025  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  1026  	return t.res, nil
  1027  }
  1028  
  1029  func BenchmarkServeHTTP(b *testing.B) {
  1030  	res := &http.Response{
  1031  		StatusCode: 200,
  1032  		Body:       io.NopCloser(strings.NewReader("")),
  1033  	}
  1034  	proxy := &ReverseProxy{
  1035  		Director:  func(*http.Request) {},
  1036  		Transport: &staticTransport{res},
  1037  	}
  1038  
  1039  	w := httptest.NewRecorder()
  1040  	r := httptest.NewRequest("GET", "/", nil)
  1041  
  1042  	b.ReportAllocs()
  1043  	for i := 0; i < b.N; i++ {
  1044  		proxy.ServeHTTP(w, r)
  1045  	}
  1046  }
  1047  
  1048  func TestServeHTTPDeepCopy(t *testing.T) {
  1049  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1050  		w.Write([]byte("Hello Gopher!"))
  1051  	}))
  1052  	defer backend.Close()
  1053  	backendURL, err := url.Parse(backend.URL)
  1054  	if err != nil {
  1055  		t.Fatal(err)
  1056  	}
  1057  
  1058  	type result struct {
  1059  		before, after string
  1060  	}
  1061  
  1062  	resultChan := make(chan result, 1)
  1063  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1064  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1065  		before := r.URL.String()
  1066  		proxyHandler.ServeHTTP(w, r)
  1067  		after := r.URL.String()
  1068  		resultChan <- result{before: before, after: after}
  1069  	}))
  1070  	defer frontend.Close()
  1071  
  1072  	want := result{before: "/", after: "/"}
  1073  
  1074  	res, err := frontend.Client().Get(frontend.URL)
  1075  	if err != nil {
  1076  		t.Fatalf("Do: %v", err)
  1077  	}
  1078  	res.Body.Close()
  1079  
  1080  	got := <-resultChan
  1081  	if got != want {
  1082  		t.Errorf("got = %+v; want = %+v", got, want)
  1083  	}
  1084  }
  1085  
  1086  // Issue 18327: verify we always do a deep copy of the Request.Header map
  1087  // before any mutations.
  1088  func TestClonesRequestHeaders(t *testing.T) {
  1089  	log.SetOutput(io.Discard)
  1090  	defer log.SetOutput(os.Stderr)
  1091  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1092  	req.RemoteAddr = "1.2.3.4:56789"
  1093  	rp := &ReverseProxy{
  1094  		Director: func(req *http.Request) {
  1095  			req.Header.Set("From-Director", "1")
  1096  		},
  1097  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  1098  			if v := req.Header.Get("From-Director"); v != "1" {
  1099  				t.Errorf("From-Directory value = %q; want 1", v)
  1100  			}
  1101  			return nil, io.EOF
  1102  		}),
  1103  	}
  1104  	rp.ServeHTTP(httptest.NewRecorder(), req)
  1105  
  1106  	for _, h := range []string{
  1107  		"From-Director",
  1108  		"X-Forwarded-For",
  1109  		"X-Forwarded-Host",
  1110  		"X-Forwarded-Proto",
  1111  	} {
  1112  		if req.Header.Get(h) != "" {
  1113  			t.Errorf("%v header mutation modified caller's request", h)
  1114  		}
  1115  	}
  1116  }
  1117  
  1118  type roundTripperFunc func(req *http.Request) (*http.Response, error)
  1119  
  1120  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  1121  	return fn(req)
  1122  }
  1123  
  1124  func TestModifyResponseClosesBody(t *testing.T) {
  1125  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1126  	req.RemoteAddr = "1.2.3.4:56789"
  1127  	closeCheck := new(checkCloser)
  1128  	logBuf := new(strings.Builder)
  1129  	outErr := errors.New("ModifyResponse error")
  1130  	rp := &ReverseProxy{
  1131  		Director: func(req *http.Request) {},
  1132  		Transport: &staticTransport{&http.Response{
  1133  			StatusCode: 200,
  1134  			Body:       closeCheck,
  1135  		}},
  1136  		ErrorLog: log.New(logBuf, "", 0),
  1137  		ModifyResponse: func(*http.Response) error {
  1138  			return outErr
  1139  		},
  1140  	}
  1141  	rec := httptest.NewRecorder()
  1142  	rp.ServeHTTP(rec, req)
  1143  	res := rec.Result()
  1144  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1145  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1146  	}
  1147  	if !closeCheck.closed {
  1148  		t.Errorf("body should have been closed")
  1149  	}
  1150  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1151  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1152  	}
  1153  }
  1154  
  1155  type checkCloser struct {
  1156  	closed bool
  1157  }
  1158  
  1159  func (cc *checkCloser) Close() error {
  1160  	cc.closed = true
  1161  	return nil
  1162  }
  1163  
  1164  func (cc *checkCloser) Read(b []byte) (int, error) {
  1165  	return len(b), nil
  1166  }
  1167  
  1168  // Issue 23643: panic on body copy error
  1169  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1170  	log.SetOutput(io.Discard)
  1171  	defer log.SetOutput(os.Stderr)
  1172  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1173  		out := "this call was relayed by the reverse proxy"
  1174  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1175  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1176  		fmt.Fprintln(w, out)
  1177  	}))
  1178  	defer backendServer.Close()
  1179  
  1180  	rpURL, err := url.Parse(backendServer.URL)
  1181  	if err != nil {
  1182  		t.Fatal(err)
  1183  	}
  1184  
  1185  	rproxy := NewSingleHostReverseProxy(rpURL)
  1186  
  1187  	// Ensure that the handler panics when the body read encounters an
  1188  	// io.ErrUnexpectedEOF
  1189  	defer func() {
  1190  		err := recover()
  1191  		if err == nil {
  1192  			t.Fatal("handler should have panicked")
  1193  		}
  1194  		if err != http.ErrAbortHandler {
  1195  			t.Fatal("expected ErrAbortHandler, got", err)
  1196  		}
  1197  	}()
  1198  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1199  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1200  }
  1201  
  1202  // Issue #46866: panic without closing incoming request body causes a panic
  1203  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1204  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1205  		out := "this call was relayed by the reverse proxy"
  1206  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1207  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1208  		fmt.Fprintln(w, out)
  1209  	}))
  1210  	defer backend.Close()
  1211  	backendURL, err := url.Parse(backend.URL)
  1212  	if err != nil {
  1213  		t.Fatal(err)
  1214  	}
  1215  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1216  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1217  	frontend := httptest.NewServer(proxyHandler)
  1218  	defer frontend.Close()
  1219  	frontendClient := frontend.Client()
  1220  
  1221  	var wg sync.WaitGroup
  1222  	for i := 0; i < 2; i++ {
  1223  		wg.Add(1)
  1224  		go func() {
  1225  			defer wg.Done()
  1226  			for j := 0; j < 10; j++ {
  1227  				const reqLen = 6 * 1024 * 1024
  1228  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1229  				req.ContentLength = reqLen
  1230  				resp, _ := frontendClient.Transport.RoundTrip(req)
  1231  				if resp != nil {
  1232  					io.Copy(io.Discard, resp.Body)
  1233  					resp.Body.Close()
  1234  				}
  1235  			}
  1236  		}()
  1237  	}
  1238  	wg.Wait()
  1239  }
  1240  
  1241  func TestSelectFlushInterval(t *testing.T) {
  1242  	tests := []struct {
  1243  		name string
  1244  		p    *ReverseProxy
  1245  		res  *http.Response
  1246  		want time.Duration
  1247  	}{
  1248  		{
  1249  			name: "default",
  1250  			res:  &http.Response{},
  1251  			p:    &ReverseProxy{FlushInterval: 123},
  1252  			want: 123,
  1253  		},
  1254  		{
  1255  			name: "server-sent events overrides non-zero",
  1256  			res: &http.Response{
  1257  				Header: http.Header{
  1258  					"Content-Type": {"text/event-stream"},
  1259  				},
  1260  			},
  1261  			p:    &ReverseProxy{FlushInterval: 123},
  1262  			want: -1,
  1263  		},
  1264  		{
  1265  			name: "server-sent events overrides zero",
  1266  			res: &http.Response{
  1267  				Header: http.Header{
  1268  					"Content-Type": {"text/event-stream"},
  1269  				},
  1270  			},
  1271  			p:    &ReverseProxy{FlushInterval: 0},
  1272  			want: -1,
  1273  		},
  1274  		{
  1275  			name: "server-sent events with media-type parameters overrides non-zero",
  1276  			res: &http.Response{
  1277  				Header: http.Header{
  1278  					"Content-Type": {"text/event-stream;charset=utf-8"},
  1279  				},
  1280  			},
  1281  			p:    &ReverseProxy{FlushInterval: 123},
  1282  			want: -1,
  1283  		},
  1284  		{
  1285  			name: "server-sent events with media-type parameters overrides zero",
  1286  			res: &http.Response{
  1287  				Header: http.Header{
  1288  					"Content-Type": {"text/event-stream;charset=utf-8"},
  1289  				},
  1290  			},
  1291  			p:    &ReverseProxy{FlushInterval: 0},
  1292  			want: -1,
  1293  		},
  1294  		{
  1295  			name: "Content-Length: -1, overrides non-zero",
  1296  			res: &http.Response{
  1297  				ContentLength: -1,
  1298  			},
  1299  			p:    &ReverseProxy{FlushInterval: 123},
  1300  			want: -1,
  1301  		},
  1302  		{
  1303  			name: "Content-Length: -1, overrides zero",
  1304  			res: &http.Response{
  1305  				ContentLength: -1,
  1306  			},
  1307  			p:    &ReverseProxy{FlushInterval: 0},
  1308  			want: -1,
  1309  		},
  1310  	}
  1311  	for _, tt := range tests {
  1312  		t.Run(tt.name, func(t *testing.T) {
  1313  			got := tt.p.flushInterval(tt.res)
  1314  			if got != tt.want {
  1315  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1316  			}
  1317  		})
  1318  	}
  1319  }
  1320  
  1321  func TestReverseProxyWebSocket(t *testing.T) {
  1322  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1323  		if upgradeType(r.Header) != "websocket" {
  1324  			t.Error("unexpected backend request")
  1325  			http.Error(w, "unexpected request", 400)
  1326  			return
  1327  		}
  1328  		c, _, err := w.(http.Hijacker).Hijack()
  1329  		if err != nil {
  1330  			t.Error(err)
  1331  			return
  1332  		}
  1333  		defer c.Close()
  1334  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1335  		bs := bufio.NewScanner(c)
  1336  		if !bs.Scan() {
  1337  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1338  			return
  1339  		}
  1340  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1341  	}))
  1342  	defer backendServer.Close()
  1343  
  1344  	backURL, _ := url.Parse(backendServer.URL)
  1345  	rproxy := NewSingleHostReverseProxy(backURL)
  1346  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1347  	rproxy.ModifyResponse = func(res *http.Response) error {
  1348  		res.Header.Add("X-Modified", "true")
  1349  		return nil
  1350  	}
  1351  
  1352  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1353  		rw.Header().Set("X-Header", "X-Value")
  1354  		rproxy.ServeHTTP(rw, req)
  1355  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1356  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1357  		}
  1358  	})
  1359  
  1360  	frontendProxy := httptest.NewServer(handler)
  1361  	defer frontendProxy.Close()
  1362  
  1363  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1364  	req.Header.Set("Connection", "Upgrade")
  1365  	req.Header.Set("Upgrade", "websocket")
  1366  
  1367  	c := frontendProxy.Client()
  1368  	res, err := c.Do(req)
  1369  	if err != nil {
  1370  		t.Fatal(err)
  1371  	}
  1372  	if res.StatusCode != 101 {
  1373  		t.Fatalf("status = %v; want 101", res.Status)
  1374  	}
  1375  
  1376  	got := res.Header.Get("X-Header")
  1377  	want := "X-Value"
  1378  	if got != want {
  1379  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1380  	}
  1381  
  1382  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1383  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1384  	}
  1385  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1386  	if !ok {
  1387  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1388  	}
  1389  	defer rwc.Close()
  1390  
  1391  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1392  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1393  	}
  1394  
  1395  	io.WriteString(rwc, "Hello\n")
  1396  	bs := bufio.NewScanner(rwc)
  1397  	if !bs.Scan() {
  1398  		t.Fatalf("Scan: %v", bs.Err())
  1399  	}
  1400  	got = bs.Text()
  1401  	want = `backend got "Hello"`
  1402  	if got != want {
  1403  		t.Errorf("got %#q, want %#q", got, want)
  1404  	}
  1405  }
  1406  
  1407  func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1408  	n := 5
  1409  	triggerCancelCh := make(chan bool, n)
  1410  	nthResponse := func(i int) string {
  1411  		return fmt.Sprintf("backend response #%d\n", i)
  1412  	}
  1413  	terminalMsg := "final message"
  1414  
  1415  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1416  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1417  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1418  			http.Error(w, "Unexpected request", 400)
  1419  			return
  1420  		}
  1421  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1422  		if err != nil {
  1423  			t.Error(err)
  1424  			return
  1425  		}
  1426  		defer conn.Close()
  1427  
  1428  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1429  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1430  			t.Error(err)
  1431  			return
  1432  		}
  1433  		if _, _, err := bufrw.ReadLine(); err != nil {
  1434  			t.Errorf("Failed to read line from client: %v", err)
  1435  			return
  1436  		}
  1437  
  1438  		for i := 0; i < n; i++ {
  1439  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1440  				select {
  1441  				case <-triggerCancelCh:
  1442  				default:
  1443  					t.Errorf("Writing response #%d failed: %v", i, err)
  1444  				}
  1445  				return
  1446  			}
  1447  			bufrw.Flush()
  1448  			time.Sleep(time.Second)
  1449  		}
  1450  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1451  			select {
  1452  			case <-triggerCancelCh:
  1453  			default:
  1454  				t.Errorf("Failed to write terminal message: %v", err)
  1455  			}
  1456  		}
  1457  		bufrw.Flush()
  1458  	}))
  1459  	defer cst.Close()
  1460  
  1461  	backendURL, _ := url.Parse(cst.URL)
  1462  	rproxy := NewSingleHostReverseProxy(backendURL)
  1463  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1464  	rproxy.ModifyResponse = func(res *http.Response) error {
  1465  		res.Header.Add("X-Modified", "true")
  1466  		return nil
  1467  	}
  1468  
  1469  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1470  		rw.Header().Set("X-Header", "X-Value")
  1471  		ctx, cancel := context.WithCancel(req.Context())
  1472  		go func() {
  1473  			<-triggerCancelCh
  1474  			cancel()
  1475  		}()
  1476  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1477  	})
  1478  
  1479  	frontendProxy := httptest.NewServer(handler)
  1480  	defer frontendProxy.Close()
  1481  
  1482  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1483  	req.Header.Set("Connection", "Upgrade")
  1484  	req.Header.Set("Upgrade", "websocket")
  1485  
  1486  	res, err := frontendProxy.Client().Do(req)
  1487  	if err != nil {
  1488  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1489  	}
  1490  	defer res.Body.Close()
  1491  	if g, w := res.StatusCode, 101; g != w {
  1492  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1493  	}
  1494  
  1495  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1496  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1497  	}
  1498  
  1499  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1500  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1501  	}
  1502  
  1503  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1504  	if !ok {
  1505  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1506  	}
  1507  
  1508  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1509  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1510  	}
  1511  
  1512  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1513  		t.Fatalf("Failed to write first message: %v", err)
  1514  	}
  1515  
  1516  	// Read loop.
  1517  
  1518  	br := bufio.NewReader(rwc)
  1519  	for {
  1520  		line, err := br.ReadString('\n')
  1521  		switch {
  1522  		case line == terminalMsg: // this case before "err == io.EOF"
  1523  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1524  
  1525  		case err == io.EOF:
  1526  			return
  1527  
  1528  		case err != nil:
  1529  			t.Fatalf("Unexpected error: %v", err)
  1530  
  1531  		case line == nthResponse(0): // We've gotten the first response back
  1532  			// Let's trigger a cancel.
  1533  			close(triggerCancelCh)
  1534  		}
  1535  	}
  1536  }
  1537  
  1538  func TestUnannouncedTrailer(t *testing.T) {
  1539  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1540  		w.WriteHeader(http.StatusOK)
  1541  		w.(http.Flusher).Flush()
  1542  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1543  	}))
  1544  	defer backend.Close()
  1545  	backendURL, err := url.Parse(backend.URL)
  1546  	if err != nil {
  1547  		t.Fatal(err)
  1548  	}
  1549  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1550  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1551  	frontend := httptest.NewServer(proxyHandler)
  1552  	defer frontend.Close()
  1553  	frontendClient := frontend.Client()
  1554  
  1555  	res, err := frontendClient.Get(frontend.URL)
  1556  	if err != nil {
  1557  		t.Fatalf("Get: %v", err)
  1558  	}
  1559  
  1560  	io.ReadAll(res.Body)
  1561  
  1562  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1563  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1564  	}
  1565  
  1566  }
  1567  
  1568  func TestSetURL(t *testing.T) {
  1569  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1570  		w.Write([]byte(r.Host))
  1571  	}))
  1572  	defer backend.Close()
  1573  	backendURL, err := url.Parse(backend.URL)
  1574  	if err != nil {
  1575  		t.Fatal(err)
  1576  	}
  1577  	proxyHandler := &ReverseProxy{
  1578  		Rewrite: func(r *ProxyRequest) {
  1579  			r.SetURL(backendURL)
  1580  		},
  1581  	}
  1582  	frontend := httptest.NewServer(proxyHandler)
  1583  	defer frontend.Close()
  1584  	frontendClient := frontend.Client()
  1585  
  1586  	res, err := frontendClient.Get(frontend.URL)
  1587  	if err != nil {
  1588  		t.Fatalf("Get: %v", err)
  1589  	}
  1590  	defer res.Body.Close()
  1591  
  1592  	body, err := io.ReadAll(res.Body)
  1593  	if err != nil {
  1594  		t.Fatalf("Reading body: %v", err)
  1595  	}
  1596  
  1597  	if got, want := string(body), backendURL.Host; got != want {
  1598  		t.Errorf("backend got Host %q, want %q", got, want)
  1599  	}
  1600  }
  1601  
  1602  func TestSingleJoinSlash(t *testing.T) {
  1603  	tests := []struct {
  1604  		slasha   string
  1605  		slashb   string
  1606  		expected string
  1607  	}{
  1608  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1609  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1610  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1611  		{"https://www.google.com", "", "https://www.google.com/"},
  1612  		{"", "favicon.ico", "/favicon.ico"},
  1613  	}
  1614  	for _, tt := range tests {
  1615  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1616  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1617  				tt.slasha,
  1618  				tt.slashb,
  1619  				tt.expected,
  1620  				got)
  1621  		}
  1622  	}
  1623  }
  1624  
  1625  func TestJoinURLPath(t *testing.T) {
  1626  	tests := []struct {
  1627  		a        *url.URL
  1628  		b        *url.URL
  1629  		wantPath string
  1630  		wantRaw  string
  1631  	}{
  1632  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1633  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1634  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1635  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1636  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1637  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1638  	}
  1639  
  1640  	for _, tt := range tests {
  1641  		p, rp := joinURLPath(tt.a, tt.b)
  1642  		if p != tt.wantPath || rp != tt.wantRaw {
  1643  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1644  				tt.a.Path, tt.a.RawPath,
  1645  				tt.b.Path, tt.b.RawPath,
  1646  				tt.wantPath, tt.wantRaw,
  1647  				p, rp)
  1648  		}
  1649  	}
  1650  }
  1651  
  1652  func TestReverseProxyRewriteReplacesOut(t *testing.T) {
  1653  	const content = "response_content"
  1654  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1655  		w.Write([]byte(content))
  1656  	}))
  1657  	defer backend.Close()
  1658  	proxyHandler := &ReverseProxy{
  1659  		Rewrite: func(r *ProxyRequest) {
  1660  			r.Out, _ = http.NewRequest("GET", backend.URL, nil)
  1661  		},
  1662  	}
  1663  	frontend := httptest.NewServer(proxyHandler)
  1664  	defer frontend.Close()
  1665  
  1666  	res, err := frontend.Client().Get(frontend.URL)
  1667  	if err != nil {
  1668  		t.Fatalf("Get: %v", err)
  1669  	}
  1670  	defer res.Body.Close()
  1671  	body, _ := io.ReadAll(res.Body)
  1672  	if got, want := string(body), content; got != want {
  1673  		t.Errorf("got response %q, want %q", got, want)
  1674  	}
  1675  }
  1676  
  1677  func Test1xxResponses(t *testing.T) {
  1678  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1679  		h := w.Header()
  1680  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1681  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1682  		w.WriteHeader(http.StatusEarlyHints)
  1683  
  1684  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1685  		w.WriteHeader(http.StatusProcessing)
  1686  
  1687  		w.Write([]byte("Hello"))
  1688  	}))
  1689  	defer backend.Close()
  1690  	backendURL, err := url.Parse(backend.URL)
  1691  	if err != nil {
  1692  		t.Fatal(err)
  1693  	}
  1694  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1695  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1696  	frontend := httptest.NewServer(proxyHandler)
  1697  	defer frontend.Close()
  1698  	frontendClient := frontend.Client()
  1699  
  1700  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1701  		t.Helper()
  1702  
  1703  		if len(expected) != len(got) {
  1704  			t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
  1705  		}
  1706  
  1707  		for i := range expected {
  1708  			if i >= len(got) {
  1709  				t.Errorf("Expected %q link header; got nothing", expected[i])
  1710  
  1711  				continue
  1712  			}
  1713  
  1714  			if expected[i] != got[i] {
  1715  				t.Errorf("Expected %q link header; got %q", expected[i], got[i])
  1716  			}
  1717  		}
  1718  	}
  1719  
  1720  	var respCounter uint8
  1721  	trace := &httptrace.ClientTrace{
  1722  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1723  			switch code {
  1724  			case http.StatusEarlyHints:
  1725  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1726  			case http.StatusProcessing:
  1727  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1728  			default:
  1729  				t.Error("Unexpected 1xx response")
  1730  			}
  1731  
  1732  			respCounter++
  1733  
  1734  			return nil
  1735  		},
  1736  	}
  1737  	req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
  1738  
  1739  	res, err := frontendClient.Do(req)
  1740  	if err != nil {
  1741  		t.Fatalf("Get: %v", err)
  1742  	}
  1743  
  1744  	defer res.Body.Close()
  1745  
  1746  	if respCounter != 2 {
  1747  		t.Errorf("Expected 2 1xx responses; got %d", respCounter)
  1748  	}
  1749  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1750  
  1751  	body, _ := io.ReadAll(res.Body)
  1752  	if string(body) != "Hello" {
  1753  		t.Errorf("Read body %q; want Hello", body)
  1754  	}
  1755  }
  1756  
  1757  const (
  1758  	testWantsCleanQuery = true
  1759  	testWantsRawQuery   = false
  1760  )
  1761  
  1762  func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
  1763  	testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
  1764  		proxyHandler := NewSingleHostReverseProxy(u)
  1765  		oldDirector := proxyHandler.Director
  1766  		proxyHandler.Director = func(r *http.Request) {
  1767  			oldDirector(r)
  1768  		}
  1769  		return proxyHandler
  1770  	})
  1771  }
  1772  
  1773  func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
  1774  	testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
  1775  		proxyHandler := NewSingleHostReverseProxy(u)
  1776  		oldDirector := proxyHandler.Director
  1777  		proxyHandler.Director = func(r *http.Request) {
  1778  			// Parsing the form causes ReverseProxy to remove unparsable
  1779  			// query parameters before forwarding.
  1780  			r.FormValue("a")
  1781  			oldDirector(r)
  1782  		}
  1783  		return proxyHandler
  1784  	})
  1785  }
  1786  
  1787  func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
  1788  	testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
  1789  		return &ReverseProxy{
  1790  			Rewrite: func(r *ProxyRequest) {
  1791  				r.SetURL(u)
  1792  			},
  1793  		}
  1794  	})
  1795  }
  1796  
  1797  func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
  1798  	testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
  1799  		return &ReverseProxy{
  1800  			Rewrite: func(r *ProxyRequest) {
  1801  				r.SetURL(u)
  1802  				r.Out.URL.RawQuery = r.In.URL.RawQuery
  1803  			},
  1804  		}
  1805  	})
  1806  }
  1807  
  1808  func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
  1809  	const content = "response_content"
  1810  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1811  		w.Write([]byte(r.URL.RawQuery))
  1812  	}))
  1813  	defer backend.Close()
  1814  	backendURL, err := url.Parse(backend.URL)
  1815  	if err != nil {
  1816  		t.Fatal(err)
  1817  	}
  1818  	proxyHandler := newProxy(backendURL)
  1819  	frontend := httptest.NewServer(proxyHandler)
  1820  	defer frontend.Close()
  1821  
  1822  	// Don't spam output with logs of queries containing semicolons.
  1823  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
  1824  	frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
  1825  
  1826  	for _, test := range []struct {
  1827  		rawQuery   string
  1828  		cleanQuery string
  1829  	}{{
  1830  		rawQuery:   "a=1&a=2;b=3",
  1831  		cleanQuery: "a=1",
  1832  	}, {
  1833  		rawQuery:   "a=1&a=%zz&b=3",
  1834  		cleanQuery: "a=1&b=3",
  1835  	}} {
  1836  		res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
  1837  		if err != nil {
  1838  			t.Fatalf("Get: %v", err)
  1839  		}
  1840  		defer res.Body.Close()
  1841  		body, _ := io.ReadAll(res.Body)
  1842  		wantQuery := test.rawQuery
  1843  		if wantCleanQuery {
  1844  			wantQuery = test.cleanQuery
  1845  		}
  1846  		if got, want := string(body), wantQuery; got != want {
  1847  			t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
  1848  		}
  1849  	}
  1850  }