github.com/xmidt-org/webpa-common@v1.11.9/secure/handler/authorizationHandler_test.go (about) 1 package handler 2 3 import ( 4 "context" 5 "net/http" 6 "net/http/httptest" 7 "strconv" 8 "testing" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/mock" 12 "github.com/stretchr/testify/require" 13 "github.com/xmidt-org/webpa-common/logging" 14 "github.com/xmidt-org/webpa-common/secure" 15 ) 16 17 const ( 18 authorizationValue = "Basic dGVzdDp0ZXN0Cg==" 19 tokenValue = "dGVzdDp0ZXN0Cg==" 20 ) 21 22 func testAuthorizationHandlerNoDecoration(t *testing.T) { 23 var ( 24 assert = assert.New(t) 25 require = require.New(t) 26 27 nextCalled = false 28 next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) { 29 nextCalled = true 30 }) 31 32 handler = AuthorizationHandler{ 33 Logger: logging.NewTestLogger(nil, t), 34 } 35 36 decorated = handler.Decorate(next) 37 ) 38 39 require.NotNil(decorated) 40 decorated.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil)) 41 assert.True(nextCalled) 42 } 43 44 func testAuthorizationHandlerNoAuthorization(t *testing.T, expectedStatusCode, configuredStatusCode int) { 45 var ( 46 assert = assert.New(t) 47 require = require.New(t) 48 49 nextCalled = false 50 next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) { 51 nextCalled = true 52 }) 53 54 validator = new(secure.MockValidator) 55 handler = AuthorizationHandler{ 56 Logger: logging.NewTestLogger(nil, t), 57 ForbiddenStatusCode: configuredStatusCode, 58 Validator: validator, 59 } 60 61 response = httptest.NewRecorder() 62 request = httptest.NewRequest("GET", "/", nil) 63 decorated = handler.Decorate(next) 64 ) 65 66 require.NotNil(decorated) 67 decorated.ServeHTTP(response, request) 68 assert.Equal(expectedStatusCode, response.Code) 69 assert.False(nextCalled) 70 validator.AssertExpectations(t) 71 } 72 73 func testAuthorizationHandlerMalformedAuthorization(t *testing.T, expectedStatusCode, configuredStatusCode int, expectedHeader, configuredHeader string) { 74 var ( 75 assert = assert.New(t) 76 require = require.New(t) 77 78 nextCalled = false 79 next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) { 80 nextCalled = true 81 }) 82 83 validator = new(secure.MockValidator) 84 handler = AuthorizationHandler{ 85 Logger: logging.NewTestLogger(nil, t), 86 HeaderName: configuredHeader, 87 ForbiddenStatusCode: configuredStatusCode, 88 Validator: validator, 89 } 90 91 response = httptest.NewRecorder() 92 request = httptest.NewRequest("GET", "/", nil) 93 decorated = handler.Decorate(next) 94 ) 95 96 require.NotNil(decorated) 97 request.Header.Set(expectedHeader, "there is no way this is a valid authorization header") 98 decorated.ServeHTTP(response, request) 99 assert.Equal(expectedStatusCode, response.Code) 100 assert.False(nextCalled) 101 validator.AssertExpectations(t) 102 } 103 104 func testAuthorizationHandlerValid(t *testing.T, expectedHeader, configuredHeader string) { 105 var ( 106 assert = assert.New(t) 107 require = require.New(t) 108 109 nextCalled = false 110 next = http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) { 111 nextCalled = true 112 values, ok := FromContext(request.Context()) 113 require.True(ok) 114 require.NotNil(values) 115 116 assert.Equal("x1:webpa-internal:5f0183", values.SatClientID) 117 assert.Equal([]string{"comcast"}, values.PartnerIDs) 118 }) 119 120 validator = new(secure.MockValidator) 121 handler = AuthorizationHandler{ 122 Logger: logging.NewTestLogger(nil, t), 123 HeaderName: configuredHeader, 124 Validator: validator, 125 } 126 127 response = httptest.NewRecorder() 128 request = httptest.NewRequest("GET", "/", nil) 129 decorated = handler.Decorate(next) 130 ) 131 132 require.NotNil(decorated) 133 request.Header.Set(expectedHeader, "Bearer eyJhbGciOiJub25lIiwia2lkIjoidGhlbWlzLTIwMTcwMSIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI4ZjA0MmIyOS03ZDE2LTRjMWYtYjBmOS1mNTJhMGFhZDI5YmMiLCJpc3MiOiJzYXRzLXByb2R1Y3Rpb24iLCJzdWIiOiJ4MTp3ZWJwYS1pbnRlcm5hbDo1ZjAxODMiLCJpYXQiOjE1Mjc3MzAwOTYsIm5iZiI6MTUyNzczMDA5NiwiZXhwIjoxNTI3NzczMjk5LCJ2ZXJzaW9uIjoiMS4wIiwiYWxsb3dlZFJlc291cmNlcyI6eyJhbGxvd2VkUGFydG5lcnMiOlsiY29tY2FzdCJdfSwiY2FwYWJpbGl0aWVzIjpbIngxOndlYnBhOmFwaTouKjphbGwiXSwiYXVkIjpbXX0.") 134 135 validator.On("Validate", mock.MatchedBy(func(context.Context) bool { return true }), mock.MatchedBy(func(*secure.Token) bool { return true })).Return(true, error(nil)).Once() 136 decorated.ServeHTTP(response, request) 137 assert.Equal(200, response.Code) 138 assert.True(nextCalled) 139 validator.AssertExpectations(t) 140 } 141 142 func testAuthorizationHandlerInvalid(t *testing.T, expectedStatusCode, configuredStatusCode int, expectedHeader, configuredHeader string) { 143 var ( 144 assert = assert.New(t) 145 require = require.New(t) 146 147 nextCalled = false 148 next = http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) { 149 nextCalled = true 150 }) 151 152 validator = new(secure.MockValidator) 153 handler = AuthorizationHandler{ 154 Logger: logging.NewTestLogger(nil, t), 155 HeaderName: configuredHeader, 156 ForbiddenStatusCode: configuredStatusCode, 157 Validator: validator, 158 } 159 160 response = httptest.NewRecorder() 161 request = httptest.NewRequest("GET", "/", nil) 162 decorated = handler.Decorate(next) 163 ) 164 165 require.NotNil(decorated) 166 request.Header.Set(expectedHeader, "Basic YWxsYWRpbjpvcGVuc2VzYW1l") 167 168 validator.On("Validate", mock.MatchedBy(func(context.Context) bool { return true }), mock.MatchedBy(func(*secure.Token) bool { return true })).Return(false, error(nil)).Once() 169 decorated.ServeHTTP(response, request) 170 assert.Equal(expectedStatusCode, response.Code) 171 assert.False(nextCalled) 172 validator.AssertExpectations(t) 173 } 174 175 func TestAuthorizationHandler(t *testing.T) { 176 t.Run("NoDecoration", testAuthorizationHandlerNoDecoration) 177 178 t.Run("NoAuthorization", func(t *testing.T) { 179 testData := []struct { 180 expectedStatusCode int 181 configuredStatusCode int 182 }{ 183 {http.StatusForbidden, 0}, 184 {http.StatusForbidden, http.StatusForbidden}, 185 {599, 599}, 186 } 187 188 for i, record := range testData { 189 t.Run(strconv.Itoa(i), func(t *testing.T) { 190 testAuthorizationHandlerNoAuthorization(t, record.expectedStatusCode, record.configuredStatusCode) 191 }) 192 } 193 }) 194 195 t.Run("MalformedAuthorization", func(t *testing.T) { 196 testData := []struct { 197 expectedStatusCode int 198 configuredStatusCode int 199 expectedHeader string 200 configuredHeader string 201 }{ 202 {http.StatusForbidden, 0, "Authorization", ""}, 203 {http.StatusForbidden, http.StatusForbidden, "Authorization", "Authorization"}, 204 {599, 599, "X-Custom", "X-Custom"}, 205 } 206 207 for i, record := range testData { 208 t.Run(strconv.Itoa(i), func(t *testing.T) { 209 testAuthorizationHandlerMalformedAuthorization(t, record.expectedStatusCode, record.configuredStatusCode, record.expectedHeader, record.configuredHeader) 210 }) 211 } 212 }) 213 214 t.Run("Valid", func(t *testing.T) { 215 testData := []struct { 216 expectedHeader string 217 configuredHeader string 218 }{ 219 {"Authorization", ""}, 220 {"Authorization", "Authorization"}, 221 {"X-Custom", "X-Custom"}, 222 } 223 224 for i, record := range testData { 225 t.Run(strconv.Itoa(i), func(t *testing.T) { 226 testAuthorizationHandlerValid(t, record.expectedHeader, record.configuredHeader) 227 }) 228 } 229 }) 230 231 t.Run("Invalid", func(t *testing.T) { 232 testData := []struct { 233 expectedStatusCode int 234 configuredStatusCode int 235 expectedHeader string 236 configuredHeader string 237 }{ 238 {http.StatusForbidden, 0, "Authorization", ""}, 239 {http.StatusForbidden, http.StatusForbidden, "Authorization", "Authorization"}, 240 {599, 599, "X-Custom", "X-Custom"}, 241 } 242 243 for i, record := range testData { 244 t.Run(strconv.Itoa(i), func(t *testing.T) { 245 testAuthorizationHandlerInvalid(t, record.expectedStatusCode, record.configuredStatusCode, record.expectedHeader, record.configuredHeader) 246 }) 247 } 248 }) 249 } 250 251 func testPopulateContextValuesNoJWT(t *testing.T) { 252 var ( 253 assert = assert.New(t) 254 require = require.New(t) 255 256 token, err = secure.ParseAuthorization("Basic abcd==") 257 ) 258 259 require.NoError(err) 260 require.NotNil(token) 261 262 values := new(ContextValues) 263 assert.NoError(populateContextValues(token, values)) 264 } 265 266 func TestPopulateContextValues(t *testing.T) { 267 t.Run("NoJWT", testPopulateContextValuesNoJWT) 268 } 269 270 //A simple verification that a pointer function signature is used 271 func TestDefineMeasures(t *testing.T) { 272 assert := assert.New(t) 273 a, m := AuthorizationHandler{}, &secure.JWTValidationMeasures{} 274 a.DefineMeasures(m) 275 assert.Equal(m, a.measures) 276 }