github.com/cilium/cilium@v1.16.2/pkg/auth/mutual_authhandler_test.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package auth 5 6 import ( 7 "context" 8 "crypto/ecdsa" 9 "crypto/elliptic" 10 "crypto/rand" 11 "crypto/tls" 12 "crypto/x509" 13 "crypto/x509/pkix" 14 "fmt" 15 "math/big" 16 "net" 17 "net/url" 18 "reflect" 19 "strings" 20 "testing" 21 "time" 22 23 "github.com/sirupsen/logrus" 24 25 "github.com/cilium/cilium/api/v1/models" 26 "github.com/cilium/cilium/pkg/auth/certs" 27 "github.com/cilium/cilium/pkg/endpoint" 28 "github.com/cilium/cilium/pkg/identity" 29 ) 30 31 var ( 32 id1000 = identity.NumericIdentity(1000) 33 id1001 = identity.NumericIdentity(1001) 34 idbad1 = identity.NumericIdentity(9999) 35 ) 36 37 type fakeEndpointGetter struct{} 38 39 func (f *fakeEndpointGetter) GetEndpoints() []*endpoint.Endpoint { 40 ep := []*endpoint.Endpoint{} 41 42 for _, id := range []identity.NumericIdentity{id1000, id1001, idbad1} { 43 ep = append(ep, &endpoint.Endpoint{ 44 SecurityIdentity: &identity.Identity{ 45 ID: id, 46 }, 47 }) 48 } 49 50 return ep 51 } 52 53 type fakeCertificateProvider struct { 54 certMap map[string]*x509.Certificate 55 privkeyMap map[string]*ecdsa.PrivateKey 56 caPool *x509.CertPool 57 } 58 59 func (f *fakeCertificateProvider) GetTrustBundle() (*x509.CertPool, error) { 60 return f.caPool, nil 61 } 62 63 func (f *fakeCertificateProvider) GetCertificateForIdentity(id identity.NumericIdentity) (*tls.Certificate, error) { 64 uriSAN := "spiffe://spiffe.cilium/identity/" + id.String() 65 cert, ok := f.certMap[uriSAN] 66 if !ok { 67 return nil, fmt.Errorf("no certificate for %s", uriSAN) 68 } 69 70 // convert the x509 cert to tls cert 71 certBytes := cert.Raw 72 tlsCert := tls.Certificate{ 73 Certificate: [][]byte{certBytes}, 74 PrivateKey: f.privkeyMap[uriSAN], 75 Leaf: cert, 76 } 77 return &tlsCert, nil 78 } 79 80 func (f *fakeCertificateProvider) ValidateIdentity(id identity.NumericIdentity, cert *x509.Certificate) (bool, error) { 81 for _, uri := range cert.URIs { 82 if uri.String() == fmt.Sprintf("spiffe://spiffe.cilium/identity/%d", id) { 83 return true, nil 84 } 85 } 86 return false, nil 87 } 88 89 func (f *fakeCertificateProvider) NumericIdentityToSNI(id identity.NumericIdentity) string { 90 return id.String() + "." + "spiffe.cilium" 91 } 92 93 func (f *fakeCertificateProvider) SNIToNumericIdentity(sni string) (identity.NumericIdentity, error) { 94 suffix := "." + "spiffe.cilium" 95 if !strings.HasSuffix(sni, suffix) { 96 return 0, fmt.Errorf("SNI %s does not belong to our trust domain", sni) 97 } 98 99 idStr := strings.TrimSuffix(sni, suffix) 100 return identity.ParseNumericIdentity(idStr) 101 } 102 103 func (f *fakeCertificateProvider) SubscribeToRotatedIdentities() <-chan certs.CertificateRotationEvent { 104 return nil 105 } 106 107 func (f *fakeCertificateProvider) Status() *models.Status { 108 return nil 109 } 110 111 func generateTestCertificates(t *testing.T) (map[string]*x509.Certificate, map[string]*ecdsa.PrivateKey, *x509.CertPool) { 112 caPool := x509.NewCertPool() 113 114 caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 115 if err != nil { 116 t.Fatalf("failed to generate CA key: %v", err) 117 } 118 caCert := &x509.Certificate{ 119 Subject: pkix.Name{CommonName: "ca"}, 120 NotAfter: time.Now().Add(time.Hour), 121 IsCA: true, 122 KeyUsage: x509.KeyUsageCertSign, 123 SerialNumber: big.NewInt(1), 124 BasicConstraintsValid: true, 125 } 126 // sign the CA certificate 127 caCertBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caKey.PublicKey, caKey) 128 if err != nil { 129 t.Fatalf("failed to sign CA certificate: %v", err) 130 } 131 caCert, err = x509.ParseCertificate(caCertBytes) 132 if err != nil { 133 t.Fatalf("failed to parse CA certificate: %v", err) 134 } 135 caPool.AddCert(caCert) 136 137 // sign two SPIFFE like certificates 138 leafCerts := make(map[string]*x509.Certificate) 139 leafPrivKeys := make(map[string]*ecdsa.PrivateKey) 140 141 for i := 1000; i <= 1002; i++ { 142 certURL, err := url.Parse(fmt.Sprintf("spiffe://spiffe.cilium/identity/%d", i)) 143 if err != nil { 144 t.Fatalf("failed to parse URL: %v", err) 145 } 146 leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 147 if err != nil { 148 t.Fatalf("failed to generate leaf key: %v", err) 149 } 150 leafCert := &x509.Certificate{ 151 NotAfter: time.Now().Add(time.Hour), 152 URIs: []*url.URL{certURL}, 153 KeyUsage: x509.KeyUsageDigitalSignature, 154 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, 155 SerialNumber: big.NewInt(int64(i)), 156 } 157 leafCertBytes, err := x509.CreateCertificate(rand.Reader, leafCert, caCert, &leafKey.PublicKey, caKey) 158 if err != nil { 159 t.Fatalf("failed to sign leaf certificate: %v", err) 160 } 161 leafCert, err = x509.ParseCertificate(leafCertBytes) 162 if err != nil { 163 t.Fatalf("failed to parse leaf certificate: %v", err) 164 } 165 leafCerts[certURL.String()] = leafCert 166 leafPrivKeys[certURL.String()] = leafKey 167 } 168 169 return leafCerts, leafPrivKeys, caPool 170 } 171 172 func Test_mutualAuthHandler_verifyPeerCertificate(t *testing.T) { 173 certMap, keyMap, caPool := generateTestCertificates(t) 174 certMapOtherCA, _, _ := generateTestCertificates(t) 175 type args struct { 176 id *identity.NumericIdentity 177 caBundle *x509.CertPool 178 verifiedChains [][]*x509.Certificate 179 } 180 tests := []struct { 181 name string 182 args args 183 want *time.Time 184 wantErr bool 185 }{ 186 { 187 name: "valid certificate with SNI to match identity", 188 args: args{ 189 id: &id1000, 190 caBundle: caPool, 191 verifiedChains: [][]*x509.Certificate{{certMap["spiffe://spiffe.cilium/identity/1000"]}}, 192 }, 193 want: &certMap["spiffe://spiffe.cilium/identity/1000"].NotAfter, 194 wantErr: false, 195 }, 196 { 197 name: "valid certificate with no identity provided", 198 args: args{ 199 id: nil, 200 caBundle: caPool, 201 verifiedChains: [][]*x509.Certificate{{certMap["spiffe://spiffe.cilium/identity/1000"]}}, 202 }, 203 want: &certMap["spiffe://spiffe.cilium/identity/1000"].NotAfter, 204 wantErr: false, 205 }, 206 { 207 name: "error on invalid certificate because incorrect identity provided", 208 args: args{ 209 id: &id1001, 210 caBundle: caPool, 211 verifiedChains: [][]*x509.Certificate{{certMap["spiffe://spiffe.cilium/identity/1000"]}}, 212 }, 213 want: nil, 214 wantErr: true, 215 }, 216 { 217 name: "error on invalid certificate signed by other CA", 218 args: args{ 219 id: &id1000, 220 caBundle: caPool, 221 verifiedChains: [][]*x509.Certificate{{certMapOtherCA["spiffe://spiffe.cilium/identity/1000"]}}, 222 }, 223 want: nil, 224 wantErr: true, 225 }, 226 { 227 name: "error on invalid certificate signed by other CA with no identity provided", 228 args: args{ 229 id: nil, 230 caBundle: caPool, 231 verifiedChains: [][]*x509.Certificate{{certMapOtherCA["spiffe://spiffe.cilium/identity/1000"]}}, 232 }, 233 want: nil, 234 wantErr: true, 235 }, { 236 name: "error on no certificates in verifiedChains", 237 args: args{ 238 id: nil, 239 caBundle: caPool, 240 verifiedChains: [][]*x509.Certificate{}, 241 }, 242 want: nil, 243 wantErr: true, 244 }, 245 { 246 name: "error on empty caBundle provided", 247 args: args{ 248 id: nil, 249 caBundle: x509.NewCertPool(), 250 verifiedChains: [][]*x509.Certificate{{certMapOtherCA["spiffe://spiffe.cilium/identity/1000"]}}, 251 }, 252 want: nil, 253 wantErr: true, 254 }, 255 } 256 for _, tt := range tests { 257 t.Run(tt.name, func(t *testing.T) { 258 m := &mutualAuthHandler{ 259 cfg: MutualAuthConfig{MutualAuthListenerPort: 1234}, 260 log: logrus.New(), 261 cert: &fakeCertificateProvider{certMap: certMap, caPool: caPool, privkeyMap: keyMap}, 262 } 263 got, err := m.verifyPeerCertificate(tt.args.id, tt.args.caBundle, tt.args.verifiedChains) 264 if (err != nil) != tt.wantErr { 265 t.Errorf("mutualAuthHandler.verifyPeerCertificate() error = %v, wantErr %v", err, tt.wantErr) 266 return 267 } 268 if !reflect.DeepEqual(got, tt.want) { 269 t.Errorf("mutualAuthHandler.verifyPeerCertificate() = %v, want %v", got, tt.want) 270 } 271 }) 272 } 273 } 274 275 func Test_mutualAuthHandler_GetCertificateForIncomingConnection(t *testing.T) { 276 certMap, keyMap, caPool := generateTestCertificates(t) 277 type args struct { 278 info *tls.ClientHelloInfo 279 } 280 tests := []struct { 281 name string 282 args args 283 wantURI string 284 wantErr bool 285 }{ 286 { 287 name: "valid certificate with SNI to match identity", 288 args: args{ 289 info: &tls.ClientHelloInfo{ 290 ServerName: "1000.spiffe.cilium", 291 }, 292 }, 293 wantURI: "spiffe://spiffe.cilium/identity/1000", 294 wantErr: false, 295 }, 296 { 297 name: "no certificate for non existing endpoint identity", 298 args: args{ 299 info: &tls.ClientHelloInfo{ 300 ServerName: "1002.spiffe.cilium", 301 }, 302 }, 303 wantErr: true, 304 }, 305 { 306 name: "no certificate for non existing security identity", 307 args: args{ 308 info: &tls.ClientHelloInfo{ 309 ServerName: "9999.spiffe.cilium", 310 }, 311 }, 312 wantErr: true, 313 }, 314 { 315 name: "no certificate for random non existing domain", 316 args: args{ 317 info: &tls.ClientHelloInfo{ 318 ServerName: "www.example.com", 319 }, 320 }, 321 wantErr: true, 322 }, 323 } 324 for _, tt := range tests { 325 t.Run(tt.name, func(t *testing.T) { 326 m := &mutualAuthHandler{ 327 cfg: MutualAuthConfig{MutualAuthListenerPort: 1234}, 328 log: logrus.New(), 329 cert: &fakeCertificateProvider{certMap: certMap, caPool: caPool, privkeyMap: keyMap}, 330 endpointManager: &fakeEndpointGetter{}, 331 } 332 got, err := m.GetCertificateForIncomingConnection(tt.args.info) 333 if (err != nil) != tt.wantErr { 334 t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() error = %v, wantErr %v", err, tt.wantErr) 335 return 336 } 337 if !tt.wantErr { 338 if got.Leaf == nil { 339 t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() leaf certificate is nil") 340 } 341 if len(got.Leaf.URIs) == 0 { 342 t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() leaf certificate has no URIs") 343 } 344 gotURI := got.Leaf.URIs[0].String() 345 if !reflect.DeepEqual(gotURI, tt.wantURI) { 346 t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() = %v, want %v", got, tt.wantURI) 347 } 348 } 349 350 }) 351 } 352 } 353 354 func Test_mutualAuthHandler_authenticate(t *testing.T) { 355 certMap, keyMap, caPool := generateTestCertificates(t) 356 357 mAuthHandler := &mutualAuthHandler{ 358 cfg: MutualAuthConfig{MutualAuthListenerPort: getRandomOpenPort(t)}, 359 log: logrus.New(), 360 cert: &fakeCertificateProvider{certMap: certMap, caPool: caPool, privkeyMap: keyMap}, 361 endpointManager: &fakeEndpointGetter{}, 362 } 363 mAuthHandler.onStart(context.Background()) 364 defer mAuthHandler.onStop(context.Background()) 365 366 var lowestExpirationTime time.Time 367 for _, cert := range certMap { 368 if lowestExpirationTime.IsZero() || cert.NotAfter.Before(lowestExpirationTime) { 369 lowestExpirationTime = cert.NotAfter 370 } 371 } 372 373 type args struct { 374 ar *authRequest 375 } 376 tests := []struct { 377 name string 378 args args 379 want *authResponse 380 wantErr bool 381 }{ 382 { 383 name: "authenticate two valid identities", 384 args: args{ 385 ar: &authRequest{ 386 localIdentity: id1000, 387 remoteIdentity: id1001, 388 remoteNodeIP: GetLoopBackIP(t), 389 }, 390 }, 391 want: &authResponse{ 392 expirationTime: lowestExpirationTime, 393 }, 394 }, 395 { 396 name: "error on authenticate when remote identity is not valid", 397 args: args{ 398 ar: &authRequest{ 399 localIdentity: id1000, 400 remoteIdentity: idbad1, 401 remoteNodeIP: GetLoopBackIP(t), 402 }, 403 }, 404 wantErr: true, 405 }, 406 { 407 name: "error on authenticate when local identity is not valid", 408 args: args{ 409 ar: &authRequest{ 410 localIdentity: idbad1, 411 remoteIdentity: id1001, 412 remoteNodeIP: GetLoopBackIP(t), 413 }, 414 }, 415 wantErr: true, 416 }, 417 { 418 name: "error on authenticate when auth request is bad", 419 args: args{ 420 ar: &authRequest{ 421 localIdentity: id1000, 422 // all other fields are intentionally left blank 423 }, 424 }, 425 wantErr: true, 426 }, 427 } 428 for _, tt := range tests { 429 t.Run(tt.name, func(t *testing.T) { 430 got, err := mAuthHandler.authenticate(tt.args.ar) 431 if (err != nil) != tt.wantErr { 432 t.Errorf("mutualAuthHandler.authenticate() error = %v, wantErr %v", err, tt.wantErr) 433 return 434 } 435 if !reflect.DeepEqual(got, tt.want) { 436 t.Errorf("mutualAuthHandler.authenticate() = %v, want %v", got, tt.want) 437 } 438 }) 439 } 440 } 441 442 func getRandomOpenPort(t *testing.T) int { 443 l, err := net.Listen("tcp", ":0") 444 if err != nil { 445 t.Fatalf("failed to get random open port: %v", err) 446 } 447 defer l.Close() 448 addr := l.Addr().(*net.TCPAddr) 449 return addr.Port 450 } 451 452 func GetLoopBackIP(t *testing.T) string { 453 addrs, err := net.InterfaceAddrs() 454 if err != nil { 455 t.Fatalf("failed to get interface addresses: %v", err) 456 } 457 for _, address := range addrs { 458 if ipnet, ok := address.(*net.IPNet); ok && ipnet.IP.IsLoopback() { 459 return ipnet.IP.String() 460 } 461 } 462 463 t.Fatalf("failed to get loopback IP") 464 return "" 465 }