github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/identity/middleware_test.go (about)

     1  // Copyright 2023 Northern.tech AS
     2  //
     3  //	Licensed under the Apache License, Version 2.0 (the "License");
     4  //	you may not use this file except in compliance with the License.
     5  //	You may obtain a copy of the License at
     6  //
     7  //	    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //	Unless required by applicable law or agreed to in writing, software
    10  //	distributed under the License is distributed on an "AS IS" BASIS,
    11  //	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //	See the License for the specific language governing permissions and
    13  //	limitations under the License.
    14  package identity
    15  
    16  import (
    17  	"encoding/base64"
    18  	"encoding/json"
    19  	"fmt"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/ant0ine/go-json-rest/rest"
    26  	"github.com/ant0ine/go-json-rest/rest/test"
    27  	"github.com/gin-gonic/gin"
    28  	"github.com/stretchr/testify/assert"
    29  
    30  	"github.com/mendersoftware/go-lib-micro/log"
    31  	urest "github.com/mendersoftware/go-lib-micro/rest.utils"
    32  )
    33  
    34  func init() {
    35  	gin.SetMode(gin.ReleaseMode)
    36  }
    37  
    38  func makeFakeAuth(idty Identity) string {
    39  	b, _ := json.Marshal(idty)
    40  	claims := base64.RawURLEncoding.EncodeToString(b)
    41  	return "aGVhZGVy." + claims + ".c2lnbg"
    42  }
    43  
    44  func TestGinMiddleware(t *testing.T) {
    45  	testCases := []struct {
    46  		Name string
    47  
    48  		Request *http.Request
    49  		Options *MiddlewareOptions
    50  
    51  		Validator func(t *testing.T,
    52  			w *httptest.ResponseRecorder, req *http.Request,
    53  		)
    54  	}{{
    55  		Name: "ok, user",
    56  		Request: func() *http.Request {
    57  			req, _ := http.NewRequest("GET",
    58  				"http://localhost/api/management/v1/test?foo=bar",
    59  				nil,
    60  			)
    61  			req.Header.Set("Authorization",
    62  				"Bearer "+makeFakeAuth(Identity{
    63  					Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282",
    64  					Tenant:  "123456789012345678901234",
    65  					IsUser:  true,
    66  					Plan:    "professional",
    67  				}),
    68  			)
    69  			return req
    70  		}(),
    71  
    72  		Validator: func(t *testing.T,
    73  			w *httptest.ResponseRecorder, req *http.Request,
    74  		) {
    75  			ctx := req.Context()
    76  			expected := &Identity{
    77  				Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282",
    78  				Tenant:  "123456789012345678901234",
    79  				IsUser:  true,
    80  				Plan:    "professional",
    81  			}
    82  			actual := FromContext(ctx)
    83  			assert.EqualValues(t, expected, actual)
    84  			logger := log.FromContext(ctx)
    85  			assert.Equal(t,
    86  				"3e955f9d-53bf-47d6-a182-ff27b2c96282",
    87  				logger.Entry.Data["user_id"],
    88  			)
    89  			assert.Equal(t,
    90  				"123456789012345678901234",
    91  				logger.Entry.Data["tenant_id"],
    92  			)
    93  			assert.Equal(t,
    94  				"professional",
    95  				logger.Entry.Data["plan"],
    96  			)
    97  		},
    98  	}, {
    99  		Name: "ok, device",
   100  		Request: func() *http.Request {
   101  			req, _ := http.NewRequest("GET",
   102  				"http://localhost/api/management/v1/test?foo=bar",
   103  				nil,
   104  			)
   105  			req.Header.Set("Authorization",
   106  				"Bearer "+makeFakeAuth(Identity{
   107  					Subject:  "3e955f9d-53bf-47d6-a182-ff27b2c96282",
   108  					Tenant:   "123456789012345678901234",
   109  					IsDevice: true,
   110  				}),
   111  			)
   112  			return req
   113  		}(),
   114  
   115  		Validator: func(t *testing.T,
   116  			w *httptest.ResponseRecorder, req *http.Request,
   117  		) {
   118  			ctx := req.Context()
   119  			expected := &Identity{
   120  				Subject:  "3e955f9d-53bf-47d6-a182-ff27b2c96282",
   121  				Tenant:   "123456789012345678901234",
   122  				IsDevice: true,
   123  			}
   124  			actual := FromContext(ctx)
   125  			assert.EqualValues(t, expected, actual)
   126  			logger := log.FromContext(ctx)
   127  			assert.Equal(t,
   128  				"3e955f9d-53bf-47d6-a182-ff27b2c96282",
   129  				logger.Entry.Data["device_id"],
   130  			)
   131  			assert.Equal(t,
   132  				"123456789012345678901234",
   133  				logger.Entry.Data["tenant_id"],
   134  			)
   135  		},
   136  	}, {
   137  		Name: "ok, with option override",
   138  		Request: func() *http.Request {
   139  			req, _ := http.NewRequest("GET",
   140  				"http://localhost/api/management/v1/test?foo=bar",
   141  				nil,
   142  			)
   143  			req.Header.Set("Authorization",
   144  				"Bearer "+makeFakeAuth(Identity{
   145  					Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282",
   146  					Tenant:  "123456789012345678901234",
   147  				}),
   148  			)
   149  			return req
   150  		}(),
   151  		Options: NewMiddlewareOptions().
   152  			SetPathRegex("^/api/management/v1/test$").
   153  			SetUpdateLogger(false),
   154  
   155  		Validator: func(t *testing.T,
   156  			w *httptest.ResponseRecorder, req *http.Request,
   157  		) {
   158  			ctx := req.Context()
   159  			expected := &Identity{
   160  				Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282",
   161  				Tenant:  "123456789012345678901234",
   162  			}
   163  			actual := FromContext(ctx)
   164  			assert.EqualValues(t, expected, actual)
   165  			logger := log.FromContext(ctx)
   166  			assert.Empty(t, logger.Entry.Data)
   167  		},
   168  	}, {
   169  		Name: "ok, path does not match",
   170  		Request: func() *http.Request {
   171  			req, _ := http.NewRequest("GET",
   172  				"http://localhost/api/management/",
   173  				nil,
   174  			)
   175  			req.Header.Set("Authorization",
   176  				"Bearer "+makeFakeAuth(Identity{
   177  					Subject: "3e955f9d-53bf-47d6-a182-ff27b2c96282",
   178  					Tenant:  "123456789012345678901234",
   179  				}),
   180  			)
   181  			return req
   182  		}(),
   183  		Options: NewMiddlewareOptions().
   184  			SetPathRegex("^/api/management/v1/test$"),
   185  
   186  		Validator: func(t *testing.T,
   187  			w *httptest.ResponseRecorder, req *http.Request,
   188  		) {
   189  			ctx := req.Context()
   190  			actual := FromContext(ctx)
   191  			assert.Nil(t, actual)
   192  			logger := log.FromContext(ctx)
   193  			assert.Empty(t, logger.Entry.Data)
   194  		},
   195  	}, {
   196  		Name: "error, token not present (w/logger)",
   197  		Request: func() *http.Request {
   198  			req, _ := http.NewRequest("GET",
   199  				"http://localhost/api/management/v1/test",
   200  				nil,
   201  			)
   202  			return req
   203  		}(),
   204  		Options: NewMiddlewareOptions().
   205  			SetPathRegex("^/api/management/v1/test$"),
   206  
   207  		Validator: func(t *testing.T,
   208  			w *httptest.ResponseRecorder, req *http.Request,
   209  		) {
   210  			assert.Equal(t, 401, w.Code)
   211  			var apiErr urest.Error
   212  			_ = json.Unmarshal(w.Body.Bytes(), &apiErr)
   213  			assert.EqualError(t,
   214  				apiErr,
   215  				"Authorization not present in header",
   216  			)
   217  		},
   218  	}, {
   219  		Name: "error, token malformed (w/logger)",
   220  		Request: func() *http.Request {
   221  			req, _ := http.NewRequest("GET",
   222  				"http://localhost/api/management/v1/test",
   223  				nil,
   224  			)
   225  			req.Header.Set("Authorization", "Bearer bruh?==")
   226  			return req
   227  		}(),
   228  		Options: NewMiddlewareOptions().
   229  			SetPathRegex("^/api/management/v1/test$"),
   230  
   231  		Validator: func(t *testing.T,
   232  			w *httptest.ResponseRecorder, req *http.Request,
   233  		) {
   234  			assert.Equal(t, 401, w.Code)
   235  			var apiErr urest.Error
   236  			_ = json.Unmarshal(w.Body.Bytes(), &apiErr)
   237  			assert.EqualError(t,
   238  				apiErr,
   239  				"identity: incorrect token format",
   240  			)
   241  		},
   242  	}, {
   243  		Name: "error, token not present (base middleware)",
   244  		Request: func() *http.Request {
   245  			req, _ := http.NewRequest("GET",
   246  				"http://localhost/api/management/v1/test",
   247  				nil,
   248  			)
   249  			return req
   250  		}(),
   251  		Options: NewMiddlewareOptions().
   252  			SetUpdateLogger(false),
   253  
   254  		Validator: func(t *testing.T,
   255  			w *httptest.ResponseRecorder, req *http.Request,
   256  		) {
   257  			assert.Equal(t, 401, w.Code)
   258  			var apiErr urest.Error
   259  			_ = json.Unmarshal(w.Body.Bytes(), &apiErr)
   260  			assert.EqualError(t,
   261  				apiErr,
   262  				"Authorization not present in header",
   263  			)
   264  		},
   265  	}, {
   266  		Name: "error, token malformed (base middleware)",
   267  		Request: func() *http.Request {
   268  			req, _ := http.NewRequest("GET",
   269  				"http://localhost/api/management/v1/test",
   270  				nil,
   271  			)
   272  			req.Header.Set("Authorization", "Bearer bruh?==")
   273  			return req
   274  		}(),
   275  		Options: NewMiddlewareOptions().
   276  			SetUpdateLogger(false),
   277  
   278  		Validator: func(t *testing.T,
   279  			w *httptest.ResponseRecorder, req *http.Request,
   280  		) {
   281  			assert.Equal(t, 401, w.Code)
   282  			var apiErr urest.Error
   283  			_ = json.Unmarshal(w.Body.Bytes(), &apiErr)
   284  			assert.EqualError(t,
   285  				apiErr,
   286  				"identity: incorrect token format",
   287  			)
   288  		},
   289  	}}
   290  
   291  	for i := range testCases {
   292  		tc := testCases[i]
   293  		t.Run(tc.Name, func(t *testing.T) {
   294  			t.Parallel()
   295  			reqChan := make(chan *http.Request, 1)
   296  			router := gin.New()
   297  			router.Use(func(c *gin.Context) {
   298  				c.Next()
   299  				c.Writer.Flush()
   300  				reqChan <- c.Request
   301  			})
   302  			router.Use(Middleware(tc.Options))
   303  			router.GET("/api/management/v1/test", func(c *gin.Context) {
   304  				c.Status(200)
   305  			})
   306  			router.NoRoute(func(c *gin.Context) {
   307  				c.Status(200)
   308  			})
   309  
   310  			w := httptest.NewRecorder()
   311  			router.ServeHTTP(w, tc.Request)
   312  
   313  			var req *http.Request
   314  			select {
   315  			case req = <-reqChan:
   316  				tc.Validator(t, w, req)
   317  			case <-time.After(time.Second):
   318  				panic("[PROG ERR] Bad test case")
   319  			}
   320  		})
   321  	}
   322  
   323  }
   324  
   325  func TestIdentityMiddlewareNoIdentity(t *testing.T) {
   326  	api := rest.NewApi()
   327  
   328  	api.Use(&IdentityMiddleware{})
   329  
   330  	api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
   331  		ctxIdentity := FromContext(r.Context())
   332  		assert.Empty(t, ctxIdentity)
   333  		w.WriteJson(map[string]string{"foo": "bar"})
   334  	}))
   335  
   336  	handler := api.MakeHandler()
   337  
   338  	req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   339  
   340  	recorded := test.RunRequest(t, handler, req)
   341  	recorded.CodeIs(200)
   342  	recorded.ContentTypeIsJson()
   343  }
   344  
   345  func TestIdentityMiddlewareNoSubject(t *testing.T) {
   346  	api := rest.NewApi()
   347  
   348  	api.Use(&IdentityMiddleware{})
   349  
   350  	identity := Identity{
   351  		Tenant: "bar",
   352  	}
   353  
   354  	api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
   355  		ctxIdentity := FromContext(r.Context())
   356  		assert.Empty(t, ctxIdentity)
   357  		w.WriteJson(map[string]string{"foo": "bar"})
   358  	}))
   359  
   360  	handler := api.MakeHandler()
   361  
   362  	req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   363  	rawclaims := makeClaimsPart(identity.Subject, identity.Tenant, identity.Plan)
   364  	req.Header.Set("Authorization", "Bearer foo."+rawclaims+".bar")
   365  
   366  	recorded := test.RunRequest(t, handler, req)
   367  	recorded.CodeIs(200)
   368  	recorded.ContentTypeIsJson()
   369  }
   370  
   371  func TestIdentityMiddlewareNoTenant(t *testing.T) {
   372  	api := rest.NewApi()
   373  
   374  	api.Use(&IdentityMiddleware{})
   375  
   376  	identity := Identity{
   377  		Subject: "foo",
   378  	}
   379  
   380  	api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
   381  		ctxIdentity := FromContext(r.Context())
   382  		assert.Equal(t, &identity, ctxIdentity)
   383  		w.WriteJson(map[string]string{"foo": "bar"})
   384  	}))
   385  
   386  	handler := api.MakeHandler()
   387  
   388  	req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   389  	rawclaims := makeClaimsPart(identity.Subject, identity.Tenant, identity.Plan)
   390  	req.Header.Set("Authorization", "Bearer foo."+rawclaims+".bar")
   391  
   392  	recorded := test.RunRequest(t, handler, req)
   393  	recorded.CodeIs(200)
   394  	recorded.ContentTypeIsJson()
   395  }
   396  
   397  func TestIdentityMiddleware(t *testing.T) {
   398  	api := rest.NewApi()
   399  
   400  	api.Use(&IdentityMiddleware{})
   401  
   402  	identity := Identity{
   403  		Subject: "foo",
   404  		Tenant:  "bar",
   405  		Plan:    "os",
   406  	}
   407  
   408  	api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
   409  		ctxIdentity := FromContext(r.Context())
   410  		assert.Equal(t, &identity, ctxIdentity)
   411  		w.WriteJson(map[string]string{"foo": "bar"})
   412  	}))
   413  
   414  	handler := api.MakeHandler()
   415  
   416  	req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   417  	rawclaims := makeClaimsPart(identity.Subject, identity.Tenant, identity.Plan)
   418  	req.Header.Set("Authorization", "Bearer foo."+rawclaims+".bar")
   419  
   420  	recorded := test.RunRequest(t, handler, req)
   421  	recorded.CodeIs(200)
   422  	recorded.ContentTypeIsJson()
   423  }
   424  
   425  func TestIdentityMiddlewareDevice(t *testing.T) {
   426  	testCases := []struct {
   427  		identity  Identity
   428  		mw        *IdentityMiddleware
   429  		logFields map[string]interface{}
   430  	}{
   431  		{
   432  			identity: Identity{
   433  				Subject:  "device-1",
   434  				Tenant:   "bar",
   435  				Plan:     "os",
   436  				IsDevice: true,
   437  			},
   438  			mw: &IdentityMiddleware{
   439  				UpdateLogger: true,
   440  			},
   441  			logFields: map[string]interface{}{
   442  				"device_id": "device-1",
   443  				"tenant_id": "bar",
   444  				"plan":      "os",
   445  			},
   446  		},
   447  		{
   448  			identity: Identity{
   449  				Subject: "user-1",
   450  				Tenant:  "bar",
   451  				Plan:    "os",
   452  				IsUser:  true,
   453  			},
   454  			mw: &IdentityMiddleware{
   455  				UpdateLogger: true,
   456  			},
   457  			logFields: map[string]interface{}{
   458  				"user_id":   "user-1",
   459  				"tenant_id": "bar",
   460  				"plan":      "os",
   461  			},
   462  		},
   463  		{
   464  			identity: Identity{
   465  				Subject: "not-a-user-not-a-device",
   466  				Tenant:  "bar",
   467  				Plan:    "os",
   468  			},
   469  			mw: &IdentityMiddleware{
   470  				UpdateLogger: true,
   471  			},
   472  			logFields: map[string]interface{}{
   473  				"sub":       "not-a-user-not-a-device",
   474  				"tenant_id": "bar",
   475  				"plan":      "os",
   476  			},
   477  		},
   478  		{
   479  			identity: Identity{
   480  				Subject:  "123-dobby-has-no-master",
   481  				IsDevice: true,
   482  			},
   483  			mw: &IdentityMiddleware{
   484  				UpdateLogger: true,
   485  			},
   486  			logFields: map[string]interface{}{
   487  				"device_id": "123-dobby-has-no-master",
   488  				"tenant_id": nil,
   489  			},
   490  		},
   491  	}
   492  
   493  	for idx := range testCases {
   494  		tc := testCases[idx]
   495  		t.Run(fmt.Sprintf("tc %d", idx), func(t *testing.T) {
   496  			api := rest.NewApi()
   497  
   498  			api.Use(tc.mw)
   499  
   500  			api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
   501  				ctxIdentity := FromContext(r.Context())
   502  
   503  				assert.Equal(t, &tc.identity, ctxIdentity)
   504  
   505  				l := log.FromContext(r.Context())
   506  				l.Infof("foobar")
   507  				for f, v := range tc.logFields {
   508  					assert.Equal(t, v, l.Data[f])
   509  				}
   510  				w.WriteJson(map[string]string{"foo": "bar"})
   511  			}))
   512  
   513  			handler := api.MakeHandler()
   514  
   515  			req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   516  
   517  			claims := makeClaimsFull(tc.identity.Subject, tc.identity.Tenant, tc.identity.Plan,
   518  				tc.identity.IsDevice, tc.identity.IsUser, false)
   519  			req.Header.Set("Authorization", "Bearer foo."+claims+".bar")
   520  
   521  			recorded := test.RunRequest(t, handler, req)
   522  			recorded.CodeIs(200)
   523  			recorded.ContentTypeIsJson()
   524  		})
   525  	}
   526  }