github.com/weaviate/weaviate@v1.24.6/usecases/auth/authentication/oidc/middleware_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package oidc 13 14 import ( 15 "fmt" 16 "testing" 17 "time" 18 19 errors "github.com/go-openapi/errors" 20 "github.com/golang-jwt/jwt/v4" 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 "github.com/weaviate/weaviate/usecases/config" 24 ) 25 26 func Test_Middleware_NotConfigured(t *testing.T) { 27 cfg := config.Config{ 28 Authentication: config.Authentication{ 29 OIDC: config.OIDC{ 30 Enabled: false, 31 }, 32 }, 33 } 34 expectedErr := errors.New(401, "oidc auth is not configured, please try another auth scheme or set up weaviate with OIDC configured") 35 36 client, err := New(cfg) 37 require.Nil(t, err) 38 39 principal, err := client.ValidateAndExtract("token-doesnt-matter", []string{}) 40 assert.Nil(t, principal) 41 assert.Equal(t, expectedErr, err) 42 } 43 44 func Test_Middleware_IncompleteConfiguration(t *testing.T) { 45 cfg := config.Config{ 46 Authentication: config.Authentication{ 47 OIDC: config.OIDC{ 48 Enabled: true, 49 }, 50 }, 51 } 52 expectedErr := fmt.Errorf("oidc init: invalid config: missing required field 'issuer', " + 53 "missing required field 'username_claim', missing required field 'client_id': either set a client_id or explicitly disable the check with 'skip_client_id_check: true'") 54 55 _, err := New(cfg) 56 assert.Equal(t, expectedErr, err) 57 } 58 59 type claims struct { 60 jwt.StandardClaims 61 Email string `json:"email"` 62 Groups []string `json:"groups"` 63 } 64 65 func Test_Middleware_WithValidToken(t *testing.T) { 66 t.Run("without groups set", func(t *testing.T) { 67 server := newOIDCServer(t) 68 defer server.Close() 69 70 cfg := config.Config{ 71 Authentication: config.Authentication{ 72 OIDC: config.OIDC{ 73 Enabled: true, 74 Issuer: server.URL, 75 ClientID: "best_client", 76 SkipClientIDCheck: false, 77 UsernameClaim: "sub", 78 }, 79 }, 80 } 81 82 token := token(t, "best-user", server.URL, "best_client") 83 client, err := New(cfg) 84 require.Nil(t, err) 85 86 principal, err := client.ValidateAndExtract(token, []string{}) 87 require.Nil(t, err) 88 assert.Equal(t, "best-user", principal.Username) 89 }) 90 91 t.Run("with a non-standard username claim", func(t *testing.T) { 92 server := newOIDCServer(t) 93 defer server.Close() 94 95 cfg := config.Config{ 96 Authentication: config.Authentication{ 97 OIDC: config.OIDC{ 98 Enabled: true, 99 Issuer: server.URL, 100 ClientID: "best_client", 101 SkipClientIDCheck: false, 102 UsernameClaim: "email", 103 GroupsClaim: "groups", 104 }, 105 }, 106 } 107 108 token := tokenWithEmail(t, "best-user", server.URL, "best_client", "foo@bar.com") 109 client, err := New(cfg) 110 require.Nil(t, err) 111 112 principal, err := client.ValidateAndExtract(token, []string{}) 113 require.Nil(t, err) 114 assert.Equal(t, "foo@bar.com", principal.Username) 115 }) 116 117 t.Run("with groups claim", func(t *testing.T) { 118 server := newOIDCServer(t) 119 defer server.Close() 120 121 cfg := config.Config{ 122 Authentication: config.Authentication{ 123 OIDC: config.OIDC{ 124 Enabled: true, 125 Issuer: server.URL, 126 ClientID: "best_client", 127 SkipClientIDCheck: false, 128 UsernameClaim: "sub", 129 GroupsClaim: "groups", 130 }, 131 }, 132 } 133 134 token := tokenWithGroups(t, "best-user", server.URL, "best_client", []string{"group1", "group2"}) 135 client, err := New(cfg) 136 require.Nil(t, err) 137 138 principal, err := client.ValidateAndExtract(token, []string{}) 139 require.Nil(t, err) 140 assert.Equal(t, "best-user", principal.Username) 141 assert.Equal(t, []string{"group1", "group2"}, principal.Groups) 142 }) 143 } 144 145 func token(t *testing.T, subject string, issuer string, aud string) string { 146 return tokenWithEmail(t, subject, issuer, aud, "") 147 } 148 149 func tokenWithEmail(t *testing.T, subject string, issuer string, aud string, email string) string { 150 claims := claims{ 151 Email: email, 152 } 153 154 return tokenWithClaims(t, subject, issuer, aud, claims) 155 } 156 157 func tokenWithGroups(t *testing.T, subject string, issuer string, aud string, groups []string) string { 158 claims := claims{ 159 Groups: groups, 160 } 161 162 return tokenWithClaims(t, subject, issuer, aud, claims) 163 } 164 165 func tokenWithClaims(t *testing.T, subject string, issuer string, aud string, claims claims) string { 166 //nolint:staticcheck // is deprecated, but for the purpose of this test, this doesn't matter 167 claims.StandardClaims = jwt.StandardClaims{ 168 Subject: subject, 169 Issuer: issuer, 170 Audience: aud, 171 ExpiresAt: time.Now().Add(10 * time.Second).Unix(), 172 } 173 174 token, err := signToken(claims) 175 require.Nil(t, err, "signing token should not error") 176 177 return token 178 }