github.com/cilium/cilium@v1.16.2/pkg/auth/mutual_authhandler_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package auth
     5  
     6  import (
     7  	"context"
     8  	"crypto/ecdsa"
     9  	"crypto/elliptic"
    10  	"crypto/rand"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"crypto/x509/pkix"
    14  	"fmt"
    15  	"math/big"
    16  	"net"
    17  	"net/url"
    18  	"reflect"
    19  	"strings"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/sirupsen/logrus"
    24  
    25  	"github.com/cilium/cilium/api/v1/models"
    26  	"github.com/cilium/cilium/pkg/auth/certs"
    27  	"github.com/cilium/cilium/pkg/endpoint"
    28  	"github.com/cilium/cilium/pkg/identity"
    29  )
    30  
    31  var (
    32  	id1000 = identity.NumericIdentity(1000)
    33  	id1001 = identity.NumericIdentity(1001)
    34  	idbad1 = identity.NumericIdentity(9999)
    35  )
    36  
    37  type fakeEndpointGetter struct{}
    38  
    39  func (f *fakeEndpointGetter) GetEndpoints() []*endpoint.Endpoint {
    40  	ep := []*endpoint.Endpoint{}
    41  
    42  	for _, id := range []identity.NumericIdentity{id1000, id1001, idbad1} {
    43  		ep = append(ep, &endpoint.Endpoint{
    44  			SecurityIdentity: &identity.Identity{
    45  				ID: id,
    46  			},
    47  		})
    48  	}
    49  
    50  	return ep
    51  }
    52  
    53  type fakeCertificateProvider struct {
    54  	certMap    map[string]*x509.Certificate
    55  	privkeyMap map[string]*ecdsa.PrivateKey
    56  	caPool     *x509.CertPool
    57  }
    58  
    59  func (f *fakeCertificateProvider) GetTrustBundle() (*x509.CertPool, error) {
    60  	return f.caPool, nil
    61  }
    62  
    63  func (f *fakeCertificateProvider) GetCertificateForIdentity(id identity.NumericIdentity) (*tls.Certificate, error) {
    64  	uriSAN := "spiffe://spiffe.cilium/identity/" + id.String()
    65  	cert, ok := f.certMap[uriSAN]
    66  	if !ok {
    67  		return nil, fmt.Errorf("no certificate for %s", uriSAN)
    68  	}
    69  
    70  	// convert the x509 cert to tls cert
    71  	certBytes := cert.Raw
    72  	tlsCert := tls.Certificate{
    73  		Certificate: [][]byte{certBytes},
    74  		PrivateKey:  f.privkeyMap[uriSAN],
    75  		Leaf:        cert,
    76  	}
    77  	return &tlsCert, nil
    78  }
    79  
    80  func (f *fakeCertificateProvider) ValidateIdentity(id identity.NumericIdentity, cert *x509.Certificate) (bool, error) {
    81  	for _, uri := range cert.URIs {
    82  		if uri.String() == fmt.Sprintf("spiffe://spiffe.cilium/identity/%d", id) {
    83  			return true, nil
    84  		}
    85  	}
    86  	return false, nil
    87  }
    88  
    89  func (f *fakeCertificateProvider) NumericIdentityToSNI(id identity.NumericIdentity) string {
    90  	return id.String() + "." + "spiffe.cilium"
    91  }
    92  
    93  func (f *fakeCertificateProvider) SNIToNumericIdentity(sni string) (identity.NumericIdentity, error) {
    94  	suffix := "." + "spiffe.cilium"
    95  	if !strings.HasSuffix(sni, suffix) {
    96  		return 0, fmt.Errorf("SNI %s does not belong to our trust domain", sni)
    97  	}
    98  
    99  	idStr := strings.TrimSuffix(sni, suffix)
   100  	return identity.ParseNumericIdentity(idStr)
   101  }
   102  
   103  func (f *fakeCertificateProvider) SubscribeToRotatedIdentities() <-chan certs.CertificateRotationEvent {
   104  	return nil
   105  }
   106  
   107  func (f *fakeCertificateProvider) Status() *models.Status {
   108  	return nil
   109  }
   110  
   111  func generateTestCertificates(t *testing.T) (map[string]*x509.Certificate, map[string]*ecdsa.PrivateKey, *x509.CertPool) {
   112  	caPool := x509.NewCertPool()
   113  
   114  	caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   115  	if err != nil {
   116  		t.Fatalf("failed to generate CA key: %v", err)
   117  	}
   118  	caCert := &x509.Certificate{
   119  		Subject:               pkix.Name{CommonName: "ca"},
   120  		NotAfter:              time.Now().Add(time.Hour),
   121  		IsCA:                  true,
   122  		KeyUsage:              x509.KeyUsageCertSign,
   123  		SerialNumber:          big.NewInt(1),
   124  		BasicConstraintsValid: true,
   125  	}
   126  	// sign the CA certificate
   127  	caCertBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caKey.PublicKey, caKey)
   128  	if err != nil {
   129  		t.Fatalf("failed to sign CA certificate: %v", err)
   130  	}
   131  	caCert, err = x509.ParseCertificate(caCertBytes)
   132  	if err != nil {
   133  		t.Fatalf("failed to parse CA certificate: %v", err)
   134  	}
   135  	caPool.AddCert(caCert)
   136  
   137  	// sign two SPIFFE like certificates
   138  	leafCerts := make(map[string]*x509.Certificate)
   139  	leafPrivKeys := make(map[string]*ecdsa.PrivateKey)
   140  
   141  	for i := 1000; i <= 1002; i++ {
   142  		certURL, err := url.Parse(fmt.Sprintf("spiffe://spiffe.cilium/identity/%d", i))
   143  		if err != nil {
   144  			t.Fatalf("failed to parse URL: %v", err)
   145  		}
   146  		leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   147  		if err != nil {
   148  			t.Fatalf("failed to generate leaf key: %v", err)
   149  		}
   150  		leafCert := &x509.Certificate{
   151  			NotAfter:     time.Now().Add(time.Hour),
   152  			URIs:         []*url.URL{certURL},
   153  			KeyUsage:     x509.KeyUsageDigitalSignature,
   154  			ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
   155  			SerialNumber: big.NewInt(int64(i)),
   156  		}
   157  		leafCertBytes, err := x509.CreateCertificate(rand.Reader, leafCert, caCert, &leafKey.PublicKey, caKey)
   158  		if err != nil {
   159  			t.Fatalf("failed to sign leaf certificate: %v", err)
   160  		}
   161  		leafCert, err = x509.ParseCertificate(leafCertBytes)
   162  		if err != nil {
   163  			t.Fatalf("failed to parse leaf certificate: %v", err)
   164  		}
   165  		leafCerts[certURL.String()] = leafCert
   166  		leafPrivKeys[certURL.String()] = leafKey
   167  	}
   168  
   169  	return leafCerts, leafPrivKeys, caPool
   170  }
   171  
   172  func Test_mutualAuthHandler_verifyPeerCertificate(t *testing.T) {
   173  	certMap, keyMap, caPool := generateTestCertificates(t)
   174  	certMapOtherCA, _, _ := generateTestCertificates(t)
   175  	type args struct {
   176  		id             *identity.NumericIdentity
   177  		caBundle       *x509.CertPool
   178  		verifiedChains [][]*x509.Certificate
   179  	}
   180  	tests := []struct {
   181  		name    string
   182  		args    args
   183  		want    *time.Time
   184  		wantErr bool
   185  	}{
   186  		{
   187  			name: "valid certificate with SNI to match identity",
   188  			args: args{
   189  				id:             &id1000,
   190  				caBundle:       caPool,
   191  				verifiedChains: [][]*x509.Certificate{{certMap["spiffe://spiffe.cilium/identity/1000"]}},
   192  			},
   193  			want:    &certMap["spiffe://spiffe.cilium/identity/1000"].NotAfter,
   194  			wantErr: false,
   195  		},
   196  		{
   197  			name: "valid certificate with no identity provided",
   198  			args: args{
   199  				id:             nil,
   200  				caBundle:       caPool,
   201  				verifiedChains: [][]*x509.Certificate{{certMap["spiffe://spiffe.cilium/identity/1000"]}},
   202  			},
   203  			want:    &certMap["spiffe://spiffe.cilium/identity/1000"].NotAfter,
   204  			wantErr: false,
   205  		},
   206  		{
   207  			name: "error on invalid certificate because incorrect identity provided",
   208  			args: args{
   209  				id:             &id1001,
   210  				caBundle:       caPool,
   211  				verifiedChains: [][]*x509.Certificate{{certMap["spiffe://spiffe.cilium/identity/1000"]}},
   212  			},
   213  			want:    nil,
   214  			wantErr: true,
   215  		},
   216  		{
   217  			name: "error on invalid certificate signed by other CA",
   218  			args: args{
   219  				id:             &id1000,
   220  				caBundle:       caPool,
   221  				verifiedChains: [][]*x509.Certificate{{certMapOtherCA["spiffe://spiffe.cilium/identity/1000"]}},
   222  			},
   223  			want:    nil,
   224  			wantErr: true,
   225  		},
   226  		{
   227  			name: "error on invalid certificate signed by other CA with no identity provided",
   228  			args: args{
   229  				id:             nil,
   230  				caBundle:       caPool,
   231  				verifiedChains: [][]*x509.Certificate{{certMapOtherCA["spiffe://spiffe.cilium/identity/1000"]}},
   232  			},
   233  			want:    nil,
   234  			wantErr: true,
   235  		}, {
   236  			name: "error on no certificates in verifiedChains",
   237  			args: args{
   238  				id:             nil,
   239  				caBundle:       caPool,
   240  				verifiedChains: [][]*x509.Certificate{},
   241  			},
   242  			want:    nil,
   243  			wantErr: true,
   244  		},
   245  		{
   246  			name: "error on empty caBundle provided",
   247  			args: args{
   248  				id:             nil,
   249  				caBundle:       x509.NewCertPool(),
   250  				verifiedChains: [][]*x509.Certificate{{certMapOtherCA["spiffe://spiffe.cilium/identity/1000"]}},
   251  			},
   252  			want:    nil,
   253  			wantErr: true,
   254  		},
   255  	}
   256  	for _, tt := range tests {
   257  		t.Run(tt.name, func(t *testing.T) {
   258  			m := &mutualAuthHandler{
   259  				cfg:  MutualAuthConfig{MutualAuthListenerPort: 1234},
   260  				log:  logrus.New(),
   261  				cert: &fakeCertificateProvider{certMap: certMap, caPool: caPool, privkeyMap: keyMap},
   262  			}
   263  			got, err := m.verifyPeerCertificate(tt.args.id, tt.args.caBundle, tt.args.verifiedChains)
   264  			if (err != nil) != tt.wantErr {
   265  				t.Errorf("mutualAuthHandler.verifyPeerCertificate() error = %v, wantErr %v", err, tt.wantErr)
   266  				return
   267  			}
   268  			if !reflect.DeepEqual(got, tt.want) {
   269  				t.Errorf("mutualAuthHandler.verifyPeerCertificate() = %v, want %v", got, tt.want)
   270  			}
   271  		})
   272  	}
   273  }
   274  
   275  func Test_mutualAuthHandler_GetCertificateForIncomingConnection(t *testing.T) {
   276  	certMap, keyMap, caPool := generateTestCertificates(t)
   277  	type args struct {
   278  		info *tls.ClientHelloInfo
   279  	}
   280  	tests := []struct {
   281  		name    string
   282  		args    args
   283  		wantURI string
   284  		wantErr bool
   285  	}{
   286  		{
   287  			name: "valid certificate with SNI to match identity",
   288  			args: args{
   289  				info: &tls.ClientHelloInfo{
   290  					ServerName: "1000.spiffe.cilium",
   291  				},
   292  			},
   293  			wantURI: "spiffe://spiffe.cilium/identity/1000",
   294  			wantErr: false,
   295  		},
   296  		{
   297  			name: "no certificate for non existing endpoint identity",
   298  			args: args{
   299  				info: &tls.ClientHelloInfo{
   300  					ServerName: "1002.spiffe.cilium",
   301  				},
   302  			},
   303  			wantErr: true,
   304  		},
   305  		{
   306  			name: "no certificate for non existing security identity",
   307  			args: args{
   308  				info: &tls.ClientHelloInfo{
   309  					ServerName: "9999.spiffe.cilium",
   310  				},
   311  			},
   312  			wantErr: true,
   313  		},
   314  		{
   315  			name: "no certificate for random non existing domain",
   316  			args: args{
   317  				info: &tls.ClientHelloInfo{
   318  					ServerName: "www.example.com",
   319  				},
   320  			},
   321  			wantErr: true,
   322  		},
   323  	}
   324  	for _, tt := range tests {
   325  		t.Run(tt.name, func(t *testing.T) {
   326  			m := &mutualAuthHandler{
   327  				cfg:             MutualAuthConfig{MutualAuthListenerPort: 1234},
   328  				log:             logrus.New(),
   329  				cert:            &fakeCertificateProvider{certMap: certMap, caPool: caPool, privkeyMap: keyMap},
   330  				endpointManager: &fakeEndpointGetter{},
   331  			}
   332  			got, err := m.GetCertificateForIncomingConnection(tt.args.info)
   333  			if (err != nil) != tt.wantErr {
   334  				t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() error = %v, wantErr %v", err, tt.wantErr)
   335  				return
   336  			}
   337  			if !tt.wantErr {
   338  				if got.Leaf == nil {
   339  					t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() leaf certificate is nil")
   340  				}
   341  				if len(got.Leaf.URIs) == 0 {
   342  					t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() leaf certificate has no URIs")
   343  				}
   344  				gotURI := got.Leaf.URIs[0].String()
   345  				if !reflect.DeepEqual(gotURI, tt.wantURI) {
   346  					t.Errorf("mutualAuthHandler.GetCertificateForIncomingConnection() = %v, want %v", got, tt.wantURI)
   347  				}
   348  			}
   349  
   350  		})
   351  	}
   352  }
   353  
   354  func Test_mutualAuthHandler_authenticate(t *testing.T) {
   355  	certMap, keyMap, caPool := generateTestCertificates(t)
   356  
   357  	mAuthHandler := &mutualAuthHandler{
   358  		cfg:             MutualAuthConfig{MutualAuthListenerPort: getRandomOpenPort(t)},
   359  		log:             logrus.New(),
   360  		cert:            &fakeCertificateProvider{certMap: certMap, caPool: caPool, privkeyMap: keyMap},
   361  		endpointManager: &fakeEndpointGetter{},
   362  	}
   363  	mAuthHandler.onStart(context.Background())
   364  	defer mAuthHandler.onStop(context.Background())
   365  
   366  	var lowestExpirationTime time.Time
   367  	for _, cert := range certMap {
   368  		if lowestExpirationTime.IsZero() || cert.NotAfter.Before(lowestExpirationTime) {
   369  			lowestExpirationTime = cert.NotAfter
   370  		}
   371  	}
   372  
   373  	type args struct {
   374  		ar *authRequest
   375  	}
   376  	tests := []struct {
   377  		name    string
   378  		args    args
   379  		want    *authResponse
   380  		wantErr bool
   381  	}{
   382  		{
   383  			name: "authenticate two valid identities",
   384  			args: args{
   385  				ar: &authRequest{
   386  					localIdentity:  id1000,
   387  					remoteIdentity: id1001,
   388  					remoteNodeIP:   GetLoopBackIP(t),
   389  				},
   390  			},
   391  			want: &authResponse{
   392  				expirationTime: lowestExpirationTime,
   393  			},
   394  		},
   395  		{
   396  			name: "error on authenticate when remote identity is not valid",
   397  			args: args{
   398  				ar: &authRequest{
   399  					localIdentity:  id1000,
   400  					remoteIdentity: idbad1,
   401  					remoteNodeIP:   GetLoopBackIP(t),
   402  				},
   403  			},
   404  			wantErr: true,
   405  		},
   406  		{
   407  			name: "error on  authenticate when local identity is not valid",
   408  			args: args{
   409  				ar: &authRequest{
   410  					localIdentity:  idbad1,
   411  					remoteIdentity: id1001,
   412  					remoteNodeIP:   GetLoopBackIP(t),
   413  				},
   414  			},
   415  			wantErr: true,
   416  		},
   417  		{
   418  			name: "error on authenticate when auth request is bad",
   419  			args: args{
   420  				ar: &authRequest{
   421  					localIdentity: id1000,
   422  					// all other fields are intentionally left blank
   423  				},
   424  			},
   425  			wantErr: true,
   426  		},
   427  	}
   428  	for _, tt := range tests {
   429  		t.Run(tt.name, func(t *testing.T) {
   430  			got, err := mAuthHandler.authenticate(tt.args.ar)
   431  			if (err != nil) != tt.wantErr {
   432  				t.Errorf("mutualAuthHandler.authenticate() error = %v, wantErr %v", err, tt.wantErr)
   433  				return
   434  			}
   435  			if !reflect.DeepEqual(got, tt.want) {
   436  				t.Errorf("mutualAuthHandler.authenticate() = %v, want %v", got, tt.want)
   437  			}
   438  		})
   439  	}
   440  }
   441  
   442  func getRandomOpenPort(t *testing.T) int {
   443  	l, err := net.Listen("tcp", ":0")
   444  	if err != nil {
   445  		t.Fatalf("failed to get random open port: %v", err)
   446  	}
   447  	defer l.Close()
   448  	addr := l.Addr().(*net.TCPAddr)
   449  	return addr.Port
   450  }
   451  
   452  func GetLoopBackIP(t *testing.T) string {
   453  	addrs, err := net.InterfaceAddrs()
   454  	if err != nil {
   455  		t.Fatalf("failed to get interface addresses: %v", err)
   456  	}
   457  	for _, address := range addrs {
   458  		if ipnet, ok := address.(*net.IPNet); ok && ipnet.IP.IsLoopback() {
   459  			return ipnet.IP.String()
   460  		}
   461  	}
   462  
   463  	t.Fatalf("failed to get loopback IP")
   464  	return ""
   465  }