github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/handlers/proxy_headers_test.go (about)

     1  package handlers
     2  
     3  import (
     4  	"testing"
     5  
     6  	http "github.com/hxx258456/ccgo/gmhttp"
     7  	"github.com/hxx258456/ccgo/gmhttp/httptest"
     8  )
     9  
    10  type headerTable struct {
    11  	key      string // header key
    12  	val      string // header val
    13  	expected string // expected result
    14  }
    15  
    16  func TestGetIP(t *testing.T) {
    17  	headers := []headerTable{
    18  		{xForwardedFor, "8.8.8.8", "8.8.8.8"},                                   // Single address
    19  		{xForwardedFor, "8.8.8.8, 8.8.4.4", "8.8.8.8"},                          // Multiple
    20  		{xForwardedFor, "[2001:db8:cafe::17]:4711", "[2001:db8:cafe::17]:4711"}, // IPv6 address
    21  		{xForwardedFor, "", ""},                                                 // None
    22  		{xRealIP, "8.8.8.8", "8.8.8.8"},                                         // Single address
    23  		{xRealIP, "8.8.8.8, 8.8.4.4", "8.8.8.8, 8.8.4.4"},                       // Multiple
    24  		{xRealIP, "[2001:db8:cafe::17]:4711", "[2001:db8:cafe::17]:4711"},       // IPv6 address
    25  		{xRealIP, "", ""},                       // None
    26  		{forwarded, `for="_gazonk"`, "_gazonk"}, // Hostname
    27  		{forwarded, `For="[2001:db8:cafe::17]:4711`, `[2001:db8:cafe::17]:4711`},      // IPv6 address
    28  		{forwarded, `for=192.0.2.60;proto=http;by=203.0.113.43`, `192.0.2.60`},        // Multiple params
    29  		{forwarded, `for=192.0.2.43, for=198.51.100.17`, "192.0.2.43"},                // Multiple params
    30  		{forwarded, `for="workstation.local",for=198.51.100.17`, "workstation.local"}, // Hostname
    31  	}
    32  
    33  	for _, v := range headers {
    34  		req := &http.Request{
    35  			Header: http.Header{
    36  				v.key: []string{v.val},
    37  			}}
    38  		res := getIP(req)
    39  		if res != v.expected {
    40  			t.Fatalf("wrong header for %s: got %s want %s", v.key, res,
    41  				v.expected)
    42  		}
    43  	}
    44  }
    45  
    46  func TestGetScheme(t *testing.T) {
    47  	headers := []headerTable{
    48  		{xForwardedProto, "https", "https"},
    49  		{xForwardedProto, "http", "http"},
    50  		{xForwardedProto, "HTTP", "http"},
    51  		{xForwardedScheme, "https", "https"},
    52  		{xForwardedScheme, "http", "http"},
    53  		{xForwardedScheme, "HTTP", "http"},
    54  		{forwarded, `For="[2001:db8:cafe::17]:4711`, ""},                      // No proto
    55  		{forwarded, `for=192.0.2.43, for=198.51.100.17;proto=https`, "https"}, // Multiple params before proto
    56  		{forwarded, `for=172.32.10.15; proto=https;by=127.0.0.1`, "https"},    // Space before proto
    57  		{forwarded, `for=192.0.2.60;proto=http;by=203.0.113.43`, "http"},      // Multiple params
    58  	}
    59  
    60  	for _, v := range headers {
    61  		req := &http.Request{
    62  			Header: http.Header{
    63  				v.key: []string{v.val},
    64  			},
    65  		}
    66  		res := getScheme(req)
    67  		if res != v.expected {
    68  			t.Fatalf("wrong header for %s: got %s want %s", v.key, res,
    69  				v.expected)
    70  		}
    71  	}
    72  }
    73  
    74  // Test the middleware end-to-end
    75  func TestProxyHeaders(t *testing.T) {
    76  	rr := httptest.NewRecorder()
    77  	r := newRequest("GET", "/")
    78  
    79  	r.Header.Set(xForwardedFor, "8.8.8.8")
    80  	r.Header.Set(xForwardedProto, "https")
    81  	r.Header.Set(xForwardedHost, "google.com")
    82  	var (
    83  		addr  string
    84  		proto string
    85  		host  string
    86  	)
    87  	ProxyHeaders(http.HandlerFunc(
    88  		func(w http.ResponseWriter, r *http.Request) {
    89  			addr = r.RemoteAddr
    90  			proto = r.URL.Scheme
    91  			host = r.Host
    92  		})).ServeHTTP(rr, r)
    93  
    94  	if rr.Code != http.StatusOK {
    95  		t.Fatalf("bad status: got %d want %d", rr.Code, http.StatusOK)
    96  	}
    97  
    98  	if addr != r.Header.Get(xForwardedFor) {
    99  		t.Fatalf("wrong address: got %s want %s", addr,
   100  			r.Header.Get(xForwardedFor))
   101  	}
   102  
   103  	if proto != r.Header.Get(xForwardedProto) {
   104  		t.Fatalf("wrong address: got %s want %s", proto,
   105  			r.Header.Get(xForwardedProto))
   106  	}
   107  	if host != r.Header.Get(xForwardedHost) {
   108  		t.Fatalf("wrong address: got %s want %s", host,
   109  			r.Header.Get(xForwardedHost))
   110  	}
   111  
   112  }