github.com/xmidt-org/webpa-common@v1.11.9/secure/handler/authorizationHandler_test.go (about)

     1  package handler
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"strconv"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/mock"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/xmidt-org/webpa-common/logging"
    14  	"github.com/xmidt-org/webpa-common/secure"
    15  )
    16  
    17  const (
    18  	authorizationValue = "Basic dGVzdDp0ZXN0Cg=="
    19  	tokenValue         = "dGVzdDp0ZXN0Cg=="
    20  )
    21  
    22  func testAuthorizationHandlerNoDecoration(t *testing.T) {
    23  	var (
    24  		assert  = assert.New(t)
    25  		require = require.New(t)
    26  
    27  		nextCalled = false
    28  		next       = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
    29  			nextCalled = true
    30  		})
    31  
    32  		handler = AuthorizationHandler{
    33  			Logger: logging.NewTestLogger(nil, t),
    34  		}
    35  
    36  		decorated = handler.Decorate(next)
    37  	)
    38  
    39  	require.NotNil(decorated)
    40  	decorated.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
    41  	assert.True(nextCalled)
    42  }
    43  
    44  func testAuthorizationHandlerNoAuthorization(t *testing.T, expectedStatusCode, configuredStatusCode int) {
    45  	var (
    46  		assert  = assert.New(t)
    47  		require = require.New(t)
    48  
    49  		nextCalled = false
    50  		next       = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
    51  			nextCalled = true
    52  		})
    53  
    54  		validator = new(secure.MockValidator)
    55  		handler   = AuthorizationHandler{
    56  			Logger:              logging.NewTestLogger(nil, t),
    57  			ForbiddenStatusCode: configuredStatusCode,
    58  			Validator:           validator,
    59  		}
    60  
    61  		response  = httptest.NewRecorder()
    62  		request   = httptest.NewRequest("GET", "/", nil)
    63  		decorated = handler.Decorate(next)
    64  	)
    65  
    66  	require.NotNil(decorated)
    67  	decorated.ServeHTTP(response, request)
    68  	assert.Equal(expectedStatusCode, response.Code)
    69  	assert.False(nextCalled)
    70  	validator.AssertExpectations(t)
    71  }
    72  
    73  func testAuthorizationHandlerMalformedAuthorization(t *testing.T, expectedStatusCode, configuredStatusCode int, expectedHeader, configuredHeader string) {
    74  	var (
    75  		assert  = assert.New(t)
    76  		require = require.New(t)
    77  
    78  		nextCalled = false
    79  		next       = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
    80  			nextCalled = true
    81  		})
    82  
    83  		validator = new(secure.MockValidator)
    84  		handler   = AuthorizationHandler{
    85  			Logger:              logging.NewTestLogger(nil, t),
    86  			HeaderName:          configuredHeader,
    87  			ForbiddenStatusCode: configuredStatusCode,
    88  			Validator:           validator,
    89  		}
    90  
    91  		response  = httptest.NewRecorder()
    92  		request   = httptest.NewRequest("GET", "/", nil)
    93  		decorated = handler.Decorate(next)
    94  	)
    95  
    96  	require.NotNil(decorated)
    97  	request.Header.Set(expectedHeader, "there is no way this is a valid authorization header")
    98  	decorated.ServeHTTP(response, request)
    99  	assert.Equal(expectedStatusCode, response.Code)
   100  	assert.False(nextCalled)
   101  	validator.AssertExpectations(t)
   102  }
   103  
   104  func testAuthorizationHandlerValid(t *testing.T, expectedHeader, configuredHeader string) {
   105  	var (
   106  		assert  = assert.New(t)
   107  		require = require.New(t)
   108  
   109  		nextCalled = false
   110  		next       = http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) {
   111  			nextCalled = true
   112  			values, ok := FromContext(request.Context())
   113  			require.True(ok)
   114  			require.NotNil(values)
   115  
   116  			assert.Equal("x1:webpa-internal:5f0183", values.SatClientID)
   117  			assert.Equal([]string{"comcast"}, values.PartnerIDs)
   118  		})
   119  
   120  		validator = new(secure.MockValidator)
   121  		handler   = AuthorizationHandler{
   122  			Logger:     logging.NewTestLogger(nil, t),
   123  			HeaderName: configuredHeader,
   124  			Validator:  validator,
   125  		}
   126  
   127  		response  = httptest.NewRecorder()
   128  		request   = httptest.NewRequest("GET", "/", nil)
   129  		decorated = handler.Decorate(next)
   130  	)
   131  
   132  	require.NotNil(decorated)
   133  	request.Header.Set(expectedHeader, "Bearer eyJhbGciOiJub25lIiwia2lkIjoidGhlbWlzLTIwMTcwMSIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI4ZjA0MmIyOS03ZDE2LTRjMWYtYjBmOS1mNTJhMGFhZDI5YmMiLCJpc3MiOiJzYXRzLXByb2R1Y3Rpb24iLCJzdWIiOiJ4MTp3ZWJwYS1pbnRlcm5hbDo1ZjAxODMiLCJpYXQiOjE1Mjc3MzAwOTYsIm5iZiI6MTUyNzczMDA5NiwiZXhwIjoxNTI3NzczMjk5LCJ2ZXJzaW9uIjoiMS4wIiwiYWxsb3dlZFJlc291cmNlcyI6eyJhbGxvd2VkUGFydG5lcnMiOlsiY29tY2FzdCJdfSwiY2FwYWJpbGl0aWVzIjpbIngxOndlYnBhOmFwaTouKjphbGwiXSwiYXVkIjpbXX0.")
   134  
   135  	validator.On("Validate", mock.MatchedBy(func(context.Context) bool { return true }), mock.MatchedBy(func(*secure.Token) bool { return true })).Return(true, error(nil)).Once()
   136  	decorated.ServeHTTP(response, request)
   137  	assert.Equal(200, response.Code)
   138  	assert.True(nextCalled)
   139  	validator.AssertExpectations(t)
   140  }
   141  
   142  func testAuthorizationHandlerInvalid(t *testing.T, expectedStatusCode, configuredStatusCode int, expectedHeader, configuredHeader string) {
   143  	var (
   144  		assert  = assert.New(t)
   145  		require = require.New(t)
   146  
   147  		nextCalled = false
   148  		next       = http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) {
   149  			nextCalled = true
   150  		})
   151  
   152  		validator = new(secure.MockValidator)
   153  		handler   = AuthorizationHandler{
   154  			Logger:              logging.NewTestLogger(nil, t),
   155  			HeaderName:          configuredHeader,
   156  			ForbiddenStatusCode: configuredStatusCode,
   157  			Validator:           validator,
   158  		}
   159  
   160  		response  = httptest.NewRecorder()
   161  		request   = httptest.NewRequest("GET", "/", nil)
   162  		decorated = handler.Decorate(next)
   163  	)
   164  
   165  	require.NotNil(decorated)
   166  	request.Header.Set(expectedHeader, "Basic YWxsYWRpbjpvcGVuc2VzYW1l")
   167  
   168  	validator.On("Validate", mock.MatchedBy(func(context.Context) bool { return true }), mock.MatchedBy(func(*secure.Token) bool { return true })).Return(false, error(nil)).Once()
   169  	decorated.ServeHTTP(response, request)
   170  	assert.Equal(expectedStatusCode, response.Code)
   171  	assert.False(nextCalled)
   172  	validator.AssertExpectations(t)
   173  }
   174  
   175  func TestAuthorizationHandler(t *testing.T) {
   176  	t.Run("NoDecoration", testAuthorizationHandlerNoDecoration)
   177  
   178  	t.Run("NoAuthorization", func(t *testing.T) {
   179  		testData := []struct {
   180  			expectedStatusCode   int
   181  			configuredStatusCode int
   182  		}{
   183  			{http.StatusForbidden, 0},
   184  			{http.StatusForbidden, http.StatusForbidden},
   185  			{599, 599},
   186  		}
   187  
   188  		for i, record := range testData {
   189  			t.Run(strconv.Itoa(i), func(t *testing.T) {
   190  				testAuthorizationHandlerNoAuthorization(t, record.expectedStatusCode, record.configuredStatusCode)
   191  			})
   192  		}
   193  	})
   194  
   195  	t.Run("MalformedAuthorization", func(t *testing.T) {
   196  		testData := []struct {
   197  			expectedStatusCode   int
   198  			configuredStatusCode int
   199  			expectedHeader       string
   200  			configuredHeader     string
   201  		}{
   202  			{http.StatusForbidden, 0, "Authorization", ""},
   203  			{http.StatusForbidden, http.StatusForbidden, "Authorization", "Authorization"},
   204  			{599, 599, "X-Custom", "X-Custom"},
   205  		}
   206  
   207  		for i, record := range testData {
   208  			t.Run(strconv.Itoa(i), func(t *testing.T) {
   209  				testAuthorizationHandlerMalformedAuthorization(t, record.expectedStatusCode, record.configuredStatusCode, record.expectedHeader, record.configuredHeader)
   210  			})
   211  		}
   212  	})
   213  
   214  	t.Run("Valid", func(t *testing.T) {
   215  		testData := []struct {
   216  			expectedHeader   string
   217  			configuredHeader string
   218  		}{
   219  			{"Authorization", ""},
   220  			{"Authorization", "Authorization"},
   221  			{"X-Custom", "X-Custom"},
   222  		}
   223  
   224  		for i, record := range testData {
   225  			t.Run(strconv.Itoa(i), func(t *testing.T) {
   226  				testAuthorizationHandlerValid(t, record.expectedHeader, record.configuredHeader)
   227  			})
   228  		}
   229  	})
   230  
   231  	t.Run("Invalid", func(t *testing.T) {
   232  		testData := []struct {
   233  			expectedStatusCode   int
   234  			configuredStatusCode int
   235  			expectedHeader       string
   236  			configuredHeader     string
   237  		}{
   238  			{http.StatusForbidden, 0, "Authorization", ""},
   239  			{http.StatusForbidden, http.StatusForbidden, "Authorization", "Authorization"},
   240  			{599, 599, "X-Custom", "X-Custom"},
   241  		}
   242  
   243  		for i, record := range testData {
   244  			t.Run(strconv.Itoa(i), func(t *testing.T) {
   245  				testAuthorizationHandlerInvalid(t, record.expectedStatusCode, record.configuredStatusCode, record.expectedHeader, record.configuredHeader)
   246  			})
   247  		}
   248  	})
   249  }
   250  
   251  func testPopulateContextValuesNoJWT(t *testing.T) {
   252  	var (
   253  		assert  = assert.New(t)
   254  		require = require.New(t)
   255  
   256  		token, err = secure.ParseAuthorization("Basic abcd==")
   257  	)
   258  
   259  	require.NoError(err)
   260  	require.NotNil(token)
   261  
   262  	values := new(ContextValues)
   263  	assert.NoError(populateContextValues(token, values))
   264  }
   265  
   266  func TestPopulateContextValues(t *testing.T) {
   267  	t.Run("NoJWT", testPopulateContextValuesNoJWT)
   268  }
   269  
   270  //A simple verification that a pointer function signature is used
   271  func TestDefineMeasures(t *testing.T) {
   272  	assert := assert.New(t)
   273  	a, m := AuthorizationHandler{}, &secure.JWTValidationMeasures{}
   274  	a.DefineMeasures(m)
   275  	assert.Equal(m, a.measures)
   276  }