github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/identity/middleware_test.go (about) 1 // Copyright 2023 Northern.tech AS 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 package identity 15 16 import ( 17 "encoding/base64" 18 "encoding/json" 19 "fmt" 20 "net/http" 21 "net/http/httptest" 22 "testing" 23 "time" 24 25 "github.com/ant0ine/go-json-rest/rest" 26 "github.com/ant0ine/go-json-rest/rest/test" 27 "github.com/gin-gonic/gin" 28 "github.com/stretchr/testify/assert" 29 30 "github.com/mendersoftware/go-lib-micro/log" 31 urest "github.com/mendersoftware/go-lib-micro/rest.utils" 32 ) 33 34 func init() { 35 gin.SetMode(gin.ReleaseMode) 36 } 37 38 func makeFakeAuth(idty Identity) string { 39 b, _ := json.Marshal(idty) 40 claims := base64.RawURLEncoding.EncodeToString(b) 41 return "aGVhZGVy." + claims + ".c2lnbg" 42 } 43 44 func TestGinMiddleware(t *testing.T) { 45 testCases := []struct { 46 Name string 47 48 Request *http.Request 49 Options *MiddlewareOptions 50 51 Validator func(t *testing.T, 52 w *httptest.ResponseRecorder, req *http.Request, 53 ) 54 }{{ 55 Name: "ok, user", 56 Request: func() *http.Request { 57 req, _ := http.NewRequest("GET", 58 "http://localhost/api/management/v1/test?foo=bar", 59 nil, 60 ) 61 req.Header.Set("Authorization", 62 "Bearer "+makeFakeAuth(Identity{ 63 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 64 Tenant: "123456789012345678901234", 65 IsUser: true, 66 Plan: "professional", 67 }), 68 ) 69 return req 70 }(), 71 72 Validator: func(t *testing.T, 73 w *httptest.ResponseRecorder, req *http.Request, 74 ) { 75 ctx := req.Context() 76 expected := &Identity{ 77 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 78 Tenant: "123456789012345678901234", 79 IsUser: true, 80 Plan: "professional", 81 } 82 actual := FromContext(ctx) 83 assert.EqualValues(t, expected, actual) 84 logger := log.FromContext(ctx) 85 assert.Equal(t, 86 "3e955f9d-53bf-47d6-a182-ff27b2c96282", 87 logger.Entry.Data["user_id"], 88 ) 89 assert.Equal(t, 90 "123456789012345678901234", 91 logger.Entry.Data["tenant_id"], 92 ) 93 assert.Equal(t, 94 "professional", 95 logger.Entry.Data["plan"], 96 ) 97 }, 98 }, { 99 Name: "ok, device", 100 Request: func() *http.Request { 101 req, _ := http.NewRequest("GET", 102 "http://localhost/api/management/v1/test?foo=bar", 103 nil, 104 ) 105 req.Header.Set("Authorization", 106 "Bearer "+makeFakeAuth(Identity{ 107 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 108 Tenant: "123456789012345678901234", 109 IsDevice: true, 110 }), 111 ) 112 return req 113 }(), 114 115 Validator: func(t *testing.T, 116 w *httptest.ResponseRecorder, req *http.Request, 117 ) { 118 ctx := req.Context() 119 expected := &Identity{ 120 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 121 Tenant: "123456789012345678901234", 122 IsDevice: true, 123 } 124 actual := FromContext(ctx) 125 assert.EqualValues(t, expected, actual) 126 logger := log.FromContext(ctx) 127 assert.Equal(t, 128 "3e955f9d-53bf-47d6-a182-ff27b2c96282", 129 logger.Entry.Data["device_id"], 130 ) 131 assert.Equal(t, 132 "123456789012345678901234", 133 logger.Entry.Data["tenant_id"], 134 ) 135 }, 136 }, { 137 Name: "ok, with option override", 138 Request: func() *http.Request { 139 req, _ := http.NewRequest("GET", 140 "http://localhost/api/management/v1/test?foo=bar", 141 nil, 142 ) 143 req.Header.Set("Authorization", 144 "Bearer "+makeFakeAuth(Identity{ 145 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 146 Tenant: "123456789012345678901234", 147 }), 148 ) 149 return req 150 }(), 151 Options: NewMiddlewareOptions(). 152 SetPathRegex("^/api/management/v1/test$"). 153 SetUpdateLogger(false), 154 155 Validator: func(t *testing.T, 156 w *httptest.ResponseRecorder, req *http.Request, 157 ) { 158 ctx := req.Context() 159 expected := &Identity{ 160 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 161 Tenant: "123456789012345678901234", 162 } 163 actual := FromContext(ctx) 164 assert.EqualValues(t, expected, actual) 165 logger := log.FromContext(ctx) 166 assert.Empty(t, logger.Entry.Data) 167 }, 168 }, { 169 Name: "ok, path does not match", 170 Request: func() *http.Request { 171 req, _ := http.NewRequest("GET", 172 "http://localhost/api/management/", 173 nil, 174 ) 175 req.Header.Set("Authorization", 176 "Bearer "+makeFakeAuth(Identity{ 177 Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282", 178 Tenant: "123456789012345678901234", 179 }), 180 ) 181 return req 182 }(), 183 Options: NewMiddlewareOptions(). 184 SetPathRegex("^/api/management/v1/test$"), 185 186 Validator: func(t *testing.T, 187 w *httptest.ResponseRecorder, req *http.Request, 188 ) { 189 ctx := req.Context() 190 actual := FromContext(ctx) 191 assert.Nil(t, actual) 192 logger := log.FromContext(ctx) 193 assert.Empty(t, logger.Entry.Data) 194 }, 195 }, { 196 Name: "error, token not present (w/logger)", 197 Request: func() *http.Request { 198 req, _ := http.NewRequest("GET", 199 "http://localhost/api/management/v1/test", 200 nil, 201 ) 202 return req 203 }(), 204 Options: NewMiddlewareOptions(). 205 SetPathRegex("^/api/management/v1/test$"), 206 207 Validator: func(t *testing.T, 208 w *httptest.ResponseRecorder, req *http.Request, 209 ) { 210 assert.Equal(t, 401, w.Code) 211 var apiErr urest.Error 212 _ = json.Unmarshal(w.Body.Bytes(), &apiErr) 213 assert.EqualError(t, 214 apiErr, 215 "Authorization not present in header", 216 ) 217 }, 218 }, { 219 Name: "error, token malformed (w/logger)", 220 Request: func() *http.Request { 221 req, _ := http.NewRequest("GET", 222 "http://localhost/api/management/v1/test", 223 nil, 224 ) 225 req.Header.Set("Authorization", "Bearer bruh?==") 226 return req 227 }(), 228 Options: NewMiddlewareOptions(). 229 SetPathRegex("^/api/management/v1/test$"), 230 231 Validator: func(t *testing.T, 232 w *httptest.ResponseRecorder, req *http.Request, 233 ) { 234 assert.Equal(t, 401, w.Code) 235 var apiErr urest.Error 236 _ = json.Unmarshal(w.Body.Bytes(), &apiErr) 237 assert.EqualError(t, 238 apiErr, 239 "identity: incorrect token format", 240 ) 241 }, 242 }, { 243 Name: "error, token not present (base middleware)", 244 Request: func() *http.Request { 245 req, _ := http.NewRequest("GET", 246 "http://localhost/api/management/v1/test", 247 nil, 248 ) 249 return req 250 }(), 251 Options: NewMiddlewareOptions(). 252 SetUpdateLogger(false), 253 254 Validator: func(t *testing.T, 255 w *httptest.ResponseRecorder, req *http.Request, 256 ) { 257 assert.Equal(t, 401, w.Code) 258 var apiErr urest.Error 259 _ = json.Unmarshal(w.Body.Bytes(), &apiErr) 260 assert.EqualError(t, 261 apiErr, 262 "Authorization not present in header", 263 ) 264 }, 265 }, { 266 Name: "error, token malformed (base middleware)", 267 Request: func() *http.Request { 268 req, _ := http.NewRequest("GET", 269 "http://localhost/api/management/v1/test", 270 nil, 271 ) 272 req.Header.Set("Authorization", "Bearer bruh?==") 273 return req 274 }(), 275 Options: NewMiddlewareOptions(). 276 SetUpdateLogger(false), 277 278 Validator: func(t *testing.T, 279 w *httptest.ResponseRecorder, req *http.Request, 280 ) { 281 assert.Equal(t, 401, w.Code) 282 var apiErr urest.Error 283 _ = json.Unmarshal(w.Body.Bytes(), &apiErr) 284 assert.EqualError(t, 285 apiErr, 286 "identity: incorrect token format", 287 ) 288 }, 289 }} 290 291 for i := range testCases { 292 tc := testCases[i] 293 t.Run(tc.Name, func(t *testing.T) { 294 t.Parallel() 295 reqChan := make(chan *http.Request, 1) 296 router := gin.New() 297 router.Use(func(c *gin.Context) { 298 c.Next() 299 c.Writer.Flush() 300 reqChan <- c.Request 301 }) 302 router.Use(Middleware(tc.Options)) 303 router.GET("/api/management/v1/test", func(c *gin.Context) { 304 c.Status(200) 305 }) 306 router.NoRoute(func(c *gin.Context) { 307 c.Status(200) 308 }) 309 310 w := httptest.NewRecorder() 311 router.ServeHTTP(w, tc.Request) 312 313 var req *http.Request 314 select { 315 case req = <-reqChan: 316 tc.Validator(t, w, req) 317 case <-time.After(time.Second): 318 panic("[PROG ERR] Bad test case") 319 } 320 }) 321 } 322 323 } 324 325 func TestIdentityMiddlewareNoIdentity(t *testing.T) { 326 api := rest.NewApi() 327 328 api.Use(&IdentityMiddleware{}) 329 330 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 331 ctxIdentity := FromContext(r.Context()) 332 assert.Empty(t, ctxIdentity) 333 w.WriteJson(map[string]string{"foo": "bar"}) 334 })) 335 336 handler := api.MakeHandler() 337 338 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 339 340 recorded := test.RunRequest(t, handler, req) 341 recorded.CodeIs(200) 342 recorded.ContentTypeIsJson() 343 } 344 345 func TestIdentityMiddlewareNoSubject(t *testing.T) { 346 api := rest.NewApi() 347 348 api.Use(&IdentityMiddleware{}) 349 350 identity := Identity{ 351 Tenant: "bar", 352 } 353 354 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 355 ctxIdentity := FromContext(r.Context()) 356 assert.Empty(t, ctxIdentity) 357 w.WriteJson(map[string]string{"foo": "bar"}) 358 })) 359 360 handler := api.MakeHandler() 361 362 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 363 rawclaims := makeClaimsPart(identity.Subject, identity.Tenant, identity.Plan) 364 req.Header.Set("Authorization", "Bearer foo."+rawclaims+".bar") 365 366 recorded := test.RunRequest(t, handler, req) 367 recorded.CodeIs(200) 368 recorded.ContentTypeIsJson() 369 } 370 371 func TestIdentityMiddlewareNoTenant(t *testing.T) { 372 api := rest.NewApi() 373 374 api.Use(&IdentityMiddleware{}) 375 376 identity := Identity{ 377 Subject: "foo", 378 } 379 380 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 381 ctxIdentity := FromContext(r.Context()) 382 assert.Equal(t, &identity, ctxIdentity) 383 w.WriteJson(map[string]string{"foo": "bar"}) 384 })) 385 386 handler := api.MakeHandler() 387 388 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 389 rawclaims := makeClaimsPart(identity.Subject, identity.Tenant, identity.Plan) 390 req.Header.Set("Authorization", "Bearer foo."+rawclaims+".bar") 391 392 recorded := test.RunRequest(t, handler, req) 393 recorded.CodeIs(200) 394 recorded.ContentTypeIsJson() 395 } 396 397 func TestIdentityMiddleware(t *testing.T) { 398 api := rest.NewApi() 399 400 api.Use(&IdentityMiddleware{}) 401 402 identity := Identity{ 403 Subject: "foo", 404 Tenant: "bar", 405 Plan: "os", 406 } 407 408 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 409 ctxIdentity := FromContext(r.Context()) 410 assert.Equal(t, &identity, ctxIdentity) 411 w.WriteJson(map[string]string{"foo": "bar"}) 412 })) 413 414 handler := api.MakeHandler() 415 416 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 417 rawclaims := makeClaimsPart(identity.Subject, identity.Tenant, identity.Plan) 418 req.Header.Set("Authorization", "Bearer foo."+rawclaims+".bar") 419 420 recorded := test.RunRequest(t, handler, req) 421 recorded.CodeIs(200) 422 recorded.ContentTypeIsJson() 423 } 424 425 func TestIdentityMiddlewareDevice(t *testing.T) { 426 testCases := []struct { 427 identity Identity 428 mw *IdentityMiddleware 429 logFields map[string]interface{} 430 }{ 431 { 432 identity: Identity{ 433 Subject: "device-1", 434 Tenant: "bar", 435 Plan: "os", 436 IsDevice: true, 437 }, 438 mw: &IdentityMiddleware{ 439 UpdateLogger: true, 440 }, 441 logFields: map[string]interface{}{ 442 "device_id": "device-1", 443 "tenant_id": "bar", 444 "plan": "os", 445 }, 446 }, 447 { 448 identity: Identity{ 449 Subject: "user-1", 450 Tenant: "bar", 451 Plan: "os", 452 IsUser: true, 453 }, 454 mw: &IdentityMiddleware{ 455 UpdateLogger: true, 456 }, 457 logFields: map[string]interface{}{ 458 "user_id": "user-1", 459 "tenant_id": "bar", 460 "plan": "os", 461 }, 462 }, 463 { 464 identity: Identity{ 465 Subject: "not-a-user-not-a-device", 466 Tenant: "bar", 467 Plan: "os", 468 }, 469 mw: &IdentityMiddleware{ 470 UpdateLogger: true, 471 }, 472 logFields: map[string]interface{}{ 473 "sub": "not-a-user-not-a-device", 474 "tenant_id": "bar", 475 "plan": "os", 476 }, 477 }, 478 { 479 identity: Identity{ 480 Subject: "123-dobby-has-no-master", 481 IsDevice: true, 482 }, 483 mw: &IdentityMiddleware{ 484 UpdateLogger: true, 485 }, 486 logFields: map[string]interface{}{ 487 "device_id": "123-dobby-has-no-master", 488 "tenant_id": nil, 489 }, 490 }, 491 } 492 493 for idx := range testCases { 494 tc := testCases[idx] 495 t.Run(fmt.Sprintf("tc %d", idx), func(t *testing.T) { 496 api := rest.NewApi() 497 498 api.Use(tc.mw) 499 500 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 501 ctxIdentity := FromContext(r.Context()) 502 503 assert.Equal(t, &tc.identity, ctxIdentity) 504 505 l := log.FromContext(r.Context()) 506 l.Infof("foobar") 507 for f, v := range tc.logFields { 508 assert.Equal(t, v, l.Data[f]) 509 } 510 w.WriteJson(map[string]string{"foo": "bar"}) 511 })) 512 513 handler := api.MakeHandler() 514 515 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 516 517 claims := makeClaimsFull(tc.identity.Subject, tc.identity.Tenant, tc.identity.Plan, 518 tc.identity.IsDevice, tc.identity.IsUser, false) 519 req.Header.Set("Authorization", "Bearer foo."+claims+".bar") 520 521 recorded := test.RunRequest(t, handler, req) 522 recorded.CodeIs(200) 523 recorded.ContentTypeIsJson() 524 }) 525 } 526 }