github.com/xmidt-org/webpa-common@v1.11.9/xhttp/redirectPolicy_test.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  	"github.com/xmidt-org/webpa-common/logging"
    11  )
    12  
    13  func testRedirectPolicyDefault(t *testing.T) {
    14  	var (
    15  		assert  = assert.New(t)
    16  		require = require.New(t)
    17  		p       = RedirectPolicy{}
    18  	)
    19  
    20  	assert.Equal(DefaultMaxRedirects, p.maxRedirects())
    21  
    22  	f := p.headerFilter()
    23  	require.NotNil(f)
    24  	assert.True(f("something"))
    25  }
    26  
    27  func testRedirectPolicyCustom(t *testing.T) {
    28  	var (
    29  		assert         = assert.New(t)
    30  		require        = require.New(t)
    31  		expectedLogger = logging.NewTestLogger(nil, t)
    32  
    33  		p = RedirectPolicy{
    34  			Logger:         expectedLogger,
    35  			MaxRedirects:   7,
    36  			ExcludeHeaders: []string{"content-type"},
    37  		}
    38  	)
    39  
    40  	assert.Equal(7, p.maxRedirects())
    41  
    42  	f := p.headerFilter()
    43  	require.NotNil(f)
    44  	assert.True(f("Accept"))
    45  	assert.False(f("Content-Type"))
    46  }
    47  
    48  func TestRedirectPolicy(t *testing.T) {
    49  	t.Run("Default", testRedirectPolicyDefault)
    50  	t.Run("Custom", testRedirectPolicyCustom)
    51  }
    52  
    53  func testCheckRedirectMaxRedirects(t *testing.T) {
    54  	var (
    55  		assert = assert.New(t)
    56  
    57  		via = []*http.Request{
    58  			httptest.NewRequest("GET", "/first", nil),
    59  			httptest.NewRequest("GET", "/second", nil),
    60  		}
    61  
    62  		checkRedirect = CheckRedirect(
    63  			RedirectPolicy{
    64  				MaxRedirects: 2,
    65  			},
    66  		)
    67  	)
    68  
    69  	assert.Error(checkRedirect(httptest.NewRequest("GET", "/", nil), via))
    70  }
    71  
    72  func testCheckRedirectCopyHeaders(t *testing.T) {
    73  	var (
    74  		assert = assert.New(t)
    75  
    76  		checkRedirect = CheckRedirect(RedirectPolicy{
    77  			Logger:         logging.NewTestLogger(nil, t),
    78  			ExcludeHeaders: []string{"content-type", "X-Supar-Sekrit"},
    79  		})
    80  
    81  		r   = httptest.NewRequest("GET", "/", nil)
    82  		via = []*http.Request{
    83  			httptest.NewRequest("GET", "/", nil),
    84  		}
    85  	)
    86  
    87  	via[len(via)-1].Header.Set("Content-Type", "text/plain")
    88  	via[len(via)-1].Header.Add("x-supar-sekrit", "don't reveal me, bro!")
    89  
    90  	via[len(via)-1].Header.Set("X-Something", "value")
    91  	via[len(via)-1].Header.Add("X-Something-Else", "value1")
    92  	via[len(via)-1].Header.Add("X-Something-Else", "value2")
    93  
    94  	checkRedirect(r, via)
    95  	assert.Equal("value", r.Header.Get("X-Something"))
    96  	assert.Equal([]string{"value1", "value2"}, r.Header["X-Something-Else"])
    97  	assert.Equal("", r.Header.Get("Content-Type"))
    98  	assert.Equal("", r.Header.Get("X-Supar-Sekrit"))
    99  }
   100  
   101  func TestCheckRedirect(t *testing.T) {
   102  	t.Run("MaxRedirects", testCheckRedirectMaxRedirects)
   103  	t.Run("CopyHeaders", testCheckRedirectCopyHeaders)
   104  }