github.com/cilium/cilium@v1.16.2/pkg/auth/manager_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package auth
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/cilium/ebpf"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/stretchr/testify/assert"
    16  	"golang.org/x/exp/maps"
    17  
    18  	"github.com/cilium/cilium/api/v1/models"
    19  	"github.com/cilium/cilium/pkg/auth/certs"
    20  	"github.com/cilium/cilium/pkg/identity"
    21  	"github.com/cilium/cilium/pkg/policy"
    22  )
    23  
    24  func Test_newAuthManager_clashingAuthHandlers(t *testing.T) {
    25  	authHandlers := []authHandler{
    26  		&alwaysFailAuthHandler{},
    27  		&alwaysFailAuthHandler{},
    28  	}
    29  
    30  	am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
    31  	assert.ErrorContains(t, err, "multiple handlers for auth type: test-always-fail")
    32  	assert.Nil(t, am)
    33  }
    34  
    35  func Test_newAuthManager(t *testing.T) {
    36  	authHandlers := []authHandler{
    37  		newAlwaysPassAuthHandler(logrus.New()),
    38  		&fakeAuthHandler{},
    39  	}
    40  
    41  	am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
    42  	assert.NoError(t, err)
    43  	assert.NotNil(t, am)
    44  
    45  	assert.Len(t, am.authHandlers, 2)
    46  }
    47  
    48  func Test_authManager_authenticate(t *testing.T) {
    49  	tests := []struct {
    50  		name              string
    51  		args              authKey
    52  		wantErr           assert.ErrorAssertionFunc
    53  		wantAuthenticated bool
    54  		wantEntries       int
    55  	}{
    56  		{
    57  			name: "missing handler for auth type",
    58  			args: authKey{
    59  				localIdentity:  1000,
    60  				remoteIdentity: 2000,
    61  				remoteNodeID:   2,
    62  				authType:       1,
    63  			},
    64  			wantErr:     assertErrorString("unknown requested auth type: spire"),
    65  			wantEntries: 0,
    66  		},
    67  		{
    68  			name: "missing node IP for node ID",
    69  			args: authKey{
    70  				localIdentity:  1000,
    71  				remoteIdentity: 2000,
    72  				remoteNodeID:   1,
    73  				authType:       2,
    74  			},
    75  			wantErr:     assertErrorString("remote node IP not available for node ID 1"),
    76  			wantEntries: 0,
    77  		},
    78  		{
    79  			name: "successful auth",
    80  			args: authKey{
    81  				localIdentity:  1000,
    82  				remoteIdentity: 2000,
    83  				remoteNodeID:   2,
    84  				authType:       100,
    85  			},
    86  			wantErr:     assert.NoError,
    87  			wantEntries: 1,
    88  		},
    89  	}
    90  	for _, tt := range tests {
    91  		t.Run(tt.name, func(t *testing.T) {
    92  			authMap := &fakeAuthMap{
    93  				entries: map[authKey]authInfo{},
    94  			}
    95  			am, err := newAuthManager(
    96  				logrus.New(),
    97  				[]authHandler{&alwaysFailAuthHandler{}, newAlwaysPassAuthHandler(logrus.New())},
    98  				authMap,
    99  				newFakeNodeIDHandler(map[uint16]string{
   100  					2: "172.18.0.2",
   101  					3: "172.18.0.3",
   102  				}),
   103  				time.Second,
   104  			)
   105  
   106  			assert.NoError(t, err)
   107  
   108  			err = am.authenticate(tt.args)
   109  			tt.wantErr(t, err)
   110  
   111  			assert.Len(t, authMap.entries, tt.wantEntries)
   112  		})
   113  	}
   114  }
   115  
   116  func Test_authManager_handleAuthRequest(t *testing.T) {
   117  	authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}
   118  
   119  	am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
   120  	assert.NoError(t, err)
   121  	assert.NotNil(t, am)
   122  
   123  	handleAuthCalled := false
   124  	am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) {
   125  		handleAuthCalled = true
   126  		assert.False(t, reAuth)
   127  		assert.Equal(t, authKey{localIdentity: 1000, remoteIdentity: 2000, remoteNodeID: 0, authType: 100}, k)
   128  	}
   129  
   130  	err = am.handleAuthRequest(context.Background(), signalAuthKey{LocalIdentity: 1000, RemoteIdentity: 2000, RemoteNodeID: 0, AuthType: 100, Pad: 0})
   131  	assert.NoError(t, err)
   132  	assert.True(t, handleAuthCalled)
   133  }
   134  
   135  func Test_authManager_handleAuthRequest_reservedRemoteIdentity(t *testing.T) {
   136  	authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}
   137  
   138  	am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
   139  	assert.NoError(t, err)
   140  	assert.NotNil(t, am)
   141  
   142  	handleAuthCalled := false
   143  	am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) {
   144  		handleAuthCalled = true
   145  	}
   146  
   147  	err = am.handleAuthRequest(context.Background(), signalAuthKey{LocalIdentity: 100, RemoteIdentity: identity.ReservedIdentityWorldIPv6.Uint32(), RemoteNodeID: 0, AuthType: 100, Pad: 0})
   148  	assert.NoError(t, err)
   149  	assert.False(t, handleAuthCalled)
   150  }
   151  
   152  func Test_authManager_handleAuthRequest_reservedLocalIdentity(t *testing.T) {
   153  	authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}
   154  
   155  	am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
   156  	assert.NoError(t, err)
   157  	assert.NotNil(t, am)
   158  
   159  	handleAuthCalled := false
   160  	am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) {
   161  		handleAuthCalled = true
   162  	}
   163  
   164  	err = am.handleAuthRequest(context.Background(), signalAuthKey{LocalIdentity: identity.ReservedIdentityWorldIPv6.Uint32(), RemoteIdentity: 100, RemoteNodeID: 0, AuthType: 100, Pad: 0})
   165  	assert.NoError(t, err)
   166  	assert.False(t, handleAuthCalled)
   167  }
   168  
   169  func Test_authManager_handleCertificateRotationEvent_Error(t *testing.T) {
   170  	authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}
   171  	aMap := &fakeAuthMap{
   172  		failGet: true,
   173  	}
   174  
   175  	am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
   176  	assert.NoError(t, err)
   177  	assert.NotNil(t, am)
   178  
   179  	err = am.handleCertificateRotationEvent(context.Background(), certs.CertificateRotationEvent{Identity: identity.NumericIdentity(10)})
   180  	assert.ErrorContains(t, err, "failed to get all auth map entries: failed to list entries")
   181  }
   182  
   183  func Test_authManager_handleCertificateRotationEvent(t *testing.T) {
   184  	authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}
   185  	aMap := &fakeAuthMap{
   186  		entries: map[authKey]authInfo{
   187  			{localIdentity: 1000, remoteIdentity: 2000, remoteNodeID: 1, authType: 100}: {expiration: time.Now()},
   188  			{localIdentity: 2000, remoteIdentity: 3000, remoteNodeID: 1, authType: 100}: {expiration: time.Now()},
   189  			{localIdentity: 3000, remoteIdentity: 4000, remoteNodeID: 1, authType: 100}: {expiration: time.Now()},
   190  		},
   191  	}
   192  
   193  	am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
   194  	assert.NoError(t, err)
   195  	assert.NotNil(t, am)
   196  
   197  	handleAuthCalled := false
   198  	am.handleAuthenticationFunc = func(_ *AuthManager, k authKey, reAuth bool) {
   199  		handleAuthCalled = true
   200  		assert.True(t, reAuth)
   201  		assert.True(t, k.localIdentity == 2000 || k.remoteIdentity == 2000)
   202  	}
   203  
   204  	err = am.handleCertificateRotationEvent(context.Background(), certs.CertificateRotationEvent{Identity: identity.NumericIdentity(2000)})
   205  	assert.NoError(t, err)
   206  	assert.True(t, handleAuthCalled)
   207  }
   208  
   209  func Test_authManager_handleCertificateDeletionEvent(t *testing.T) {
   210  	authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}
   211  	aMap := &fakeAuthMap{
   212  		entries: map[authKey]authInfo{
   213  			{localIdentity: 1000, remoteIdentity: 2000, remoteNodeID: 1000, authType: 100}: {expiration: time.Now()},
   214  			{localIdentity: 2000, remoteIdentity: 3000, remoteNodeID: 1000, authType: 100}: {expiration: time.Now()},
   215  			{localIdentity: 3000, remoteIdentity: 4000, remoteNodeID: 1000, authType: 100}: {expiration: time.Now()},
   216  		},
   217  	}
   218  
   219  	am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
   220  	assert.NoError(t, err)
   221  	assert.NotNil(t, am)
   222  
   223  	err = am.handleCertificateRotationEvent(context.Background(), certs.CertificateRotationEvent{
   224  		Identity: identity.NumericIdentity(2000),
   225  		Deleted:  true,
   226  	})
   227  	assert.NoError(t, err)
   228  	assert.Len(t, aMap.entries, 1)
   229  }
   230  
   231  // Fake NodeIDHandler
   232  type fakeNodeIDHandler struct {
   233  	nodeIdMappings map[uint16]string
   234  }
   235  
   236  func (r *fakeNodeIDHandler) DumpNodeIDs() []*models.NodeID {
   237  	return []*models.NodeID{}
   238  }
   239  
   240  func (r *fakeNodeIDHandler) RestoreNodeIDs() {
   241  }
   242  
   243  func newFakeNodeIDHandler(mappings map[uint16]string) *fakeNodeIDHandler {
   244  	return &fakeNodeIDHandler{
   245  		nodeIdMappings: mappings,
   246  	}
   247  }
   248  
   249  func (r *fakeNodeIDHandler) GetNodeIP(id uint16) string {
   250  	return r.nodeIdMappings[id]
   251  }
   252  
   253  func (r *fakeNodeIDHandler) GetNodeID(nodeIP net.IP) (uint16, bool) {
   254  	for id, ip := range r.nodeIdMappings {
   255  		if ip == nodeIP.String() {
   256  			return id, true
   257  		}
   258  	}
   259  
   260  	return 0, false
   261  }
   262  
   263  // Fake AuthHandler
   264  type fakeAuthHandler struct {
   265  }
   266  
   267  func (r *fakeAuthHandler) authenticate(authReq *authRequest) (*authResponse, error) {
   268  
   269  	return &authResponse{}, nil
   270  }
   271  
   272  func (r *fakeAuthHandler) authType() policy.AuthType {
   273  	return policy.AuthType(255)
   274  }
   275  
   276  func (r *fakeAuthHandler) subscribeToRotatedIdentities() <-chan certs.CertificateRotationEvent {
   277  	return nil
   278  }
   279  
   280  func (r *fakeAuthHandler) certProviderStatus() *models.Status {
   281  	return nil
   282  }
   283  
   284  // Fake AuthMap
   285  type fakeAuthMap struct {
   286  	entries    map[authKey]authInfo
   287  	failDelete bool
   288  	failGet    bool
   289  }
   290  
   291  func (r *fakeAuthMap) Delete(key authKey) error {
   292  	if r.failDelete {
   293  		return errors.New("failed to delete entry")
   294  	}
   295  
   296  	if _, ok := r.entries[key]; !ok {
   297  		return ebpf.ErrKeyNotExist
   298  	}
   299  
   300  	delete(r.entries, key)
   301  	return nil
   302  }
   303  
   304  func (r *fakeAuthMap) DeleteIf(predicate func(key authKey, info authInfo) bool) error {
   305  	if r.failDelete {
   306  		return errors.New("failed to delete entry")
   307  	}
   308  
   309  	maps.DeleteFunc(r.entries, predicate)
   310  
   311  	return nil
   312  }
   313  
   314  func (r *fakeAuthMap) All() (map[authKey]authInfo, error) {
   315  	if r.failGet {
   316  		return nil, errors.New("failed to list entries")
   317  	}
   318  
   319  	return r.entries, nil
   320  }
   321  
   322  func (r *fakeAuthMap) GetCacheInfo(key authKey) (authInfoCache, error) {
   323  	v, err := r.Get(key)
   324  
   325  	return authInfoCache{
   326  		authInfo: v,
   327  	}, err
   328  }
   329  
   330  func (r *fakeAuthMap) Get(key authKey) (authInfo, error) {
   331  	if r.failGet {
   332  		return authInfo{}, errors.New("failed to get entry")
   333  	}
   334  
   335  	v, ok := r.entries[key]
   336  	if !ok {
   337  		return authInfo{}, errors.New("authinfo not available")
   338  	}
   339  
   340  	return v, nil
   341  }
   342  
   343  func (r *fakeAuthMap) Update(key authKey, info authInfo) error {
   344  	r.entries[authKey{
   345  		localIdentity:  key.localIdentity,
   346  		remoteIdentity: key.remoteIdentity,
   347  		remoteNodeID:   key.remoteNodeID,
   348  		authType:       key.authType,
   349  	}] = authInfo{expiration: info.expiration}
   350  	return nil
   351  }
   352  
   353  func (r *fakeAuthMap) MaxEntries() uint32 {
   354  	return 1 << 8
   355  }
   356  
   357  func assertErrorString(errString string) assert.ErrorAssertionFunc {
   358  	return func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool {
   359  		return assert.EqualError(t, err, errString, msgAndArgs)
   360  	}
   361  }