github.com/sean-/go@v0.0.0-20151219100004-97f854cd7bb6/src/net/http/httputil/reverseproxy_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"io"
    12  	"io/ioutil"
    13  	"log"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"net/url"
    17  	"reflect"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  )
    24  
    25  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    26  
    27  func init() {
    28  	hopHeaders = append(hopHeaders, fakeHopHeader)
    29  }
    30  
    31  func TestReverseProxy(t *testing.T) {
    32  	const backendResponse = "I am the backend"
    33  	const backendStatus = 404
    34  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    35  		if len(r.TransferEncoding) > 0 {
    36  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    37  		}
    38  		if r.Header.Get("X-Forwarded-For") == "" {
    39  			t.Errorf("didn't get X-Forwarded-For header")
    40  		}
    41  		if c := r.Header.Get("Connection"); c != "" {
    42  			t.Errorf("handler got Connection header value %q", c)
    43  		}
    44  		if c := r.Header.Get("Upgrade"); c != "" {
    45  			t.Errorf("handler got Upgrade header value %q", c)
    46  		}
    47  		if g, e := r.Host, "some-name"; g != e {
    48  			t.Errorf("backend got Host header %q, want %q", g, e)
    49  		}
    50  		w.Header().Set("Trailer", "X-Trailer")
    51  		w.Header().Set("X-Foo", "bar")
    52  		w.Header().Set("Upgrade", "foo")
    53  		w.Header().Set(fakeHopHeader, "foo")
    54  		w.Header().Add("X-Multi-Value", "foo")
    55  		w.Header().Add("X-Multi-Value", "bar")
    56  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    57  		w.WriteHeader(backendStatus)
    58  		w.Write([]byte(backendResponse))
    59  		w.Header().Set("X-Trailer", "trailer_value")
    60  	}))
    61  	defer backend.Close()
    62  	backendURL, err := url.Parse(backend.URL)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    67  	frontend := httptest.NewServer(proxyHandler)
    68  	defer frontend.Close()
    69  
    70  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    71  	getReq.Host = "some-name"
    72  	getReq.Header.Set("Connection", "close")
    73  	getReq.Header.Set("Upgrade", "foo")
    74  	getReq.Close = true
    75  	res, err := http.DefaultClient.Do(getReq)
    76  	if err != nil {
    77  		t.Fatalf("Get: %v", err)
    78  	}
    79  	if g, e := res.StatusCode, backendStatus; g != e {
    80  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
    81  	}
    82  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
    83  		t.Errorf("got X-Foo %q; expected %q", g, e)
    84  	}
    85  	if c := res.Header.Get(fakeHopHeader); c != "" {
    86  		t.Errorf("got %s header value %q", fakeHopHeader, c)
    87  	}
    88  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
    89  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
    90  	}
    91  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
    92  		t.Fatalf("got %d SetCookies, want %d", g, e)
    93  	}
    94  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
    95  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
    96  	}
    97  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
    98  		t.Errorf("unexpected cookie %q", cookie.Name)
    99  	}
   100  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   101  	if g, e := string(bodyBytes), backendResponse; g != e {
   102  		t.Errorf("got body %q; expected %q", g, e)
   103  	}
   104  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   105  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   106  	}
   107  
   108  }
   109  
   110  func TestXForwardedFor(t *testing.T) {
   111  	const prevForwardedFor = "client ip"
   112  	const backendResponse = "I am the backend"
   113  	const backendStatus = 404
   114  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   115  		if r.Header.Get("X-Forwarded-For") == "" {
   116  			t.Errorf("didn't get X-Forwarded-For header")
   117  		}
   118  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   119  			t.Errorf("X-Forwarded-For didn't contain prior data")
   120  		}
   121  		w.WriteHeader(backendStatus)
   122  		w.Write([]byte(backendResponse))
   123  	}))
   124  	defer backend.Close()
   125  	backendURL, err := url.Parse(backend.URL)
   126  	if err != nil {
   127  		t.Fatal(err)
   128  	}
   129  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   130  	frontend := httptest.NewServer(proxyHandler)
   131  	defer frontend.Close()
   132  
   133  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   134  	getReq.Host = "some-name"
   135  	getReq.Header.Set("Connection", "close")
   136  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   137  	getReq.Close = true
   138  	res, err := http.DefaultClient.Do(getReq)
   139  	if err != nil {
   140  		t.Fatalf("Get: %v", err)
   141  	}
   142  	if g, e := res.StatusCode, backendStatus; g != e {
   143  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   144  	}
   145  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   146  	if g, e := string(bodyBytes), backendResponse; g != e {
   147  		t.Errorf("got body %q; expected %q", g, e)
   148  	}
   149  }
   150  
   151  var proxyQueryTests = []struct {
   152  	baseSuffix string // suffix to add to backend URL
   153  	reqSuffix  string // suffix to add to frontend's request URL
   154  	want       string // what backend should see for final request URL (without ?)
   155  }{
   156  	{"", "", ""},
   157  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   158  	{"", "?us=er", "us=er"},
   159  	{"?sta=tic", "", "sta=tic"},
   160  }
   161  
   162  func TestReverseProxyQuery(t *testing.T) {
   163  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   164  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   165  		w.Write([]byte("hi"))
   166  	}))
   167  	defer backend.Close()
   168  
   169  	for i, tt := range proxyQueryTests {
   170  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   171  		if err != nil {
   172  			t.Fatal(err)
   173  		}
   174  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   175  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   176  		req.Close = true
   177  		res, err := http.DefaultClient.Do(req)
   178  		if err != nil {
   179  			t.Fatalf("%d. Get: %v", i, err)
   180  		}
   181  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   182  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   183  		}
   184  		res.Body.Close()
   185  		frontend.Close()
   186  	}
   187  }
   188  
   189  func TestReverseProxyFlushInterval(t *testing.T) {
   190  	const expected = "hi"
   191  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   192  		w.Write([]byte(expected))
   193  	}))
   194  	defer backend.Close()
   195  
   196  	backendURL, err := url.Parse(backend.URL)
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  
   201  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   202  	proxyHandler.FlushInterval = time.Microsecond
   203  
   204  	done := make(chan bool)
   205  	onExitFlushLoop = func() { done <- true }
   206  	defer func() { onExitFlushLoop = nil }()
   207  
   208  	frontend := httptest.NewServer(proxyHandler)
   209  	defer frontend.Close()
   210  
   211  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   212  	req.Close = true
   213  	res, err := http.DefaultClient.Do(req)
   214  	if err != nil {
   215  		t.Fatalf("Get: %v", err)
   216  	}
   217  	defer res.Body.Close()
   218  	if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
   219  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   220  	}
   221  
   222  	select {
   223  	case <-done:
   224  		// OK
   225  	case <-time.After(5 * time.Second):
   226  		t.Error("maxLatencyWriter flushLoop() never exited")
   227  	}
   228  }
   229  
   230  func TestReverseProxyCancelation(t *testing.T) {
   231  	const backendResponse = "I am the backend"
   232  
   233  	reqInFlight := make(chan struct{})
   234  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   235  		close(reqInFlight)
   236  
   237  		select {
   238  		case <-time.After(10 * time.Second):
   239  			// Note: this should only happen in broken implementations, and the
   240  			// closenotify case should be instantaneous.
   241  			t.Log("Failed to close backend connection")
   242  			t.Fail()
   243  		case <-w.(http.CloseNotifier).CloseNotify():
   244  		}
   245  
   246  		w.WriteHeader(http.StatusOK)
   247  		w.Write([]byte(backendResponse))
   248  	}))
   249  
   250  	defer backend.Close()
   251  
   252  	backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
   253  
   254  	backendURL, err := url.Parse(backend.URL)
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  
   259  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   260  
   261  	// Discards errors of the form:
   262  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   263  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
   264  
   265  	frontend := httptest.NewServer(proxyHandler)
   266  	defer frontend.Close()
   267  
   268  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   269  	go func() {
   270  		<-reqInFlight
   271  		http.DefaultTransport.(*http.Transport).CancelRequest(getReq)
   272  	}()
   273  	res, err := http.DefaultClient.Do(getReq)
   274  	if res != nil {
   275  		t.Fatal("Non-nil response")
   276  	}
   277  	if err == nil {
   278  		// This should be an error like:
   279  		// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
   280  		//    use of closed network connection
   281  		t.Fatal("DefaultClient.Do() returned nil error")
   282  	}
   283  }
   284  
   285  func req(t *testing.T, v string) *http.Request {
   286  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   287  	if err != nil {
   288  		t.Fatal(err)
   289  	}
   290  	return req
   291  }
   292  
   293  // Issue 12344
   294  func TestNilBody(t *testing.T) {
   295  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   296  		w.Write([]byte("hi"))
   297  	}))
   298  	defer backend.Close()
   299  
   300  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   301  		backURL, _ := url.Parse(backend.URL)
   302  		rp := NewSingleHostReverseProxy(backURL)
   303  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   304  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   305  		rp.ServeHTTP(w, r)
   306  	}))
   307  	defer frontend.Close()
   308  
   309  	res, err := http.Get(frontend.URL)
   310  	if err != nil {
   311  		t.Fatal(err)
   312  	}
   313  	defer res.Body.Close()
   314  	slurp, err := ioutil.ReadAll(res.Body)
   315  	if err != nil {
   316  		t.Fatal(err)
   317  	}
   318  	if string(slurp) != "hi" {
   319  		t.Errorf("Got %q; want %q", slurp, "hi")
   320  	}
   321  }
   322  
   323  type bufferPool struct {
   324  	get func() []byte
   325  	put func([]byte)
   326  }
   327  
   328  func (bp bufferPool) Get() []byte  { return bp.get() }
   329  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   330  
   331  func TestReverseProxyGetPutBuffer(t *testing.T) {
   332  	const msg = "hi"
   333  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   334  		io.WriteString(w, msg)
   335  	}))
   336  	defer backend.Close()
   337  
   338  	backendURL, err := url.Parse(backend.URL)
   339  	if err != nil {
   340  		t.Fatal(err)
   341  	}
   342  
   343  	var (
   344  		mu  sync.Mutex
   345  		log []string
   346  	)
   347  	addLog := func(event string) {
   348  		mu.Lock()
   349  		defer mu.Unlock()
   350  		log = append(log, event)
   351  	}
   352  	rp := NewSingleHostReverseProxy(backendURL)
   353  	const size = 1234
   354  	rp.BufferPool = bufferPool{
   355  		get: func() []byte {
   356  			addLog("getBuf")
   357  			return make([]byte, size)
   358  		},
   359  		put: func(p []byte) {
   360  			addLog("putBuf-" + strconv.Itoa(len(p)))
   361  		},
   362  	}
   363  	frontend := httptest.NewServer(rp)
   364  	defer frontend.Close()
   365  
   366  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   367  	req.Close = true
   368  	res, err := http.DefaultClient.Do(req)
   369  	if err != nil {
   370  		t.Fatalf("Get: %v", err)
   371  	}
   372  	slurp, err := ioutil.ReadAll(res.Body)
   373  	res.Body.Close()
   374  	if err != nil {
   375  		t.Fatalf("reading body: %v", err)
   376  	}
   377  	if string(slurp) != msg {
   378  		t.Errorf("msg = %q; want %q", slurp, msg)
   379  	}
   380  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   381  	mu.Lock()
   382  	defer mu.Unlock()
   383  	if !reflect.DeepEqual(log, wantLog) {
   384  		t.Errorf("Log events = %q; want %q", log, wantLog)
   385  	}
   386  }