github.com/c0deoo1/golang1.5@v0.0.0-20220525150107-c87c805d4593/src/net/http/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"io/ioutil"
    11  	"log"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"reflect"
    16  	"runtime"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    23  
    24  func init() {
    25  	hopHeaders = append(hopHeaders, fakeHopHeader)
    26  }
    27  
    28  func TestReverseProxy(t *testing.T) {
    29  	const backendResponse = "I am the backend"
    30  	const backendStatus = 404
    31  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  		if len(r.TransferEncoding) > 0 {
    33  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    34  		}
    35  		if r.Header.Get("X-Forwarded-For") == "" {
    36  			t.Errorf("didn't get X-Forwarded-For header")
    37  		}
    38  		if c := r.Header.Get("Connection"); c != "" {
    39  			t.Errorf("handler got Connection header value %q", c)
    40  		}
    41  		if c := r.Header.Get("Upgrade"); c != "" {
    42  			t.Errorf("handler got Upgrade header value %q", c)
    43  		}
    44  		if g, e := r.Host, "some-name"; g != e {
    45  			t.Errorf("backend got Host header %q, want %q", g, e)
    46  		}
    47  		w.Header().Set("Trailer", "X-Trailer")
    48  		w.Header().Set("X-Foo", "bar")
    49  		w.Header().Set("Upgrade", "foo")
    50  		w.Header().Set(fakeHopHeader, "foo")
    51  		w.Header().Add("X-Multi-Value", "foo")
    52  		w.Header().Add("X-Multi-Value", "bar")
    53  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    54  		w.WriteHeader(backendStatus)
    55  		w.Write([]byte(backendResponse))
    56  		w.Header().Set("X-Trailer", "trailer_value")
    57  	}))
    58  	defer backend.Close()
    59  	backendURL, err := url.Parse(backend.URL)
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    64  	frontend := httptest.NewServer(proxyHandler)
    65  	defer frontend.Close()
    66  
    67  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    68  	getReq.Host = "some-name"
    69  	getReq.Header.Set("Connection", "close")
    70  	getReq.Header.Set("Upgrade", "foo")
    71  	getReq.Close = true
    72  	res, err := http.DefaultClient.Do(getReq)
    73  	if err != nil {
    74  		t.Fatalf("Get: %v", err)
    75  	}
    76  	if g, e := res.StatusCode, backendStatus; g != e {
    77  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
    78  	}
    79  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
    80  		t.Errorf("got X-Foo %q; expected %q", g, e)
    81  	}
    82  	if c := res.Header.Get(fakeHopHeader); c != "" {
    83  		t.Errorf("got %s header value %q", fakeHopHeader, c)
    84  	}
    85  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
    86  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
    87  	}
    88  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
    89  		t.Fatalf("got %d SetCookies, want %d", g, e)
    90  	}
    91  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
    92  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
    93  	}
    94  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
    95  		t.Errorf("unexpected cookie %q", cookie.Name)
    96  	}
    97  	bodyBytes, _ := ioutil.ReadAll(res.Body)
    98  	if g, e := string(bodyBytes), backendResponse; g != e {
    99  		t.Errorf("got body %q; expected %q", g, e)
   100  	}
   101  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   102  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   103  	}
   104  
   105  }
   106  
   107  func TestXForwardedFor(t *testing.T) {
   108  	const prevForwardedFor = "client ip"
   109  	const backendResponse = "I am the backend"
   110  	const backendStatus = 404
   111  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   112  		if r.Header.Get("X-Forwarded-For") == "" {
   113  			t.Errorf("didn't get X-Forwarded-For header")
   114  		}
   115  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   116  			t.Errorf("X-Forwarded-For didn't contain prior data")
   117  		}
   118  		w.WriteHeader(backendStatus)
   119  		w.Write([]byte(backendResponse))
   120  	}))
   121  	defer backend.Close()
   122  	backendURL, err := url.Parse(backend.URL)
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   127  	frontend := httptest.NewServer(proxyHandler)
   128  	defer frontend.Close()
   129  
   130  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   131  	getReq.Host = "some-name"
   132  	getReq.Header.Set("Connection", "close")
   133  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   134  	getReq.Close = true
   135  	res, err := http.DefaultClient.Do(getReq)
   136  	if err != nil {
   137  		t.Fatalf("Get: %v", err)
   138  	}
   139  	if g, e := res.StatusCode, backendStatus; g != e {
   140  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   141  	}
   142  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   143  	if g, e := string(bodyBytes), backendResponse; g != e {
   144  		t.Errorf("got body %q; expected %q", g, e)
   145  	}
   146  }
   147  
   148  var proxyQueryTests = []struct {
   149  	baseSuffix string // suffix to add to backend URL
   150  	reqSuffix  string // suffix to add to frontend's request URL
   151  	want       string // what backend should see for final request URL (without ?)
   152  }{
   153  	{"", "", ""},
   154  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   155  	{"", "?us=er", "us=er"},
   156  	{"?sta=tic", "", "sta=tic"},
   157  }
   158  
   159  func TestReverseProxyQuery(t *testing.T) {
   160  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   161  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   162  		w.Write([]byte("hi"))
   163  	}))
   164  	defer backend.Close()
   165  
   166  	for i, tt := range proxyQueryTests {
   167  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   168  		if err != nil {
   169  			t.Fatal(err)
   170  		}
   171  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   172  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   173  		req.Close = true
   174  		res, err := http.DefaultClient.Do(req)
   175  		if err != nil {
   176  			t.Fatalf("%d. Get: %v", i, err)
   177  		}
   178  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   179  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   180  		}
   181  		res.Body.Close()
   182  		frontend.Close()
   183  	}
   184  }
   185  
   186  func TestReverseProxyFlushInterval(t *testing.T) {
   187  	const expected = "hi"
   188  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   189  		w.Write([]byte(expected))
   190  	}))
   191  	defer backend.Close()
   192  
   193  	backendURL, err := url.Parse(backend.URL)
   194  	if err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   199  	proxyHandler.FlushInterval = time.Microsecond
   200  
   201  	done := make(chan bool)
   202  	onExitFlushLoop = func() { done <- true }
   203  	defer func() { onExitFlushLoop = nil }()
   204  
   205  	frontend := httptest.NewServer(proxyHandler)
   206  	defer frontend.Close()
   207  
   208  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   209  	req.Close = true
   210  	res, err := http.DefaultClient.Do(req)
   211  	if err != nil {
   212  		t.Fatalf("Get: %v", err)
   213  	}
   214  	defer res.Body.Close()
   215  	if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
   216  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   217  	}
   218  
   219  	select {
   220  	case <-done:
   221  		// OK
   222  	case <-time.After(5 * time.Second):
   223  		t.Error("maxLatencyWriter flushLoop() never exited")
   224  	}
   225  }
   226  
   227  func TestReverseProxyCancellation(t *testing.T) {
   228  	if runtime.GOOS == "plan9" {
   229  		t.Skip("skipping test; see https://golang.org/issue/9554")
   230  	}
   231  	const backendResponse = "I am the backend"
   232  
   233  	reqInFlight := make(chan struct{})
   234  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   235  		close(reqInFlight)
   236  
   237  		select {
   238  		case <-time.After(10 * time.Second):
   239  			// Note: this should only happen in broken implementations, and the
   240  			// closenotify case should be instantaneous.
   241  			t.Log("Failed to close backend connection")
   242  			t.Fail()
   243  		case <-w.(http.CloseNotifier).CloseNotify():
   244  		}
   245  
   246  		w.WriteHeader(http.StatusOK)
   247  		w.Write([]byte(backendResponse))
   248  	}))
   249  
   250  	defer backend.Close()
   251  
   252  	backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
   253  
   254  	backendURL, err := url.Parse(backend.URL)
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  
   259  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   260  
   261  	// Discards errors of the form:
   262  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   263  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
   264  
   265  	frontend := httptest.NewServer(proxyHandler)
   266  	defer frontend.Close()
   267  
   268  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   269  	go func() {
   270  		<-reqInFlight
   271  		http.DefaultTransport.(*http.Transport).CancelRequest(getReq)
   272  	}()
   273  	res, err := http.DefaultClient.Do(getReq)
   274  	if res != nil {
   275  		t.Fatal("Non-nil response")
   276  	}
   277  	if err == nil {
   278  		// This should be an error like:
   279  		// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
   280  		//    use of closed network connection
   281  		t.Fatal("DefaultClient.Do() returned nil error")
   282  	}
   283  }