github.com/cilium/cilium@v1.16.2/pkg/auth/manager_test.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package auth 5 6 import ( 7 "context" 8 "errors" 9 "net" 10 "testing" 11 "time" 12 13 "github.com/cilium/ebpf" 14 "github.com/sirupsen/logrus" 15 "github.com/stretchr/testify/assert" 16 "golang.org/x/exp/maps" 17 18 "github.com/cilium/cilium/api/v1/models" 19 "github.com/cilium/cilium/pkg/auth/certs" 20 "github.com/cilium/cilium/pkg/identity" 21 "github.com/cilium/cilium/pkg/policy" 22 ) 23 24 func Test_newAuthManager_clashingAuthHandlers(t *testing.T) { 25 authHandlers := []authHandler{ 26 &alwaysFailAuthHandler{}, 27 &alwaysFailAuthHandler{}, 28 } 29 30 am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second) 31 assert.ErrorContains(t, err, "multiple handlers for auth type: test-always-fail") 32 assert.Nil(t, am) 33 } 34 35 func Test_newAuthManager(t *testing.T) { 36 authHandlers := []authHandler{ 37 newAlwaysPassAuthHandler(logrus.New()), 38 &fakeAuthHandler{}, 39 } 40 41 am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second) 42 assert.NoError(t, err) 43 assert.NotNil(t, am) 44 45 assert.Len(t, am.authHandlers, 2) 46 } 47 48 func Test_authManager_authenticate(t *testing.T) { 49 tests := []struct { 50 name string 51 args authKey 52 wantErr assert.ErrorAssertionFunc 53 wantAuthenticated bool 54 wantEntries int 55 }{ 56 { 57 name: "missing handler for auth type", 58 args: authKey{ 59 localIdentity: 1000, 60 remoteIdentity: 2000, 61 remoteNodeID: 2, 62 authType: 1, 63 }, 64 wantErr: assertErrorString("unknown requested auth type: spire"), 65 wantEntries: 0, 66 }, 67 { 68 name: "missing node IP for node ID", 69 args: authKey{ 70 localIdentity: 1000, 71 remoteIdentity: 2000, 72 remoteNodeID: 1, 73 authType: 2, 74 }, 75 wantErr: assertErrorString("remote node IP not available for node ID 1"), 76 wantEntries: 0, 77 }, 78 { 79 name: "successful auth", 80 args: authKey{ 81 localIdentity: 1000, 82 remoteIdentity: 2000, 83 remoteNodeID: 2, 84 authType: 100, 85 }, 86 wantErr: assert.NoError, 87 wantEntries: 1, 88 }, 89 } 90 for _, tt := range tests { 91 t.Run(tt.name, func(t *testing.T) { 92 authMap := &fakeAuthMap{ 93 entries: map[authKey]authInfo{}, 94 } 95 am, err := newAuthManager( 96 logrus.New(), 97 []authHandler{&alwaysFailAuthHandler{}, newAlwaysPassAuthHandler(logrus.New())}, 98 authMap, 99 newFakeNodeIDHandler(map[uint16]string{ 100 2: "172.18.0.2", 101 3: "172.18.0.3", 102 }), 103 time.Second, 104 ) 105 106 assert.NoError(t, err) 107 108 err = am.authenticate(tt.args) 109 tt.wantErr(t, err) 110 111 assert.Len(t, authMap.entries, tt.wantEntries) 112 }) 113 } 114 } 115 116 func Test_authManager_handleAuthRequest(t *testing.T) { 117 authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())} 118 119 am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second) 120 assert.NoError(t, err) 121 assert.NotNil(t, am) 122 123 handleAuthCalled := false 124 am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) { 125 handleAuthCalled = true 126 assert.False(t, reAuth) 127 assert.Equal(t, authKey{localIdentity: 1000, remoteIdentity: 2000, remoteNodeID: 0, authType: 100}, k) 128 } 129 130 err = am.handleAuthRequest(context.Background(), signalAuthKey{LocalIdentity: 1000, RemoteIdentity: 2000, RemoteNodeID: 0, AuthType: 100, Pad: 0}) 131 assert.NoError(t, err) 132 assert.True(t, handleAuthCalled) 133 } 134 135 func Test_authManager_handleAuthRequest_reservedRemoteIdentity(t *testing.T) { 136 authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())} 137 138 am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second) 139 assert.NoError(t, err) 140 assert.NotNil(t, am) 141 142 handleAuthCalled := false 143 am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) { 144 handleAuthCalled = true 145 } 146 147 err = am.handleAuthRequest(context.Background(), signalAuthKey{LocalIdentity: 100, RemoteIdentity: identity.ReservedIdentityWorldIPv6.Uint32(), RemoteNodeID: 0, AuthType: 100, Pad: 0}) 148 assert.NoError(t, err) 149 assert.False(t, handleAuthCalled) 150 } 151 152 func Test_authManager_handleAuthRequest_reservedLocalIdentity(t *testing.T) { 153 authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())} 154 155 am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second) 156 assert.NoError(t, err) 157 assert.NotNil(t, am) 158 159 handleAuthCalled := false 160 am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) { 161 handleAuthCalled = true 162 } 163 164 err = am.handleAuthRequest(context.Background(), signalAuthKey{LocalIdentity: identity.ReservedIdentityWorldIPv6.Uint32(), RemoteIdentity: 100, RemoteNodeID: 0, AuthType: 100, Pad: 0}) 165 assert.NoError(t, err) 166 assert.False(t, handleAuthCalled) 167 } 168 169 func Test_authManager_handleCertificateRotationEvent_Error(t *testing.T) { 170 authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())} 171 aMap := &fakeAuthMap{ 172 failGet: true, 173 } 174 175 am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second) 176 assert.NoError(t, err) 177 assert.NotNil(t, am) 178 179 err = am.handleCertificateRotationEvent(context.Background(), certs.CertificateRotationEvent{Identity: identity.NumericIdentity(10)}) 180 assert.ErrorContains(t, err, "failed to get all auth map entries: failed to list entries") 181 } 182 183 func Test_authManager_handleCertificateRotationEvent(t *testing.T) { 184 authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())} 185 aMap := &fakeAuthMap{ 186 entries: map[authKey]authInfo{ 187 {localIdentity: 1000, remoteIdentity: 2000, remoteNodeID: 1, authType: 100}: {expiration: time.Now()}, 188 {localIdentity: 2000, remoteIdentity: 3000, remoteNodeID: 1, authType: 100}: {expiration: time.Now()}, 189 {localIdentity: 3000, remoteIdentity: 4000, remoteNodeID: 1, authType: 100}: {expiration: time.Now()}, 190 }, 191 } 192 193 am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second) 194 assert.NoError(t, err) 195 assert.NotNil(t, am) 196 197 handleAuthCalled := false 198 am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) { 199 handleAuthCalled = true 200 assert.True(t, reAuth) 201 assert.True(t, k.localIdentity == 2000 || k.remoteIdentity == 2000) 202 } 203 204 err = am.handleCertificateRotationEvent(context.Background(), certs.CertificateRotationEvent{Identity: identity.NumericIdentity(2000)}) 205 assert.NoError(t, err) 206 assert.True(t, handleAuthCalled) 207 } 208 209 func Test_authManager_handleCertificateDeletionEvent(t *testing.T) { 210 authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())} 211 aMap := &fakeAuthMap{ 212 entries: map[authKey]authInfo{ 213 {localIdentity: 1000, remoteIdentity: 2000, remoteNodeID: 1000, authType: 100}: {expiration: time.Now()}, 214 {localIdentity: 2000, remoteIdentity: 3000, remoteNodeID: 1000, authType: 100}: {expiration: time.Now()}, 215 {localIdentity: 3000, remoteIdentity: 4000, remoteNodeID: 1000, authType: 100}: {expiration: time.Now()}, 216 }, 217 } 218 219 am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second) 220 assert.NoError(t, err) 221 assert.NotNil(t, am) 222 223 err = am.handleCertificateRotationEvent(context.Background(), certs.CertificateRotationEvent{ 224 Identity: identity.NumericIdentity(2000), 225 Deleted: true, 226 }) 227 assert.NoError(t, err) 228 assert.Len(t, aMap.entries, 1) 229 } 230 231 // Fake NodeIDHandler 232 type fakeNodeIDHandler struct { 233 nodeIdMappings map[uint16]string 234 } 235 236 func (r *fakeNodeIDHandler) DumpNodeIDs() []*models.NodeID { 237 return []*models.NodeID{} 238 } 239 240 func (r *fakeNodeIDHandler) RestoreNodeIDs() { 241 } 242 243 func newFakeNodeIDHandler(mappings map[uint16]string) *fakeNodeIDHandler { 244 return &fakeNodeIDHandler{ 245 nodeIdMappings: mappings, 246 } 247 } 248 249 func (r *fakeNodeIDHandler) GetNodeIP(id uint16) string { 250 return r.nodeIdMappings[id] 251 } 252 253 func (r *fakeNodeIDHandler) GetNodeID(nodeIP net.IP) (uint16, bool) { 254 for id, ip := range r.nodeIdMappings { 255 if ip == nodeIP.String() { 256 return id, true 257 } 258 } 259 260 return 0, false 261 } 262 263 // Fake AuthHandler 264 type fakeAuthHandler struct { 265 } 266 267 func (r *fakeAuthHandler) authenticate(authReq *authRequest) (*authResponse, error) { 268 269 return &authResponse{}, nil 270 } 271 272 func (r *fakeAuthHandler) authType() policy.AuthType { 273 return policy.AuthType(255) 274 } 275 276 func (r *fakeAuthHandler) subscribeToRotatedIdentities() <-chan certs.CertificateRotationEvent { 277 return nil 278 } 279 280 func (r *fakeAuthHandler) certProviderStatus() *models.Status { 281 return nil 282 } 283 284 // Fake AuthMap 285 type fakeAuthMap struct { 286 entries map[authKey]authInfo 287 failDelete bool 288 failGet bool 289 } 290 291 func (r *fakeAuthMap) Delete(key authKey) error { 292 if r.failDelete { 293 return errors.New("failed to delete entry") 294 } 295 296 if _, ok := r.entries[key]; !ok { 297 return ebpf.ErrKeyNotExist 298 } 299 300 delete(r.entries, key) 301 return nil 302 } 303 304 func (r *fakeAuthMap) DeleteIf(predicate func(key authKey, info authInfo) bool) error { 305 if r.failDelete { 306 return errors.New("failed to delete entry") 307 } 308 309 maps.DeleteFunc(r.entries, predicate) 310 311 return nil 312 } 313 314 func (r *fakeAuthMap) All() (map[authKey]authInfo, error) { 315 if r.failGet { 316 return nil, errors.New("failed to list entries") 317 } 318 319 return r.entries, nil 320 } 321 322 func (r *fakeAuthMap) GetCacheInfo(key authKey) (authInfoCache, error) { 323 v, err := r.Get(key) 324 325 return authInfoCache{ 326 authInfo: v, 327 }, err 328 } 329 330 func (r *fakeAuthMap) Get(key authKey) (authInfo, error) { 331 if r.failGet { 332 return authInfo{}, errors.New("failed to get entry") 333 } 334 335 v, ok := r.entries[key] 336 if !ok { 337 return authInfo{}, errors.New("authinfo not available") 338 } 339 340 return v, nil 341 } 342 343 func (r *fakeAuthMap) Update(key authKey, info authInfo) error { 344 r.entries[authKey{ 345 localIdentity: key.localIdentity, 346 remoteIdentity: key.remoteIdentity, 347 remoteNodeID: key.remoteNodeID, 348 authType: key.authType, 349 }] = authInfo{expiration: info.expiration} 350 return nil 351 } 352 353 func (r *fakeAuthMap) MaxEntries() uint32 { 354 return 1 << 8 355 } 356 357 func assertErrorString(errString string) assert.ErrorAssertionFunc { 358 return func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool { 359 return assert.EqualError(t, err, errString, msgAndArgs) 360 } 361 }