github.com/Azure/aad-pod-identity@v1.8.17/pkg/mic/mic_test.go (about)

     1  package mic
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	internalaadpodid "github.com/Azure/aad-pod-identity/pkg/apis/aadpodidentity"
    13  	aadpodid "github.com/Azure/aad-pod-identity/pkg/apis/aadpodidentity/v1"
    14  	cp "github.com/Azure/aad-pod-identity/pkg/cloudprovider"
    15  	"github.com/Azure/aad-pod-identity/pkg/config"
    16  	"github.com/Azure/aad-pod-identity/pkg/crd"
    17  	"github.com/Azure/aad-pod-identity/pkg/metrics"
    18  	"github.com/Azure/aad-pod-identity/pkg/retry"
    19  
    20  	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute"
    21  	"github.com/stretchr/testify/assert"
    22  	api "k8s.io/api/core/v1"
    23  	corev1 "k8s.io/api/core/v1"
    24  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    25  	"k8s.io/apimachinery/pkg/runtime"
    26  	"k8s.io/client-go/rest"
    27  	"k8s.io/client-go/tools/cache"
    28  	"k8s.io/klog/v2"
    29  )
    30  
    31  var (
    32  	testResourceID = "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/identity1"
    33  )
    34  
    35  /****************** CLOUD PROVIDER MOCK ****************************/
    36  type TestCloudClient struct {
    37  	*cp.Client
    38  	// testVMClient is test validation purpose.
    39  	testVMClient   *TestVMClient
    40  	testVMSSClient *TestVMSSClient
    41  }
    42  
    43  type TestVMClient struct {
    44  	*cp.VMClient
    45  
    46  	mu       sync.Mutex
    47  	nodeMap  map[string]*compute.VirtualMachine
    48  	nodeIDs  map[string]map[string]bool
    49  	err      *error
    50  	identity *compute.VirtualMachineIdentity
    51  }
    52  
    53  func (c *TestVMClient) SetError(err error) {
    54  	c.err = &err
    55  }
    56  
    57  func (c *TestVMClient) UnsetError() {
    58  	c.err = nil
    59  }
    60  
    61  func (c *TestVMClient) Get(rgName string, nodeName string) (compute.VirtualMachine, error) {
    62  	c.mu.Lock()
    63  	defer c.mu.Unlock()
    64  
    65  	stored := c.nodeMap[nodeName]
    66  	if stored == nil {
    67  		vm := new(compute.VirtualMachine)
    68  		c.nodeMap[nodeName] = vm
    69  		c.nodeIDs[nodeName] = make(map[string]bool)
    70  		vm.Identity = &compute.VirtualMachineIdentity{}
    71  		return *vm, nil
    72  	}
    73  
    74  	storedIDs := c.nodeIDs[nodeName]
    75  	newVMIdentity := make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue)
    76  	for id := range storedIDs {
    77  		newVMIdentity[id] = &compute.VirtualMachineIdentityUserAssignedIdentitiesValue{}
    78  	}
    79  	stored.Identity.UserAssignedIdentities = newVMIdentity
    80  	return *stored, nil
    81  }
    82  
    83  func (c *TestVMClient) UpdateIdentities(rg, nodeName string, vm compute.VirtualMachine) error {
    84  	c.mu.Lock()
    85  	defer c.mu.Unlock()
    86  
    87  	if c.err != nil {
    88  		defer c.UnsetError()
    89  		return *c.err
    90  	}
    91  	if vm.Identity != nil && vm.Identity.UserAssignedIdentities != nil {
    92  		for k, v := range vm.Identity.UserAssignedIdentities {
    93  			if v == nil {
    94  				delete(c.nodeIDs[nodeName], k)
    95  			} else {
    96  				c.nodeIDs[nodeName][k] = true
    97  			}
    98  		}
    99  	}
   100  	if vm.Identity != nil && vm.Identity.UserAssignedIdentities == nil {
   101  		for k := range c.nodeIDs[nodeName] {
   102  			delete(c.nodeIDs[nodeName], k)
   103  		}
   104  	}
   105  	c.nodeMap[nodeName] = &vm
   106  	return nil
   107  }
   108  
   109  func (c *TestVMClient) ListMSI() map[string]*[]string {
   110  	c.mu.Lock()
   111  	defer c.mu.Unlock()
   112  
   113  	ret := make(map[string]*[]string)
   114  
   115  	for key, val := range c.nodeMap {
   116  		var ids []string
   117  		for k := range val.Identity.UserAssignedIdentities {
   118  			ids = append(ids, k)
   119  		}
   120  		ret[key] = &ids
   121  	}
   122  	return ret
   123  }
   124  
   125  func (c *TestVMClient) CompareMSI(nodeName string, expectedUserIDs []string) bool {
   126  	c.mu.Lock()
   127  	defer c.mu.Unlock()
   128  
   129  	stored := c.nodeMap[nodeName]
   130  	if stored == nil || stored.Identity == nil {
   131  		return false
   132  	}
   133  
   134  	var actualUserIDs []string
   135  	for k := range c.nodeIDs[nodeName] {
   136  		actualUserIDs = append(actualUserIDs, k)
   137  	}
   138  	if actualUserIDs == nil {
   139  		if len(expectedUserIDs) == 0 && stored.Identity.Type == compute.ResourceIdentityTypeNone { // Validate that we have reset the resource type as none.
   140  			return true
   141  		}
   142  		return false
   143  	}
   144  
   145  	sort.Strings(actualUserIDs)
   146  	sort.Strings(expectedUserIDs)
   147  	for i := range actualUserIDs {
   148  		if !strings.EqualFold(actualUserIDs[i], expectedUserIDs[i]) {
   149  			return false
   150  		}
   151  	}
   152  
   153  	return true
   154  }
   155  
   156  type TestVMSSClient struct {
   157  	*cp.VMSSClient
   158  
   159  	mu       sync.Mutex
   160  	nodeMap  map[string]*compute.VirtualMachineScaleSet
   161  	nodeIDs  map[string]map[string]bool
   162  	err      *error
   163  	identity *compute.VirtualMachineScaleSetIdentity
   164  }
   165  
   166  func (c *TestVMSSClient) SetError(err error) {
   167  	c.err = &err
   168  }
   169  
   170  func (c *TestVMSSClient) UnsetError() {
   171  	c.err = nil
   172  }
   173  
   174  func (c *TestVMSSClient) Get(rgName string, nodeName string) (compute.VirtualMachineScaleSet, error) {
   175  	c.mu.Lock()
   176  	defer c.mu.Unlock()
   177  
   178  	stored := c.nodeMap[nodeName]
   179  	if stored == nil {
   180  		vmss := new(compute.VirtualMachineScaleSet)
   181  		c.nodeMap[nodeName] = vmss
   182  		c.nodeIDs[nodeName] = make(map[string]bool)
   183  		vmss.Identity = &compute.VirtualMachineScaleSetIdentity{}
   184  		return *vmss, nil
   185  	}
   186  
   187  	storedIDs := c.nodeIDs[nodeName]
   188  	newVMSSIdentity := make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue)
   189  	for id := range storedIDs {
   190  		newVMSSIdentity[id] = &compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{}
   191  	}
   192  	stored.Identity.UserAssignedIdentities = newVMSSIdentity
   193  	return *stored, nil
   194  }
   195  
   196  func (c *TestVMSSClient) UpdateIdentities(rg, nodeName string, vmss compute.VirtualMachineScaleSet) error {
   197  	c.mu.Lock()
   198  	defer c.mu.Unlock()
   199  
   200  	if c.err != nil {
   201  		defer c.UnsetError()
   202  		return *c.err
   203  	}
   204  	if vmss.Identity != nil && vmss.Identity.UserAssignedIdentities != nil {
   205  		for k, v := range vmss.Identity.UserAssignedIdentities {
   206  			if v == nil {
   207  				delete(c.nodeIDs[nodeName], k)
   208  			} else {
   209  				c.nodeIDs[nodeName][k] = true
   210  			}
   211  		}
   212  	}
   213  	if vmss.Identity != nil && vmss.Identity.UserAssignedIdentities == nil {
   214  		for k := range c.nodeIDs[nodeName] {
   215  			delete(c.nodeIDs[nodeName], k)
   216  		}
   217  	}
   218  
   219  	c.nodeMap[nodeName] = &vmss
   220  	return nil
   221  }
   222  
   223  func (c *TestVMSSClient) ListMSI() map[string]*[]string {
   224  	ret := make(map[string]*[]string)
   225  
   226  	for key, val := range c.nodeMap {
   227  		var ids []string
   228  		for k := range val.Identity.UserAssignedIdentities {
   229  			ids = append(ids, k)
   230  		}
   231  		ret[key] = &ids
   232  	}
   233  	return ret
   234  }
   235  
   236  func (c *TestVMSSClient) CompareMSI(nodeName string, expectedUserIDs []string) bool {
   237  	c.mu.Lock()
   238  	defer c.mu.Unlock()
   239  
   240  	stored := c.nodeMap[nodeName]
   241  	if stored == nil || stored.Identity == nil {
   242  		return false
   243  	}
   244  
   245  	var actualUserIDs []string
   246  	for k := range c.nodeIDs[nodeName] {
   247  		actualUserIDs = append(actualUserIDs, k)
   248  	}
   249  
   250  	if actualUserIDs == nil {
   251  		if len(expectedUserIDs) == 0 && stored.Identity.Type == compute.ResourceIdentityTypeNone { // Validate that we have reset the resource type as none.
   252  			return true
   253  		}
   254  		return false
   255  	}
   256  
   257  	sort.Strings(actualUserIDs)
   258  	sort.Strings(expectedUserIDs)
   259  	for i := range actualUserIDs {
   260  		if !strings.EqualFold(actualUserIDs[i], expectedUserIDs[i]) {
   261  			return false
   262  		}
   263  	}
   264  
   265  	return true
   266  }
   267  
   268  func (c *TestCloudClient) GetUserMSIs(name string, isvmss bool) ([]string, error) {
   269  	var ret []string
   270  	if isvmss {
   271  		vmss, _ := c.testVMSSClient.Get("", name)
   272  		for id := range vmss.Identity.UserAssignedIdentities {
   273  			ret = append(ret, id)
   274  		}
   275  	} else {
   276  		vm, _ := c.testVMClient.Get("", name)
   277  		for id := range vm.Identity.UserAssignedIdentities {
   278  			ret = append(ret, id)
   279  		}
   280  	}
   281  
   282  	return ret, nil
   283  }
   284  
   285  func (c *TestCloudClient) ListMSI() map[string]*[]string {
   286  	if c.Client.Config.VMType == "vmss" {
   287  		return c.testVMSSClient.ListMSI()
   288  	}
   289  	return c.testVMClient.ListMSI()
   290  }
   291  
   292  func (c *TestCloudClient) CompareMSI(nodeName string, userIDs []string) bool {
   293  	if c.Client.Config.VMType == "vmss" {
   294  		return c.testVMSSClient.CompareMSI(nodeName, userIDs)
   295  	}
   296  	return c.testVMClient.CompareMSI(nodeName, userIDs)
   297  }
   298  
   299  func (c *TestCloudClient) PrintMSI() {
   300  	for key, val := range c.ListMSI() {
   301  		klog.Infof("\nnode name: %s", key)
   302  		if val != nil {
   303  			for i, id := range *val {
   304  				klog.Infof("%d) %s", i, id)
   305  			}
   306  		}
   307  	}
   308  }
   309  
   310  func (c *TestCloudClient) SetError(err error) {
   311  	c.testVMClient.SetError(err)
   312  	c.testVMSSClient.SetError(err)
   313  }
   314  
   315  func (c *TestCloudClient) UnsetError() {
   316  	c.testVMClient.UnsetError()
   317  	c.testVMSSClient.UnsetError()
   318  }
   319  
   320  func NewTestVMClient() *TestVMClient {
   321  	nodeMap := make(map[string]*compute.VirtualMachine)
   322  	nodeIDs := make(map[string]map[string]bool)
   323  	vmClient := &cp.VMClient{}
   324  	identity := &compute.VirtualMachineIdentity{
   325  		UserAssignedIdentities: make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue),
   326  	}
   327  
   328  	return &TestVMClient{
   329  		VMClient: vmClient,
   330  		nodeMap:  nodeMap,
   331  		nodeIDs:  nodeIDs,
   332  		identity: identity,
   333  	}
   334  }
   335  
   336  func NewTestVMSSClient() *TestVMSSClient {
   337  	nodeMap := make(map[string]*compute.VirtualMachineScaleSet)
   338  	nodeIDs := make(map[string]map[string]bool)
   339  	vmssClient := &cp.VMSSClient{}
   340  	identity := &compute.VirtualMachineScaleSetIdentity{
   341  		UserAssignedIdentities: make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue),
   342  	}
   343  
   344  	return &TestVMSSClient{
   345  		VMSSClient: vmssClient,
   346  		nodeMap:    nodeMap,
   347  		nodeIDs:    nodeIDs,
   348  		identity:   identity,
   349  	}
   350  }
   351  
   352  func NewTestCloudClient(cfg config.AzureConfig) *TestCloudClient {
   353  	vmClient := NewTestVMClient()
   354  	vmssClient := NewTestVMSSClient()
   355  	retryClient := retry.NewRetryClient(2, 0)
   356  	cloudClient := &cp.Client{
   357  		Config:      cfg,
   358  		VMClient:    vmClient,
   359  		VMSSClient:  vmssClient,
   360  		RetryClient: retryClient,
   361  	}
   362  
   363  	return &TestCloudClient{
   364  		cloudClient,
   365  		vmClient,
   366  		vmssClient,
   367  	}
   368  }
   369  
   370  /****************** POD MOCK ****************************/
   371  type TestPodClient struct {
   372  	mu   sync.Mutex
   373  	pods []*corev1.Pod
   374  }
   375  
   376  func NewTestPodClient() *TestPodClient {
   377  	var pods []*corev1.Pod
   378  	return &TestPodClient{
   379  		pods: pods,
   380  	}
   381  }
   382  
   383  func (c *TestPodClient) Start(exit <-chan struct{}) {
   384  	klog.Info("start called from the test interface")
   385  }
   386  
   387  func (c *TestPodClient) GetPods() ([]*corev1.Pod, error) {
   388  	// TODO: Add label matching. For now we add only pods which we want to add.
   389  	c.mu.Lock()
   390  	defer c.mu.Unlock()
   391  
   392  	pods := make([]*corev1.Pod, len(c.pods))
   393  	copy(pods, c.pods)
   394  
   395  	return pods, nil
   396  }
   397  
   398  func (c *TestPodClient) AddPod(podName, podNs, nodeName, binding string) {
   399  	labels := make(map[string]string)
   400  	labels[aadpodid.CRDLabelKey] = binding
   401  	pod := &corev1.Pod{
   402  		ObjectMeta: metav1.ObjectMeta{
   403  			Name:      podName,
   404  			Namespace: podNs,
   405  			Labels:    labels,
   406  		},
   407  		Spec: corev1.PodSpec{
   408  			NodeName: nodeName,
   409  		},
   410  	}
   411  
   412  	c.mu.Lock()
   413  	defer c.mu.Unlock()
   414  	c.pods = append(c.pods, pod)
   415  }
   416  
   417  func (c *TestPodClient) DeletePod(podName, podNs string) {
   418  	var newPods []*corev1.Pod
   419  	changed := false
   420  
   421  	c.mu.Lock()
   422  	defer c.mu.Unlock()
   423  
   424  	for _, pod := range c.pods {
   425  		if pod.Name == podName && pod.Namespace == podNs {
   426  			changed = true
   427  			continue
   428  		} else {
   429  			newPods = append(newPods, pod)
   430  		}
   431  	}
   432  	if changed {
   433  		c.pods = newPods
   434  	}
   435  }
   436  
   437  /****************** CRD MOCK ****************************/
   438  
   439  type TestCrdClient struct {
   440  	*crd.Client
   441  	mu            sync.Mutex
   442  	assignedIDMap map[string]*internalaadpodid.AzureAssignedIdentity
   443  	bindingMap    map[string]*aadpodid.AzureIdentityBinding
   444  	idMap         map[string]*aadpodid.AzureIdentity
   445  	err           *error
   446  }
   447  
   448  func NewTestCrdClient(config *rest.Config) *TestCrdClient {
   449  	return &TestCrdClient{
   450  		assignedIDMap: make(map[string]*internalaadpodid.AzureAssignedIdentity),
   451  		bindingMap:    make(map[string]*aadpodid.AzureIdentityBinding),
   452  		idMap:         make(map[string]*aadpodid.AzureIdentity),
   453  	}
   454  }
   455  
   456  func (c *TestCrdClient) Start(exit <-chan struct{}) {
   457  }
   458  
   459  func (c *TestCrdClient) SyncCache(exit <-chan struct{}, initial bool, cacheSyncs ...cache.InformerSynced) {
   460  
   461  }
   462  
   463  func (c *TestCrdClient) SyncCacheAll(exit <-chan struct{}, initial bool) {
   464  
   465  }
   466  
   467  func (c *TestCrdClient) CreateCrdWatchers(eventCh chan internalaadpodid.EventType) error {
   468  	return nil
   469  }
   470  
   471  func (c *TestCrdClient) RemoveAssignedIdentity(assignedIdentity *internalaadpodid.AzureAssignedIdentity) error {
   472  	c.mu.Lock()
   473  	defer c.mu.Unlock()
   474  
   475  	if c.err != nil {
   476  		return *c.err
   477  	}
   478  	delete(c.assignedIDMap, assignedIdentity.Name)
   479  	return nil
   480  }
   481  
   482  // This function is not used currently
   483  // TODO: consider remove
   484  func (c *TestCrdClient) CreateAssignedIdentity(assignedIdentity *internalaadpodid.AzureAssignedIdentity) error {
   485  	assignedIdentityToStore := *assignedIdentity // Make a copy to store in the map.
   486  	c.mu.Lock()
   487  	c.assignedIDMap[assignedIdentity.Name] = &assignedIdentityToStore
   488  	c.mu.Unlock()
   489  	return nil
   490  }
   491  
   492  func (c *TestCrdClient) UpdateAssignedIdentity(assignedIdentity *internalaadpodid.AzureAssignedIdentity) error {
   493  	assignedIdentityToStore := *assignedIdentity // Make a copy to store in the map.
   494  	c.mu.Lock()
   495  	c.assignedIDMap[assignedIdentity.Name] = &assignedIdentityToStore
   496  	c.mu.Unlock()
   497  	return nil
   498  }
   499  
   500  func (c *TestCrdClient) UpdateAzureAssignedIdentityStatus(assignedIdentity *internalaadpodid.AzureAssignedIdentity, status string) error {
   501  	assignedIdentity.Status.Status = status
   502  	assignedIdentityToStore := *assignedIdentity // Make a copy to store in the map.
   503  	c.mu.Lock()
   504  	c.assignedIDMap[assignedIdentity.Name] = &assignedIdentityToStore
   505  	c.mu.Unlock()
   506  	return nil
   507  }
   508  
   509  func (c *TestCrdClient) CreateBinding(name, ns, idName, selector, resourceVersion string) {
   510  	binding := &aadpodid.AzureIdentityBinding{
   511  		ObjectMeta: metav1.ObjectMeta{
   512  			Name:            name,
   513  			Namespace:       ns,
   514  			ResourceVersion: resourceVersion,
   515  		},
   516  		Spec: aadpodid.AzureIdentityBindingSpec{
   517  			AzureIdentity: idName,
   518  			Selector:      selector,
   519  		},
   520  	}
   521  	c.mu.Lock()
   522  	c.bindingMap[getIDKey(ns, name)] = binding
   523  	c.mu.Unlock()
   524  }
   525  
   526  func (c *TestCrdClient) CreateID(idName, ns string, t aadpodid.IdentityType, rID, cID string, cp *api.SecretReference, tID, adRID, adEpt, resourceVersion string) {
   527  	id := &aadpodid.AzureIdentity{
   528  		ObjectMeta: metav1.ObjectMeta{
   529  			Name:            idName,
   530  			Namespace:       ns,
   531  			ResourceVersion: resourceVersion,
   532  		},
   533  		Spec: aadpodid.AzureIdentitySpec{
   534  			Type:         t,
   535  			ResourceID:   rID,
   536  			ClientID:     cID,
   537  			TenantID:     tID,
   538  			ADResourceID: adRID,
   539  			ADEndpoint:   adEpt,
   540  		},
   541  	}
   542  	c.mu.Lock()
   543  	c.idMap[getIDKey(ns, idName)] = id
   544  	c.mu.Unlock()
   545  }
   546  
   547  func (c *TestCrdClient) ListIds() (*[]internalaadpodid.AzureIdentity, error) {
   548  	idList := make([]internalaadpodid.AzureIdentity, 0)
   549  	c.mu.Lock()
   550  	for _, v := range c.idMap {
   551  		currID := aadpodid.ConvertV1IdentityToInternalIdentity(*v)
   552  		idList = append(idList, currID)
   553  	}
   554  	c.mu.Unlock()
   555  	return &idList, nil
   556  }
   557  
   558  func (c *TestCrdClient) ListBindings() (*[]internalaadpodid.AzureIdentityBinding, error) {
   559  	bindingList := make([]internalaadpodid.AzureIdentityBinding, 0)
   560  	c.mu.Lock()
   561  	for _, v := range c.bindingMap {
   562  		newBinding := aadpodid.ConvertV1BindingToInternalBinding(*v)
   563  		bindingList = append(bindingList, newBinding)
   564  	}
   565  	c.mu.Unlock()
   566  	return &bindingList, nil
   567  }
   568  
   569  func (c *TestCrdClient) ListAssignedIDs() (*[]internalaadpodid.AzureAssignedIdentity, error) {
   570  	assignedIDList := make([]internalaadpodid.AzureAssignedIdentity, 0)
   571  	c.mu.Lock()
   572  	for _, v := range c.assignedIDMap {
   573  		assignedIDList = append(assignedIDList, *v)
   574  	}
   575  	c.mu.Unlock()
   576  	return &assignedIDList, nil
   577  }
   578  
   579  func (c *TestCrdClient) ListAssignedIDsInMap() (map[string]internalaadpodid.AzureAssignedIdentity, error) {
   580  	assignedIDMap := make(map[string]internalaadpodid.AzureAssignedIdentity)
   581  	c.mu.Lock()
   582  	for k, v := range c.assignedIDMap {
   583  		assignedIDMap[k] = *v
   584  	}
   585  	c.mu.Unlock()
   586  	return assignedIDMap, nil
   587  }
   588  
   589  func (c *Client) ListPodIds(podns, podname string) (map[string][]internalaadpodid.AzureIdentity, error) {
   590  	return map[string][]internalaadpodid.AzureIdentity{}, nil
   591  }
   592  
   593  func (c *Client) ListPodIdentityExceptions(ns string) (*[]internalaadpodid.AzurePodIdentityException, error) {
   594  	return nil, nil
   595  }
   596  
   597  func (c *TestCrdClient) SetError(err error) {
   598  	c.err = &err
   599  }
   600  
   601  func (c *TestCrdClient) UnsetError() {
   602  	c.err = nil
   603  }
   604  
   605  func (c *TestCrdClient) waitForAssignedIDs(count int) bool {
   606  	i := 0
   607  	for i < 10 {
   608  		time.Sleep(1 * time.Second)
   609  
   610  		assignedIDs, err := c.ListAssignedIDs()
   611  		if err != nil {
   612  			return false
   613  		}
   614  		if len(*assignedIDs) == count {
   615  			return true
   616  		}
   617  		i++
   618  	}
   619  	return false
   620  }
   621  
   622  /************************ NODE MOCK *************************************/
   623  
   624  type TestNodeClient struct {
   625  	mu    sync.Mutex
   626  	nodes map[string]*corev1.Node
   627  }
   628  
   629  func NewTestNodeClient() *TestNodeClient {
   630  	return &TestNodeClient{nodes: make(map[string]*corev1.Node)}
   631  }
   632  
   633  func (c *TestNodeClient) Get(name string) (*corev1.Node, error) {
   634  	c.mu.Lock()
   635  	defer c.mu.Unlock()
   636  
   637  	node, exists := c.nodes[name]
   638  	if !exists {
   639  		return nil, errors.New("node not found")
   640  	}
   641  	return node, nil
   642  }
   643  
   644  func (c *TestNodeClient) Delete(name string) {
   645  	c.mu.Lock()
   646  	delete(c.nodes, name)
   647  	c.mu.Unlock()
   648  }
   649  
   650  func (c *TestNodeClient) Start(<-chan struct{}) {}
   651  
   652  func (c *TestNodeClient) AddNode(name string, opts ...func(*corev1.Node)) {
   653  	c.mu.Lock()
   654  	defer c.mu.Unlock()
   655  
   656  	n := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: name}, Spec: corev1.NodeSpec{
   657  		ProviderID: "azure:///subscriptions/testSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachines/" + name,
   658  	}}
   659  	for _, o := range opts {
   660  		o(n)
   661  	}
   662  	c.nodes[name] = n
   663  }
   664  
   665  /************************ EVENT RECORDER MOCK *************************************/
   666  type LastEvent struct {
   667  	Type    string
   668  	Reason  string
   669  	Message string
   670  }
   671  
   672  type TestEventRecorder struct {
   673  	mu        sync.Mutex
   674  	lastEvent *LastEvent
   675  
   676  	eventChannel chan bool
   677  }
   678  
   679  func (c *TestEventRecorder) WaitForEvents(expectedCount int) bool {
   680  	count := 0
   681  	for {
   682  		select {
   683  		case <-c.eventChannel:
   684  			count++
   685  			if expectedCount == count {
   686  				return true
   687  			}
   688  		case <-time.After(2 * time.Minute):
   689  			return false
   690  		}
   691  	}
   692  }
   693  
   694  func (c *TestEventRecorder) Event(object runtime.Object, t string, r string, message string) {
   695  	c.mu.Lock()
   696  
   697  	c.lastEvent.Type = t
   698  	c.lastEvent.Reason = r
   699  	c.lastEvent.Message = message
   700  
   701  	c.mu.Unlock()
   702  
   703  	c.eventChannel <- true
   704  }
   705  
   706  func (c *TestEventRecorder) Validate(e *LastEvent) bool {
   707  	c.mu.Lock()
   708  
   709  	t := c.lastEvent.Type
   710  	r := c.lastEvent.Reason
   711  	m := c.lastEvent.Message
   712  
   713  	c.mu.Unlock()
   714  
   715  	if t != e.Type || r != e.Reason || m != e.Message {
   716  		klog.Errorf("event mismatch. expected - (t:%s, r:%s, m:%s). got - (t:%s, r:%s, m:%s)", e.Type, e.Reason, e.Message, t, r, m)
   717  		return false
   718  	}
   719  	return true
   720  }
   721  
   722  func (c *TestEventRecorder) Eventf(object runtime.Object, t string, r string, messageFmt string, args ...interface{}) {
   723  
   724  }
   725  
   726  func (c *TestEventRecorder) PastEventf(object runtime.Object, timestamp metav1.Time, t string, m1 string, messageFmt string, args ...interface{}) {
   727  
   728  }
   729  
   730  func (c *TestEventRecorder) AnnotatedEventf(object runtime.Object, annotations map[string]string, eventtype, reason, messageFmt string, args ...interface{}) {
   731  
   732  }
   733  
   734  /************************ MIC MOCK *************************************/
   735  func NewMICTestClient(eventCh chan internalaadpodid.EventType,
   736  	cpClient *TestCloudClient,
   737  	crdClient *TestCrdClient,
   738  	podClient *TestPodClient,
   739  	nodeClient *TestNodeClient,
   740  	eventRecorder *TestEventRecorder, isNamespaced bool,
   741  	createDeleteBatch int64,
   742  	immutableUserMSIs map[string]bool) *TestMICClient {
   743  
   744  	reporter, _ := metrics.NewReporter()
   745  
   746  	realMICClient := &Client{
   747  		CloudClient:                         cpClient,
   748  		CRDClient:                           crdClient,
   749  		EventRecorder:                       eventRecorder,
   750  		PodClient:                           podClient,
   751  		EventChannel:                        eventCh,
   752  		NodeClient:                          nodeClient,
   753  		syncRetryInterval:                   120 * time.Second,
   754  		IsNamespaced:                        isNamespaced,
   755  		createDeleteBatch:                   createDeleteBatch,
   756  		ImmutableUserMSIsMap:                immutableUserMSIs,
   757  		Reporter:                            reporter,
   758  		identityAssignmentReconcileInterval: 3 * time.Minute,
   759  	}
   760  
   761  	return &TestMICClient{
   762  		realMICClient,
   763  	}
   764  }
   765  
   766  type TestMICClient struct {
   767  	*Client
   768  }
   769  
   770  /************************ UNIT TEST *************************************/
   771  
   772  func TestMapMICClient_1(t *testing.T) {
   773  	idList := []internalaadpodid.AzureIdentity{
   774  		{
   775  			ObjectMeta: metav1.ObjectMeta{
   776  				Name:      "testazid1",
   777  				Namespace: "default",
   778  			},
   779  			Spec: internalaadpodid.AzureIdentitySpec{
   780  				ResourceID: testResourceID,
   781  			},
   782  		},
   783  		{
   784  			ObjectMeta: metav1.ObjectMeta{
   785  				Name:      "testazid2",
   786  				Namespace: "ns00",
   787  			},
   788  			Spec: internalaadpodid.AzureIdentitySpec{
   789  				ResourceID: testResourceID,
   790  			},
   791  		},
   792  		{
   793  			ObjectMeta: metav1.ObjectMeta{
   794  				Name:      "testazid3",
   795  				Namespace: "default",
   796  			},
   797  			Spec: internalaadpodid.AzureIdentitySpec{
   798  				ResourceID: "testResourceID",
   799  			},
   800  		},
   801  		{
   802  			ObjectMeta: metav1.ObjectMeta{
   803  				Name:      "testazid5",
   804  				Namespace: "default",
   805  			},
   806  			Spec: internalaadpodid.AzureIdentitySpec{
   807  				Type:     internalaadpodid.ServicePrincipal,
   808  				TenantID: "tenantid",
   809  				ClientID: "clientid",
   810  			},
   811  		},
   812  	}
   813  
   814  	micClient := &TestMICClient{}
   815  	idMap, err := micClient.convertIDListToMap(idList)
   816  	if err != nil {
   817  		t.Fatalf("expected err to be nil, got: %+v", err)
   818  	}
   819  
   820  	tests := []struct {
   821  		name        string
   822  		idName      string
   823  		idNamespace string
   824  		shouldExist bool
   825  	}{
   826  		{
   827  			name:        "default/testazid1 exists",
   828  			idName:      "testazid1",
   829  			idNamespace: "default",
   830  			shouldExist: true,
   831  		},
   832  		{
   833  			name:        "ns00/testazid2 in ns00 ns exists",
   834  			idName:      "testazid2",
   835  			idNamespace: "ns00",
   836  			shouldExist: true,
   837  		},
   838  		{
   839  			name:        "default/testazid3 doesn't exist as resource id invalid",
   840  			idName:      "testazid3",
   841  			idNamespace: "default",
   842  			shouldExist: false,
   843  		},
   844  		{
   845  			name:        "default/testazid4 doesn't exist",
   846  			idName:      "testazid4",
   847  			idNamespace: "default",
   848  			shouldExist: false,
   849  		},
   850  		{
   851  			name:        "ns00/testazid1 doesn't exist",
   852  			idName:      "testazid1",
   853  			idNamespace: "ns00",
   854  			shouldExist: false,
   855  		},
   856  		{
   857  			name:        "default/testazid5 for type 1 does exist",
   858  			idName:      "testazid5",
   859  			idNamespace: "default",
   860  			shouldExist: true,
   861  		},
   862  	}
   863  
   864  	for _, test := range tests {
   865  		t.Run(test.name, func(t *testing.T) {
   866  			azureID, idPresent := idMap[getIDKey(test.idNamespace, test.idName)]
   867  			if test.shouldExist != idPresent {
   868  				t.Fatalf("expected exist: %v, but identity %s/%s in map exist: %v",
   869  					test.shouldExist, test.idNamespace, test.idName, idPresent)
   870  			}
   871  			if test.shouldExist && (azureID.Name != test.idName || azureID.Namespace != test.idNamespace) {
   872  				t.Fatalf("expected identity %s/%s, got %s/%s", test.idNamespace, test.idName, azureID.Namespace, azureID.Name)
   873  			}
   874  		})
   875  	}
   876  }
   877  
   878  func (c *TestMICClient) testRunSync() func(t *testing.T) {
   879  	done := make(chan struct{})
   880  	exit := make(chan struct{})
   881  	var closeOnce sync.Once
   882  
   883  	go func() {
   884  		c.Sync(exit)
   885  		close(done)
   886  	}()
   887  
   888  	return func(t *testing.T) {
   889  		t.Helper()
   890  
   891  		closeOnce.Do(func() {
   892  			close(exit)
   893  		})
   894  
   895  		timeout := time.NewTimer(30 * time.Second)
   896  		defer timeout.Stop()
   897  
   898  		select {
   899  		case <-done:
   900  		case <-timeout.C:
   901  			t.Fatal("timeout waiting for sync to exit")
   902  		}
   903  	}
   904  }
   905  
   906  func TestSimpleMICClient(t *testing.T) {
   907  	eventCh := make(chan internalaadpodid.EventType, 100)
   908  	cloudClient := NewTestCloudClient(config.AzureConfig{})
   909  	crdClient := NewTestCrdClient(nil)
   910  	podClient := NewTestPodClient()
   911  	nodeClient := NewTestNodeClient()
   912  	var evtRecorder TestEventRecorder
   913  	evtRecorder.lastEvent = new(LastEvent)
   914  	evtRecorder.eventChannel = make(chan bool, 100)
   915  
   916  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
   917  
   918  	crdClient.CreateID("test-id", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
   919  	crdClient.CreateBinding("testbinding", "default", "test-id", "test-select", "")
   920  
   921  	nodeClient.AddNode("test-node")
   922  	podClient.AddPod("test-pod", "default", "test-node", "test-select")
   923  
   924  	eventCh <- internalaadpodid.PodCreated
   925  	defer micClient.testRunSync()(t)
   926  
   927  	evtRecorder.WaitForEvents(1)
   928  	if !crdClient.waitForAssignedIDs(1) {
   929  		t.Fatalf("expected len of assigned identities to be 0")
   930  	}
   931  	listAssignedIDs, err := crdClient.ListAssignedIDs()
   932  	if err != nil {
   933  		t.Fatalf("list assigned ids failed , error: %+v", err)
   934  	}
   935  
   936  	assignedID := (*listAssignedIDs)[0]
   937  	if !(assignedID.Spec.Pod == "test-pod" && assignedID.Spec.PodNamespace == "default" && assignedID.Spec.NodeName == "test-node" &&
   938  		assignedID.Spec.AzureBindingRef.Name == "testbinding" && assignedID.Spec.AzureIdentityRef.Name == "test-id") {
   939  		t.Fatalf("assigned ID spec: %v mismatch", assignedID)
   940  	}
   941  
   942  	// Test2: Remove assigned id event test
   943  	podClient.DeletePod("test-pod", "default")
   944  	eventCh <- internalaadpodid.PodDeleted
   945  	if !crdClient.waitForAssignedIDs(0) {
   946  		t.Fatalf("expected len of assigned identities to be 0")
   947  	}
   948  
   949  	// Test3: Error from cloud provider event test
   950  	err = errors.New("error returned from cloud provider")
   951  	cloudClient.SetError(err)
   952  
   953  	podClient.AddPod("test-pod", "default", "test-node", "test-select")
   954  	eventCh <- internalaadpodid.PodCreated
   955  	evtRecorder.WaitForEvents(1)
   956  	if !crdClient.waitForAssignedIDs(1) {
   957  		t.Fatalf("expected len of assigned identities to be 1")
   958  	}
   959  	listAssignedIDs, err = crdClient.ListAssignedIDs()
   960  	if err != nil {
   961  		t.Fatalf("list assigned ids failed , error: %+v", err)
   962  	}
   963  	if (*listAssignedIDs)[0].Status.Status != aadpodid.AssignedIDCreated {
   964  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDCreated, (*listAssignedIDs)[0].Status.Status)
   965  	}
   966  }
   967  
   968  func TestUpdateAssignedIdentities(t *testing.T) {
   969  	eventCh := make(chan internalaadpodid.EventType, 100)
   970  	cloudClient := NewTestCloudClient(config.AzureConfig{})
   971  	crdClient := NewTestCrdClient(nil)
   972  	podClient := NewTestPodClient()
   973  	nodeClient := NewTestNodeClient()
   974  	var evtRecorder TestEventRecorder
   975  	evtRecorder.lastEvent = new(LastEvent)
   976  	evtRecorder.eventChannel = make(chan bool, 100)
   977  
   978  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
   979  
   980  	crdClient.CreateID("test-id", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "rv1")
   981  	crdClient.CreateBinding("testbinding", "default", "test-id", "test-select", "")
   982  
   983  	nodeClient.AddNode("test-node")
   984  	podClient.AddPod("test-pod", "default", "test-node", "test-select")
   985  
   986  	eventCh <- internalaadpodid.PodCreated
   987  	defer micClient.testRunSync()(t)
   988  
   989  	evtRecorder.WaitForEvents(1)
   990  	if !crdClient.waitForAssignedIDs(1) {
   991  		t.Fatalf("expected len of assigned identities to be 1")
   992  	}
   993  	listAssignedIDs, err := crdClient.ListAssignedIDs()
   994  	if err != nil {
   995  		t.Fatalf("list assigned ids failed , error: %+v", err)
   996  	}
   997  
   998  	assignedID := (*listAssignedIDs)[0]
   999  	if !(assignedID.Spec.Pod == "test-pod" && assignedID.Spec.PodNamespace == "default" && assignedID.Spec.NodeName == "test-node" &&
  1000  		assignedID.Spec.AzureBindingRef.Name == "testbinding" && assignedID.Spec.AzureIdentityRef.Name == "test-id" &&
  1001  		assignedID.Spec.AzureIdentityRef.ResourceVersion == "rv1" && assignedID.Spec.AzureIdentityRef.Spec.ClientID == "test-user-msi-clientid") {
  1002  		t.Fatalf("assigned ID spec: %v mismatch", assignedID)
  1003  	}
  1004  
  1005  	newResourceID := testResourceID + "-new"
  1006  	crdClient.CreateID("test-id", "default", aadpodid.UserAssignedMSI, newResourceID, "test-user-msi-clientid", nil, "", "", "", "changedrv2")
  1007  	crdClient.CreateID("test-id-2", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "rv2")
  1008  	crdClient.CreateBinding("testbinding2", "default", "test-id-2", "test-select", "")
  1009  
  1010  	eventCh <- internalaadpodid.IdentityUpdated
  1011  	eventCh <- internalaadpodid.IdentityCreated
  1012  	eventCh <- internalaadpodid.BindingCreated
  1013  
  1014  	evtRecorder.WaitForEvents(1)
  1015  	if !crdClient.waitForAssignedIDs(2) {
  1016  		t.Fatalf("expected len of assigned identities to be 2")
  1017  	}
  1018  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1019  	if err != nil {
  1020  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1021  	}
  1022  	// check updated assigned identity has the right resource id
  1023  	if listAssignedIDs != nil {
  1024  		for _, assignedID := range *listAssignedIDs {
  1025  			if assignedID.Name != "test-pod-default-test-id" {
  1026  				continue
  1027  			}
  1028  			if !(assignedID.Spec.Pod == "test-pod" && assignedID.Spec.PodNamespace == "default" && assignedID.Spec.NodeName == "test-node" &&
  1029  				assignedID.Spec.AzureBindingRef.Name == "testbinding" && assignedID.Spec.AzureIdentityRef.Name == "test-id" &&
  1030  				assignedID.Spec.AzureIdentityRef.ResourceVersion == "changedrv2" && assignedID.Spec.AzureIdentityRef.Spec.ClientID == "test-user-msi-clientid" &&
  1031  				assignedID.Spec.AzureIdentityRef.Spec.ResourceID == newResourceID) {
  1032  				t.Fatalf("assigned ID spec: %v mismatch", assignedID)
  1033  			}
  1034  		}
  1035  	}
  1036  
  1037  	// test pod with same name moving to a new node
  1038  	// the nodename label should be updated to test-node2
  1039  	nodeClient.AddNode("test-node2")
  1040  	podClient.DeletePod("test-pod", "default")
  1041  	podClient.AddPod("test-pod", "default", "test-node2", "test-select")
  1042  
  1043  	eventCh <- internalaadpodid.PodUpdated
  1044  
  1045  	evtRecorder.WaitForEvents(1)
  1046  	if !crdClient.waitForAssignedIDs(2) {
  1047  		t.Fatalf("expected len of assigned identities to be 2")
  1048  	}
  1049  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1050  	if err != nil {
  1051  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1052  	}
  1053  	// check updated assigned identity has the right resource id
  1054  	if listAssignedIDs != nil {
  1055  		for _, assignedID := range *listAssignedIDs {
  1056  			if assignedID.Name != "test-pod-default-test-id" {
  1057  				continue
  1058  			}
  1059  			if assignedID.ObjectMeta.Labels["nodename"] != "test-node2" {
  1060  				t.Fatalf("expected node name: test-node2, got: %s", assignedID.ObjectMeta.Labels["nodename"])
  1061  			}
  1062  		}
  1063  	}
  1064  }
  1065  
  1066  func TestAddUpdateDel(t *testing.T) {
  1067  	eventCh := make(chan internalaadpodid.EventType, 100)
  1068  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1069  	crdClient := NewTestCrdClient(nil)
  1070  	podClient := NewTestPodClient()
  1071  	nodeClient := NewTestNodeClient()
  1072  	var evtRecorder TestEventRecorder
  1073  	evtRecorder.lastEvent = new(LastEvent)
  1074  	evtRecorder.eventChannel = make(chan bool, 100)
  1075  
  1076  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1077  
  1078  	crdClient.CreateID("test-id-0", "default", aadpodid.UserAssignedMSI, fmt.Sprintf("%s-%d", testResourceID, 0), "test-user-msi-clientid-0", nil, "", "", "", "rv-0")
  1079  	crdClient.CreateBinding("testbinding-0", "default", "test-id-0", "test-select-0", "")
  1080  
  1081  	crdClient.CreateID("test-id-1", "default", aadpodid.UserAssignedMSI, fmt.Sprintf("%s-%d", testResourceID, 1), "test-user-msi-clientid-1", nil, "", "", "", "rv-1")
  1082  	crdClient.CreateBinding("testbinding-1", "default", "test-id-1", "test-select-1", "")
  1083  
  1084  	crdClient.CreateID("test-id-2", "default", aadpodid.UserAssignedMSI, fmt.Sprintf("%s-%d", testResourceID, 2), "test-user-msi-clientid-2", nil, "", "", "", "rv-2")
  1085  	crdClient.CreateBinding("testbinding-2", "default", "test-id-2", "test-select-2", "")
  1086  
  1087  	nodeClient.AddNode("test-node")
  1088  	podClient.AddPod("test-pod-0", "default", "test-node", "test-select-0")
  1089  	podClient.AddPod("test-pod-1", "default", "test-node", "test-select-1")
  1090  	podClient.AddPod("test-pod-2", "default", "test-node", "test-select-2")
  1091  
  1092  	eventCh <- internalaadpodid.PodCreated
  1093  	eventCh <- internalaadpodid.PodCreated
  1094  	eventCh <- internalaadpodid.PodCreated
  1095  
  1096  	defer micClient.testRunSync()(t)
  1097  
  1098  	evtRecorder.WaitForEvents(3)
  1099  	if !crdClient.waitForAssignedIDs(3) {
  1100  		t.Fatalf("expected len of assigned identities to be 3")
  1101  	}
  1102  
  1103  	crdClient.CreateID("test-id-0", "default", aadpodid.UserAssignedMSI, fmt.Sprintf("%s-%d", testResourceID, 4), "test-user-msi-clientid-4", nil, "", "", "", "updated-rv-0")
  1104  	crdClient.CreateID("test-id-3", "default", aadpodid.UserAssignedMSI, fmt.Sprintf("%s-%d", testResourceID, 3), "test-user-msi-clientid-3", nil, "", "", "", "rv-3")
  1105  	crdClient.CreateBinding("testbinding-3", "default", "test-id-3", "test-select-2", "")
  1106  	podClient.DeletePod("test-pod-1", "default")
  1107  
  1108  	eventCh <- internalaadpodid.IdentityCreated
  1109  	eventCh <- internalaadpodid.BindingCreated
  1110  	eventCh <- internalaadpodid.IdentityUpdated
  1111  	eventCh <- internalaadpodid.PodDeleted
  1112  
  1113  	evtRecorder.WaitForEvents(2)
  1114  	if !crdClient.waitForAssignedIDs(3) {
  1115  		t.Fatalf("expected len of assigned identities to be 1")
  1116  	}
  1117  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1118  	if err != nil {
  1119  		t.Fatalf("failed to list assigned ids, error: %+v", err)
  1120  	}
  1121  	// check the updated identity has the correct azureid ref
  1122  	for _, assignedID := range *listAssignedIDs {
  1123  		if assignedID.Name != "test-pod-0-default-test-id-0" {
  1124  			continue
  1125  		}
  1126  		if !(assignedID.Spec.Pod == "test-pod-0" && assignedID.Spec.PodNamespace == "default" && assignedID.Spec.NodeName == "test-node" &&
  1127  			assignedID.Spec.AzureBindingRef.Name == "testbinding-0" && assignedID.Spec.AzureIdentityRef.Name == "test-id-0" &&
  1128  			assignedID.Spec.AzureIdentityRef.ResourceVersion == "updated-rv-0" && assignedID.Spec.AzureIdentityRef.Spec.ClientID == "test-user-msi-clientid-4" &&
  1129  			assignedID.Spec.AzureIdentityRef.Spec.ResourceID == fmt.Sprintf("%s-%d", testResourceID, 4)) {
  1130  			t.Fatalf("azure identity spec mismatch")
  1131  		}
  1132  	}
  1133  }
  1134  
  1135  func TestAddDelMICClient(t *testing.T) {
  1136  	eventCh := make(chan internalaadpodid.EventType, 100)
  1137  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1138  	crdClient := NewTestCrdClient(nil)
  1139  	podClient := NewTestPodClient()
  1140  	nodeClient := NewTestNodeClient()
  1141  	var evtRecorder TestEventRecorder
  1142  	evtRecorder.lastEvent = new(LastEvent)
  1143  	evtRecorder.eventChannel = make(chan bool, 100)
  1144  
  1145  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1146  
  1147  	// Test to add and delete at the same time.
  1148  	// Add a pod, identity and binding.
  1149  	crdClient.CreateID("test-id2", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1150  	crdClient.CreateBinding("testbinding2", "default", "test-id2", "test-select2", "")
  1151  
  1152  	nodeClient.AddNode("test-node2")
  1153  	podClient.AddPod("test-pod2", "default", "test-node2", "test-select2")
  1154  
  1155  	crdClient.CreateID("test-id4", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1156  	crdClient.CreateBinding("testbinding4", "default", "test-id4", "test-select4", "")
  1157  	podClient.AddPod("test-pod4", "default", "test-node2", "test-select4")
  1158  
  1159  	eventCh <- internalaadpodid.PodCreated
  1160  	eventCh <- internalaadpodid.PodCreated
  1161  
  1162  	stopSync1 := micClient.testRunSync()
  1163  	defer stopSync1(t)
  1164  
  1165  	if !evtRecorder.WaitForEvents(2) {
  1166  		t.Fatalf("Timeout waiting for mic sync cycles")
  1167  	}
  1168  	if !crdClient.waitForAssignedIDs(2) {
  1169  		t.Fatalf("expected len of assigned identities to be 2")
  1170  	}
  1171  
  1172  	// Delete the pod
  1173  	podClient.DeletePod("test-pod2", "default")
  1174  	podClient.DeletePod("test-pod4", "default")
  1175  
  1176  	// Add a new pod, with different id and binding on the same node.
  1177  	crdClient.CreateID("test-id3", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1178  	crdClient.CreateBinding("testbinding3", "default", "test-id3", "test-select3", "")
  1179  	podClient.AddPod("test-pod3", "default", "test-node2", "test-select3")
  1180  
  1181  	eventCh <- internalaadpodid.PodCreated
  1182  	eventCh <- internalaadpodid.PodDeleted
  1183  	eventCh <- internalaadpodid.PodDeleted
  1184  
  1185  	stopSync1(t)
  1186  	defer micClient.testRunSync()(t)
  1187  
  1188  	if !crdClient.waitForAssignedIDs(1) {
  1189  		t.Fatalf("expected len of assigned identities to be 1")
  1190  	}
  1191  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1192  	if err != nil {
  1193  		t.Fatalf("list assigned failed")
  1194  	}
  1195  
  1196  	assignedID := (*listAssignedIDs)[0].Name
  1197  	expectedID := "test-pod3-default-test-id3"
  1198  	if assignedID != expectedID {
  1199  		t.Fatalf("Expected %s. Got: %s", expectedID, assignedID)
  1200  	}
  1201  }
  1202  
  1203  func TestMicAddDelVMSS(t *testing.T) {
  1204  	eventCh := make(chan internalaadpodid.EventType, 100)
  1205  	cloudClient := NewTestCloudClient(config.AzureConfig{VMType: "vmss"})
  1206  	crdClient := NewTestCrdClient(nil)
  1207  	podClient := NewTestPodClient()
  1208  	nodeClient := NewTestNodeClient()
  1209  	var evtRecorder TestEventRecorder
  1210  	evtRecorder.lastEvent = new(LastEvent)
  1211  	evtRecorder.eventChannel = make(chan bool, 100)
  1212  
  1213  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1214  
  1215  	// Test to add and delete at the same time.
  1216  	// Add a pod, identity and binding.
  1217  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1218  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "")
  1219  
  1220  	nodeClient.AddNode("test-node1", func(n *corev1.Node) {
  1221  		n.Spec.ProviderID = "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/testvmss1/virtualMachines/0"
  1222  	})
  1223  
  1224  	nodeClient.AddNode("test-node2", func(n *corev1.Node) {
  1225  		n.Spec.ProviderID = "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/testvmss1/virtualMachines/1"
  1226  	})
  1227  
  1228  	nodeClient.AddNode("test-node3", func(n *corev1.Node) {
  1229  		n.Spec.ProviderID = "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/testvmss2/virtualMachines/0"
  1230  	})
  1231  
  1232  	podClient.AddPod("test-pod1", "default", "test-node1", "test-select1")
  1233  	podClient.AddPod("test-pod2", "default", "test-node2", "test-select1")
  1234  	podClient.AddPod("test-pod3", "default", "test-node3", "test-select1")
  1235  
  1236  	defer micClient.testRunSync()(t)
  1237  
  1238  	eventCh <- internalaadpodid.PodCreated
  1239  	eventCh <- internalaadpodid.PodCreated
  1240  	eventCh <- internalaadpodid.PodCreated
  1241  	if !evtRecorder.WaitForEvents(3) {
  1242  		t.Fatalf("Timeout waiting for mic sync cycles")
  1243  	}
  1244  	if !crdClient.waitForAssignedIDs(3) {
  1245  		t.Fatalf("expected len of assigned identities to be 3")
  1246  	}
  1247  
  1248  	if !cloudClient.CompareMSI("testvmss1", []string{testResourceID}) {
  1249  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss1"])
  1250  	}
  1251  	if !cloudClient.CompareMSI("testvmss2", []string{testResourceID}) {
  1252  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss2"])
  1253  	}
  1254  
  1255  	podClient.DeletePod("test-pod1", "default")
  1256  	eventCh <- internalaadpodid.PodDeleted
  1257  
  1258  	if !crdClient.waitForAssignedIDs(2) {
  1259  		t.Fatalf("expected len of assigned identities to be 2")
  1260  	}
  1261  	if !cloudClient.CompareMSI("testvmss1", []string{testResourceID}) {
  1262  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss1"])
  1263  	}
  1264  	if !cloudClient.CompareMSI("testvmss2", []string{testResourceID}) {
  1265  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss2"])
  1266  	}
  1267  
  1268  	podClient.DeletePod("test-pod2", "default")
  1269  	eventCh <- internalaadpodid.PodDeleted
  1270  
  1271  	if !crdClient.waitForAssignedIDs(1) {
  1272  		t.Fatalf("expected len of assigned identities to be 1")
  1273  	}
  1274  	if !cloudClient.CompareMSI("testvmss1", []string{}) {
  1275  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss1"])
  1276  	}
  1277  	if !cloudClient.CompareMSI("testvmss2", []string{testResourceID}) {
  1278  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss2"])
  1279  	}
  1280  }
  1281  
  1282  func TestMICStateFlow(t *testing.T) {
  1283  	eventCh := make(chan internalaadpodid.EventType, 100)
  1284  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1285  	crdClient := NewTestCrdClient(nil)
  1286  	podClient := NewTestPodClient()
  1287  	nodeClient := NewTestNodeClient()
  1288  	var evtRecorder TestEventRecorder
  1289  	evtRecorder.lastEvent = new(LastEvent)
  1290  	evtRecorder.eventChannel = make(chan bool, 100)
  1291  
  1292  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1293  
  1294  	// Add a pod, identity and binding.
  1295  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1296  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "")
  1297  
  1298  	nodeClient.AddNode("test-node1")
  1299  	podClient.AddPod("test-pod1", "default", "test-node1", "test-select1")
  1300  
  1301  	eventCh <- internalaadpodid.PodCreated
  1302  	defer micClient.testRunSync()(t)
  1303  
  1304  	if !evtRecorder.WaitForEvents(1) {
  1305  		t.Fatalf("Timeout waiting for mic sync cycles")
  1306  	}
  1307  	if !crdClient.waitForAssignedIDs(1) {
  1308  		t.Fatalf("expected len of assigned identities to be 1")
  1309  	}
  1310  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1311  	if err != nil {
  1312  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1313  	}
  1314  	if !((*listAssignedIDs)[0].Status.Status == aadpodid.AssignedIDAssigned) {
  1315  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[0].Status.Status)
  1316  	}
  1317  
  1318  	// delete the pod, simulate failure in cloud calls on trying to un-assign identity from node
  1319  	podClient.DeletePod("test-pod1", "default")
  1320  	// SetError sets error in crd client only for remove assigned identity
  1321  	cloudClient.SetError(errors.New("error removing identity from node"))
  1322  	cloudClient.testVMClient.identity = &compute.VirtualMachineIdentity{
  1323  		UserAssignedIdentities: map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue{
  1324  			testResourceID: {},
  1325  		},
  1326  	}
  1327  
  1328  	eventCh <- internalaadpodid.PodDeleted
  1329  	if !crdClient.waitForAssignedIDs(1) {
  1330  		t.Fatalf("expected len of assigned identities to be 1")
  1331  	}
  1332  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1333  	if err != nil {
  1334  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1335  	}
  1336  	if !((*listAssignedIDs)[0].Status.Status == aadpodid.AssignedIDAssigned) {
  1337  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[0].Status.Status)
  1338  	}
  1339  
  1340  	crdClient.SetError(errors.New("error from crd client"))
  1341  
  1342  	// add new pod, this time the old assigned identity which is in Assigned state should be tried to delete
  1343  	// simulate failure on kube api call to delete crd
  1344  	crdClient.CreateID("test-id2", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid2", nil, "", "", "", "")
  1345  	crdClient.CreateBinding("testbinding2", "default", "test-id2", "test-select2", "")
  1346  
  1347  	nodeClient.AddNode("test-node2")
  1348  	podClient.AddPod("test-pod2", "default", "test-node2", "test-select2")
  1349  
  1350  	eventCh <- internalaadpodid.PodCreated
  1351  	if !evtRecorder.WaitForEvents(1) {
  1352  		t.Fatalf("Timeout waiting for mic sync cycles")
  1353  	}
  1354  	if !crdClient.waitForAssignedIDs(2) {
  1355  		t.Fatalf("expected len of assigned identities to be 2")
  1356  	}
  1357  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1358  	if err != nil {
  1359  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1360  	}
  1361  	for _, assignedID := range *listAssignedIDs {
  1362  		if assignedID.Spec.Pod == "test-pod1" {
  1363  			if assignedID.Status.Status != aadpodid.AssignedIDUnAssigned {
  1364  				t.Fatalf("Expected status to be: %s. Got: %s", aadpodid.AssignedIDUnAssigned, assignedID.Status.Status)
  1365  			}
  1366  		}
  1367  		if assignedID.Spec.Pod == "test-pod2" {
  1368  			if assignedID.Status.Status != aadpodid.AssignedIDAssigned {
  1369  				t.Fatalf("Expected status to be: %s. Got: %s", aadpodid.AssignedIDAssigned, assignedID.Status.Status)
  1370  			}
  1371  		}
  1372  	}
  1373  	crdClient.UnsetError()
  1374  
  1375  	// delete pod2 and everything should be cleaned up now
  1376  	podClient.DeletePod("test-pod2", "default")
  1377  	eventCh <- internalaadpodid.PodDeleted
  1378  	if !crdClient.waitForAssignedIDs(0) {
  1379  		t.Fatalf("expected len of assigned identities to be 0")
  1380  	}
  1381  }
  1382  
  1383  func TestForceNamespaced(t *testing.T) {
  1384  	eventCh := make(chan internalaadpodid.EventType, 100)
  1385  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1386  	crdClient := NewTestCrdClient(nil)
  1387  	podClient := NewTestPodClient()
  1388  	nodeClient := NewTestNodeClient()
  1389  	var evtRecorder TestEventRecorder
  1390  	evtRecorder.lastEvent = new(LastEvent)
  1391  	evtRecorder.eventChannel = make(chan bool, 100)
  1392  
  1393  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, true, 4, nil)
  1394  
  1395  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "idrv1")
  1396  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "bindingrv1")
  1397  
  1398  	nodeClient.AddNode("test-node1")
  1399  	podClient.AddPod("test-pod1", "default", "test-node1", "test-select1")
  1400  
  1401  	eventCh <- internalaadpodid.PodCreated
  1402  	defer micClient.testRunSync()(t)
  1403  
  1404  	if !evtRecorder.WaitForEvents(1) {
  1405  		t.Fatalf("Timeout waiting for mic sync cycles")
  1406  	}
  1407  	if !crdClient.waitForAssignedIDs(1) {
  1408  		t.Fatalf("expected len of assigned identities to be 1")
  1409  	}
  1410  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1411  	if err != nil {
  1412  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1413  	}
  1414  	if !((*listAssignedIDs)[0].Status.Status == aadpodid.AssignedIDAssigned) {
  1415  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[0].Status.Status)
  1416  	}
  1417  
  1418  	crdClient.CreateID("test-id1", "default2", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "idrv2")
  1419  	crdClient.CreateBinding("testbinding1", "default2", "test-id1", "test-select1", "bindingrv2")
  1420  	podClient.AddPod("test-pod2", "default2", "test-node1", "test-select1")
  1421  
  1422  	eventCh <- internalaadpodid.IdentityCreated
  1423  	eventCh <- internalaadpodid.BindingCreated
  1424  	eventCh <- internalaadpodid.PodCreated
  1425  
  1426  	if !evtRecorder.WaitForEvents(1) {
  1427  		t.Fatalf("Timeout waiting for mic sync cycles")
  1428  	}
  1429  	if !crdClient.waitForAssignedIDs(2) {
  1430  		t.Fatalf("expected len of assigned identities to be 2")
  1431  	}
  1432  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1433  	if err != nil {
  1434  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1435  	}
  1436  
  1437  	for _, assignedID := range *listAssignedIDs {
  1438  		if !(assignedID.Status.Status == aadpodid.AssignedIDAssigned) {
  1439  			t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[0].Status.Status)
  1440  		}
  1441  	}
  1442  }
  1443  
  1444  func TestSyncRetryLoop(t *testing.T) {
  1445  	eventCh := make(chan internalaadpodid.EventType, 100)
  1446  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1447  	crdClient := NewTestCrdClient(nil)
  1448  	podClient := NewTestPodClient()
  1449  	nodeClient := NewTestNodeClient()
  1450  	var evtRecorder TestEventRecorder
  1451  	evtRecorder.lastEvent = new(LastEvent)
  1452  	evtRecorder.eventChannel = make(chan bool, 100)
  1453  
  1454  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1455  	syncRetryInterval, err := time.ParseDuration("5s")
  1456  	if err != nil {
  1457  		t.Fatalf("error parsing duration: %v", err)
  1458  	}
  1459  	micClient.syncRetryInterval = syncRetryInterval
  1460  
  1461  	// Add a pod, identity and binding.
  1462  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1463  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "")
  1464  
  1465  	nodeClient.AddNode("test-node1")
  1466  	podClient.AddPod("test-pod1", "default", "test-node1", "test-select1")
  1467  
  1468  	eventCh <- internalaadpodid.PodCreated
  1469  	defer micClient.testRunSync()(t)
  1470  
  1471  	if !evtRecorder.WaitForEvents(1) {
  1472  		t.Fatalf("Timeout waiting for mic sync cycles")
  1473  	}
  1474  	if !crdClient.waitForAssignedIDs(1) {
  1475  		t.Fatalf("expected len of assigned identities to be 1")
  1476  	}
  1477  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1478  	if err != nil {
  1479  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1480  	}
  1481  	if !((*listAssignedIDs)[0].Status.Status == aadpodid.AssignedIDAssigned) {
  1482  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[0].Status.Status)
  1483  	}
  1484  
  1485  	// delete the pod, simulate failure in cloud calls on trying to un-assign identity from node
  1486  	podClient.DeletePod("test-pod1", "default")
  1487  	cloudClient.SetError(errors.New("error removing identity from node"))
  1488  	cloudClient.testVMClient.identity = &compute.VirtualMachineIdentity{
  1489  		UserAssignedIdentities: map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue{
  1490  			testResourceID: {},
  1491  		},
  1492  	}
  1493  
  1494  	eventCh <- internalaadpodid.PodDeleted
  1495  	if !crdClient.waitForAssignedIDs(1) {
  1496  		t.Fatalf("expected len of assigned identities to be 1")
  1497  	}
  1498  
  1499  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1500  	if err != nil {
  1501  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1502  	}
  1503  	if !((*listAssignedIDs)[0].Status.Status == aadpodid.AssignedIDAssigned) {
  1504  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[0].Status.Status)
  1505  	}
  1506  
  1507  	// mic should automatically retry and delete assigned identity
  1508  	if !crdClient.waitForAssignedIDs(0) {
  1509  		t.Fatalf("expected len of assigned identities to be 0")
  1510  	}
  1511  }
  1512  
  1513  func TestSyncNodeNotFound(t *testing.T) {
  1514  	eventCh := make(chan internalaadpodid.EventType, 100)
  1515  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1516  	crdClient := NewTestCrdClient(nil)
  1517  	podClient := NewTestPodClient()
  1518  	nodeClient := NewTestNodeClient()
  1519  	var evtRecorder TestEventRecorder
  1520  	evtRecorder.lastEvent = new(LastEvent)
  1521  	evtRecorder.eventChannel = make(chan bool, 100)
  1522  
  1523  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1524  
  1525  	// Add a pod, identity and binding.
  1526  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1527  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "")
  1528  
  1529  	for i := 0; i < 10; i++ {
  1530  		nodeClient.AddNode(fmt.Sprintf("test-node%d", i))
  1531  		podClient.AddPod(fmt.Sprintf("test-pod%d", i), "default", fmt.Sprintf("test-node%d", i), "test-select1")
  1532  		eventCh <- internalaadpodid.PodCreated
  1533  	}
  1534  
  1535  	defer micClient.testRunSync()(t)
  1536  
  1537  	if !evtRecorder.WaitForEvents(10) {
  1538  		t.Fatalf("Timeout waiting for mic sync cycles")
  1539  	}
  1540  	if !crdClient.waitForAssignedIDs(10) {
  1541  		t.Fatalf("expected len of assigned identities to be 10")
  1542  	}
  1543  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1544  	if err != nil {
  1545  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1546  	}
  1547  	for i := range *listAssignedIDs {
  1548  		if !((*listAssignedIDs)[i].Status.Status == aadpodid.AssignedIDAssigned) {
  1549  			t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[i].Status.Status)
  1550  		}
  1551  	}
  1552  
  1553  	// delete 5 nodes
  1554  	for i := 5; i < 10; i++ {
  1555  		nodeClient.Delete(fmt.Sprintf("test-node%d", i))
  1556  		podClient.DeletePod(fmt.Sprintf("test-pod%d", i), "default")
  1557  		eventCh <- internalaadpodid.PodDeleted
  1558  	}
  1559  
  1560  	nodeClient.AddNode("test-nodex")
  1561  	podClient.AddPod("test-podx", "default", "test-node1", "test-select1")
  1562  	eventCh <- internalaadpodid.PodCreated
  1563  
  1564  	if !evtRecorder.WaitForEvents(1) {
  1565  		t.Fatalf("Timeout waiting for mic sync cycles")
  1566  	}
  1567  	if !crdClient.waitForAssignedIDs(6) {
  1568  		t.Fatalf("expected len of assigned identities to be 6")
  1569  	}
  1570  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1571  	if err != nil {
  1572  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1573  	}
  1574  	for i := range *listAssignedIDs {
  1575  		if !((*listAssignedIDs)[i].Status.Status == aadpodid.AssignedIDAssigned) {
  1576  			t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, (*listAssignedIDs)[i].Status.Status)
  1577  		}
  1578  	}
  1579  }
  1580  
  1581  func TestProcessingTimeForScale(t *testing.T) {
  1582  	eventCh := make(chan internalaadpodid.EventType, 20000)
  1583  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1584  	crdClient := NewTestCrdClient(nil)
  1585  	podClient := NewTestPodClient()
  1586  	nodeClient := NewTestNodeClient()
  1587  	var evtRecorder TestEventRecorder
  1588  	evtRecorder.lastEvent = new(LastEvent)
  1589  	evtRecorder.eventChannel = make(chan bool, 20000)
  1590  
  1591  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1592  
  1593  	// Add a pod, identity and binding.
  1594  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1595  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "")
  1596  
  1597  	nodeClient.AddNode("test-node1")
  1598  	for i := 0; i < 20000; i++ {
  1599  		podClient.AddPod(fmt.Sprintf("test-pod%d", i), "default", "test-node1", "test-select1")
  1600  	}
  1601  	eventCh <- internalaadpodid.PodCreated
  1602  
  1603  	defer micClient.testRunSync()(t)
  1604  
  1605  	if !evtRecorder.WaitForEvents(20000) {
  1606  		t.Fatalf("Timeout waiting for mic sync cycles")
  1607  	}
  1608  
  1609  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1610  	if err != nil {
  1611  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1612  	}
  1613  	if !(len(*listAssignedIDs) == 20000) {
  1614  		t.Fatalf("expected assigned identities len: %d, got: %d", 20000, len(*listAssignedIDs))
  1615  	}
  1616  
  1617  	for i := 10000; i < 20000; i++ {
  1618  		podClient.DeletePod(fmt.Sprintf("test-pod%d", i), "default")
  1619  	}
  1620  	eventCh <- internalaadpodid.PodDeleted
  1621  
  1622  	if !crdClient.waitForAssignedIDs(10000) {
  1623  		t.Fatalf("expected len of assigned identities to be 10000")
  1624  	}
  1625  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1626  	if err != nil {
  1627  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1628  	}
  1629  	if !(len(*listAssignedIDs) == 10000) {
  1630  		t.Fatalf("expected assigned identities len: %d, got: %d", 10000, len(*listAssignedIDs))
  1631  	}
  1632  }
  1633  
  1634  func TestSyncExit(t *testing.T) {
  1635  	eventCh := make(chan internalaadpodid.EventType)
  1636  	cloudClient := NewTestCloudClient(config.AzureConfig{VMType: "vmss"})
  1637  	crdClient := NewTestCrdClient(nil)
  1638  	podClient := NewTestPodClient()
  1639  	nodeClient := NewTestNodeClient()
  1640  	var evtRecorder TestEventRecorder
  1641  	evtRecorder.lastEvent = new(LastEvent)
  1642  	evtRecorder.eventChannel = make(chan bool)
  1643  
  1644  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1645  
  1646  	micClient.testRunSync()(t)
  1647  }
  1648  
  1649  func TestMicAddDelVMSSWithImmutableIdentities(t *testing.T) {
  1650  	eventCh := make(chan internalaadpodid.EventType, 100)
  1651  	cloudClient := NewTestCloudClient(config.AzureConfig{VMType: "vmss"})
  1652  	crdClient := NewTestCrdClient(nil)
  1653  	podClient := NewTestPodClient()
  1654  	nodeClient := NewTestNodeClient()
  1655  	var evtRecorder TestEventRecorder
  1656  	evtRecorder.lastEvent = new(LastEvent)
  1657  	evtRecorder.eventChannel = make(chan bool, 100)
  1658  	var immutableUserMSIs = map[string]bool{
  1659  		"zero-test":              true,
  1660  		"test-user-msi-clientid": true,
  1661  	}
  1662  
  1663  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, immutableUserMSIs)
  1664  
  1665  	// Test to add and delete at the same time.
  1666  	// Add a pod, identity and binding.
  1667  	crdClient.CreateID("test-id1", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid", nil, "", "", "", "")
  1668  	crdClient.CreateBinding("testbinding1", "default", "test-id1", "test-select1", "")
  1669  
  1670  	nodeClient.AddNode("test-node1", func(n *corev1.Node) {
  1671  		n.Spec.ProviderID = "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/testvmss1/virtualMachines/0"
  1672  	})
  1673  
  1674  	nodeClient.AddNode("test-node2", func(n *corev1.Node) {
  1675  		n.Spec.ProviderID = "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/testvmss1/virtualMachines/1"
  1676  	})
  1677  
  1678  	nodeClient.AddNode("test-node3", func(n *corev1.Node) {
  1679  		n.Spec.ProviderID = "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/testvmss2/virtualMachines/0"
  1680  	})
  1681  
  1682  	podClient.AddPod("test-pod1", "default", "test-node1", "test-select1")
  1683  	podClient.AddPod("test-pod2", "default", "test-node2", "test-select1")
  1684  	podClient.AddPod("test-pod3", "default", "test-node3", "test-select1")
  1685  
  1686  	defer micClient.testRunSync()(t)
  1687  
  1688  	eventCh <- internalaadpodid.PodCreated
  1689  	eventCh <- internalaadpodid.PodCreated
  1690  	eventCh <- internalaadpodid.PodCreated
  1691  	if !evtRecorder.WaitForEvents(3) {
  1692  		t.Fatalf("Timeout waiting for mic sync cycles")
  1693  	}
  1694  	if !crdClient.waitForAssignedIDs(3) {
  1695  		t.Fatalf("expected len of assigned identities to be 3")
  1696  	}
  1697  	if !cloudClient.CompareMSI("testvmss1", []string{testResourceID}) {
  1698  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss1"])
  1699  	}
  1700  	if !cloudClient.CompareMSI("testvmss2", []string{testResourceID}) {
  1701  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss2"])
  1702  	}
  1703  
  1704  	podClient.DeletePod("test-pod1", "default")
  1705  	eventCh <- internalaadpodid.PodDeleted
  1706  
  1707  	if !crdClient.waitForAssignedIDs(2) {
  1708  		t.Fatalf("expected len of assigned identities to be 2")
  1709  	}
  1710  	if !cloudClient.CompareMSI("testvmss1", []string{testResourceID}) {
  1711  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss1"])
  1712  	}
  1713  	if !cloudClient.CompareMSI("testvmss2", []string{testResourceID}) {
  1714  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss2"])
  1715  	}
  1716  
  1717  	podClient.DeletePod("test-pod2", "default")
  1718  	eventCh <- internalaadpodid.PodDeleted
  1719  
  1720  	if !crdClient.waitForAssignedIDs(1) {
  1721  		t.Fatalf("expected len of assigned identities to be 1")
  1722  	}
  1723  	if !cloudClient.CompareMSI("testvmss1", []string{testResourceID}) {
  1724  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss1"])
  1725  	}
  1726  	if !cloudClient.CompareMSI("testvmss2", []string{testResourceID}) {
  1727  		t.Fatalf("missing identity: %+v", cloudClient.ListMSI()["testvmss2"])
  1728  	}
  1729  }
  1730  
  1731  func TestCloudProviderRetryLoop(t *testing.T) {
  1732  	eventCh := make(chan internalaadpodid.EventType, 100)
  1733  	cloudClient := NewTestCloudClient(config.AzureConfig{})
  1734  	cloudClient.RetryClient.RegisterRetriableErrors("KnownError")
  1735  	crdClient := NewTestCrdClient(nil)
  1736  	podClient := NewTestPodClient()
  1737  	nodeClient := NewTestNodeClient()
  1738  	var evtRecorder TestEventRecorder
  1739  	evtRecorder.lastEvent = new(LastEvent)
  1740  	evtRecorder.eventChannel = make(chan bool, 100)
  1741  
  1742  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1743  	defer micClient.testRunSync()(t)
  1744  
  1745  	erroneousTestResourceID := strings.Replace(testResourceID, "identity1", "erroneousIdentity", -1)
  1746  	cloudClient.SetError(fmt.Errorf("KnownError: '%s' is erroneous", erroneousTestResourceID))
  1747  	crdClient.CreateID("test-id-1", "default", aadpodid.UserAssignedMSI, erroneousTestResourceID, "test-user-msi-clientid-1", nil, "", "", "", "")
  1748  	crdClient.CreateBinding("test-binding-1", "default", "test-id-1", "test-select-1", "")
  1749  	crdClient.CreateID("test-id-2", "default", aadpodid.UserAssignedMSI, testResourceID, "test-user-msi-clientid-2", nil, "", "", "", "")
  1750  	crdClient.CreateBinding("test-binding-2", "default", "test-id-2", "test-select-2", "")
  1751  
  1752  	nodeClient.AddNode("test-node-1")
  1753  	podClient.AddPod("test-pod-1", "default", "test-node-1", "test-select-1")
  1754  	podClient.AddPod("test-pod-2", "default", "test-node-1", "test-select-2")
  1755  
  1756  	eventCh <- internalaadpodid.PodCreated
  1757  	if !evtRecorder.WaitForEvents(1) {
  1758  		t.Fatalf("Timeout waiting for mic sync cycles")
  1759  	}
  1760  
  1761  	if !crdClient.waitForAssignedIDs(2) {
  1762  		t.Fatalf("expected len of assigned identities to be 2")
  1763  	}
  1764  
  1765  	listAssignedIDs, err := crdClient.ListAssignedIDs()
  1766  	if err != nil {
  1767  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1768  	}
  1769  
  1770  	assignedID := findAssignedIDByName("test-pod-1-default-test-id-1", listAssignedIDs)
  1771  	// Not in assigned state since the identity is erroneous
  1772  	if assignedID.Status.Status != aadpodid.AssignedIDCreated {
  1773  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDCreated, assignedID.Status.Status)
  1774  	}
  1775  
  1776  	assignedID = findAssignedIDByName("test-pod-2-default-test-id-2", listAssignedIDs)
  1777  	if assignedID.Status.Status != aadpodid.AssignedIDAssigned {
  1778  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, assignedID.Status.Status)
  1779  	}
  1780  
  1781  	podClient.DeletePod("test-pod-2", "default")
  1782  	cloudClient.SetError(fmt.Errorf("KnownError: '%s' is erroneous", testResourceID))
  1783  
  1784  	eventCh <- internalaadpodid.PodDeleted
  1785  	if !evtRecorder.WaitForEvents(1) {
  1786  		t.Fatalf("Timeout waiting for mic sync cycles")
  1787  	}
  1788  
  1789  	if !crdClient.waitForAssignedIDs(2) {
  1790  		t.Fatalf("expected len of assigned identities to be 2")
  1791  	}
  1792  
  1793  	listAssignedIDs, err = crdClient.ListAssignedIDs()
  1794  	if err != nil {
  1795  		t.Fatalf("list assigned ids failed , error: %+v", err)
  1796  	}
  1797  
  1798  	assignedID = findAssignedIDByName("test-pod-1-default-test-id-1", listAssignedIDs)
  1799  	if assignedID.Status.Status != aadpodid.AssignedIDAssigned {
  1800  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, assignedID.Status.Status)
  1801  	}
  1802  
  1803  	assignedID = findAssignedIDByName("test-pod-2-default-test-id-2", listAssignedIDs)
  1804  	// Should still be assigned since the cloud client encountered an error
  1805  	// when unassigning the identity from the underlying node
  1806  	if assignedID.Status.Status != aadpodid.AssignedIDAssigned {
  1807  		t.Fatalf("expected status to be %s, got: %s", aadpodid.AssignedIDAssigned, assignedID.Status.Status)
  1808  	}
  1809  }
  1810  
  1811  func TestGenerateIdentityAssignmentStateVM(t *testing.T) {
  1812  	eventCh := make(chan internalaadpodid.EventType)
  1813  	cloudClient := NewTestCloudClient(config.AzureConfig{VMType: "vmss"})
  1814  	crdClient := NewTestCrdClient(nil)
  1815  	podClient := NewTestPodClient()
  1816  	nodeClient := NewTestNodeClient()
  1817  	var evtRecorder TestEventRecorder
  1818  	evtRecorder.lastEvent = new(LastEvent)
  1819  	evtRecorder.eventChannel = make(chan bool)
  1820  
  1821  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1822  	currentState, desiredState, isVMSSMap, err := micClient.generateIdentityAssignmentState()
  1823  	assert.Empty(t, currentState)
  1824  	assert.Empty(t, desiredState)
  1825  	assert.Empty(t, isVMSSMap)
  1826  	assert.NoError(t, err)
  1827  
  1828  	nodeClient.AddNode("node-0", func(n *corev1.Node) {
  1829  		n.Spec.ProviderID = "azure:///subscriptions/xxx/resourceGroups/xxx/providers/Microsoft.Compute/virtualMachines/node-0"
  1830  	})
  1831  
  1832  	_ = crdClient.CreateAssignedIdentity(&internalaadpodid.AzureAssignedIdentity{
  1833  		Spec: internalaadpodid.AzureAssignedIdentitySpec{
  1834  			NodeName: "node-0",
  1835  			AzureIdentityRef: &internalaadpodid.AzureIdentity{
  1836  				Spec: internalaadpodid.AzureIdentitySpec{
  1837  					ResourceID: testResourceID,
  1838  				},
  1839  			},
  1840  		},
  1841  		Status: internalaadpodid.AzureAssignedIdentityStatus{
  1842  			Status: aadpodid.AssignedIDAssigned,
  1843  		},
  1844  	})
  1845  
  1846  	// the user-assigned identity isn't assigned to a VMSS instance on Azure
  1847  	currentState, desiredState, isVMSSMap, err = micClient.generateIdentityAssignmentState()
  1848  	assert.Equal(t, currentState, map[string]map[string]bool{
  1849  		"node-0": {},
  1850  	})
  1851  	assert.Equal(t, desiredState, map[string]map[string]bool{
  1852  		"node-0": {
  1853  			testResourceID: true,
  1854  		},
  1855  	})
  1856  	assert.Equal(t, isVMSSMap, map[string]bool{
  1857  		"node-0": false,
  1858  	})
  1859  	assert.NoError(t, err)
  1860  
  1861  	// the user-assigned identity is now assigned to a VM instance on Azure
  1862  	vm, _ := cloudClient.testVMClient.Get("", "node-0")
  1863  	vm.Identity = &compute.VirtualMachineIdentity{
  1864  		UserAssignedIdentities: map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue{
  1865  			testResourceID: {},
  1866  		},
  1867  	}
  1868  	_ = cloudClient.testVMClient.UpdateIdentities("", "node-0", vm)
  1869  
  1870  	currentState, desiredState, isVMSSMap, err = micClient.generateIdentityAssignmentState()
  1871  	assert.Equal(t, currentState, map[string]map[string]bool{
  1872  		"node-0": {
  1873  			testResourceID: true,
  1874  		},
  1875  	})
  1876  	assert.Equal(t, desiredState, map[string]map[string]bool{
  1877  		"node-0": {
  1878  			testResourceID: true,
  1879  		},
  1880  	})
  1881  	assert.Equal(t, isVMSSMap, map[string]bool{
  1882  		"node-0": false,
  1883  	})
  1884  	assert.NoError(t, err)
  1885  }
  1886  
  1887  func TestGenerateIdentityAssignmentStateVMSS(t *testing.T) {
  1888  	eventCh := make(chan internalaadpodid.EventType)
  1889  	cloudClient := NewTestCloudClient(config.AzureConfig{VMType: "vmss"})
  1890  	crdClient := NewTestCrdClient(nil)
  1891  	podClient := NewTestPodClient()
  1892  	nodeClient := NewTestNodeClient()
  1893  	var evtRecorder TestEventRecorder
  1894  	evtRecorder.lastEvent = new(LastEvent)
  1895  	evtRecorder.eventChannel = make(chan bool)
  1896  
  1897  	micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder, false, 4, nil)
  1898  	currentState, desiredState, isVMSSMap, err := micClient.generateIdentityAssignmentState()
  1899  	assert.Empty(t, currentState)
  1900  	assert.Empty(t, desiredState)
  1901  	assert.Empty(t, isVMSSMap)
  1902  	assert.NoError(t, err)
  1903  
  1904  	nodeClient.AddNode("node-0", func(n *corev1.Node) {
  1905  		n.Spec.ProviderID = "azure:///subscriptions/xxx/resourceGroups/xxx/providers/Microsoft.Compute/virtualMachineScaleSets/node-0/virtualMachines/0"
  1906  	})
  1907  
  1908  	_ = crdClient.CreateAssignedIdentity(&internalaadpodid.AzureAssignedIdentity{
  1909  		Spec: internalaadpodid.AzureAssignedIdentitySpec{
  1910  			NodeName: "node-0",
  1911  			AzureIdentityRef: &internalaadpodid.AzureIdentity{
  1912  				Spec: internalaadpodid.AzureIdentitySpec{
  1913  					Type:       internalaadpodid.UserAssignedMSI,
  1914  					ResourceID: testResourceID,
  1915  				},
  1916  			},
  1917  		},
  1918  		Status: internalaadpodid.AzureAssignedIdentityStatus{
  1919  			Status: aadpodid.AssignedIDAssigned,
  1920  		},
  1921  	})
  1922  
  1923  	// the user-assigned identity isn't assigned to a VMSS instance on Azure
  1924  	currentState, desiredState, isVMSSMap, err = micClient.generateIdentityAssignmentState()
  1925  	assert.Equal(t, currentState, map[string]map[string]bool{
  1926  		"node-0": {},
  1927  	})
  1928  	assert.Equal(t, desiredState, map[string]map[string]bool{
  1929  		"node-0": {
  1930  			testResourceID: true,
  1931  		},
  1932  	})
  1933  	assert.Equal(t, isVMSSMap, map[string]bool{
  1934  		"node-0": true,
  1935  	})
  1936  	assert.NoError(t, err)
  1937  
  1938  	// the user-assigned identity is now assigned to a VMSS instance on Azure
  1939  	vmss, _ := cloudClient.testVMSSClient.Get("", "node-0")
  1940  	vmss.Identity = &compute.VirtualMachineScaleSetIdentity{
  1941  		UserAssignedIdentities: map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
  1942  			testResourceID: {},
  1943  		},
  1944  	}
  1945  	_ = cloudClient.testVMSSClient.UpdateIdentities("", "node-0", vmss)
  1946  
  1947  	currentState, desiredState, isVMSSMap, err = micClient.generateIdentityAssignmentState()
  1948  	assert.Equal(t, currentState, map[string]map[string]bool{
  1949  		"node-0": {
  1950  			testResourceID: true,
  1951  		},
  1952  	})
  1953  	assert.Equal(t, desiredState, map[string]map[string]bool{
  1954  		"node-0": {
  1955  			testResourceID: true,
  1956  		},
  1957  	})
  1958  	assert.Equal(t, isVMSSMap, map[string]bool{
  1959  		"node-0": true,
  1960  	})
  1961  	assert.NoError(t, err)
  1962  }
  1963  
  1964  func TestGenerateIdentityAssignmentDiff(t *testing.T) {
  1965  	testCases := []struct {
  1966  		currentState map[string]map[string]bool
  1967  		desiredState map[string]map[string]bool
  1968  		expectedDiff map[string][]string
  1969  	}{
  1970  		{
  1971  			currentState: map[string]map[string]bool{
  1972  				"node-0": {
  1973  					"id-0": true,
  1974  				},
  1975  			},
  1976  			desiredState: map[string]map[string]bool{
  1977  				"node-0": {
  1978  					"id-0": true,
  1979  				},
  1980  			},
  1981  			expectedDiff: map[string][]string{},
  1982  		},
  1983  		{
  1984  			currentState: map[string]map[string]bool{
  1985  				"node-1": {
  1986  					"id-1": true,
  1987  				},
  1988  			},
  1989  			desiredState: map[string]map[string]bool{
  1990  				"node-0": {
  1991  					"id-0": true,
  1992  				},
  1993  				"node-1": {
  1994  					"id-0": true,
  1995  					"id-1": true,
  1996  				},
  1997  			},
  1998  			expectedDiff: map[string][]string{
  1999  				"node-0": {
  2000  					"id-0",
  2001  				},
  2002  				"node-1": {
  2003  					"id-0",
  2004  				},
  2005  			},
  2006  		},
  2007  		{
  2008  			currentState: nil,
  2009  			desiredState: map[string]map[string]bool{
  2010  				"node-0": {
  2011  					"id-0": true,
  2012  				},
  2013  			},
  2014  			expectedDiff: map[string][]string{
  2015  				"node-0": {
  2016  					"id-0",
  2017  				},
  2018  			},
  2019  		},
  2020  		{
  2021  			currentState: map[string]map[string]bool{
  2022  				"node-0": {
  2023  					"id-0": true,
  2024  				},
  2025  			},
  2026  			desiredState: nil,
  2027  			expectedDiff: map[string][]string{},
  2028  		},
  2029  	}
  2030  
  2031  	for _, tc := range testCases {
  2032  		assert.Equal(t, tc.expectedDiff, generateIdentityAssignmentDiff(tc.currentState, tc.desiredState))
  2033  	}
  2034  }
  2035  
  2036  func findAssignedIDByName(name string, assignedIDs *[]internalaadpodid.AzureAssignedIdentity) *internalaadpodid.AzureAssignedIdentity {
  2037  	for _, assignedID := range *assignedIDs {
  2038  		if assignedID.Name == name {
  2039  			return &assignedID
  2040  		}
  2041  	}
  2042  	return nil
  2043  }