github.com/Azure/aad-pod-identity@v1.8.17/pkg/cloudprovider/cloudprovider_test.go (about)

     1  package cloudprovider
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"reflect"
     7  	"sort"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/Azure/aad-pod-identity/pkg/config"
    13  	"github.com/Azure/aad-pod-identity/pkg/retry"
    14  
    15  	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute"
    16  	"github.com/Azure/go-autorest/autorest/azure"
    17  	corev1 "k8s.io/api/core/v1"
    18  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    19  )
    20  
    21  func TestParseResourceID(t *testing.T) {
    22  	type testCase struct {
    23  		desc   string
    24  		testID string
    25  		expect azure.Resource
    26  		xErr   bool
    27  	}
    28  
    29  	notNested := "/subscriptions/asdf/resourceGroups/qwerty/providers/testCompute/myComputeObjectType/testComputeResource"
    30  	nested := "/subscriptions/asdf/resourceGroups/qwerty/providers/testCompute/myComputeObjectType/testComputeResource/someNestedResource/myNestedResource"
    31  
    32  	for _, c := range []testCase{
    33  		{"empty string", "", azure.Resource{}, true},
    34  		{"just a string", "asdf", azure.Resource{}, true},
    35  		{"partial match", "/subscriptions/asdf/resourceGroups/qwery", azure.Resource{}, true},
    36  		{"nested", nested, azure.Resource{
    37  			SubscriptionID: "asdf",
    38  			ResourceGroup:  "qwerty",
    39  			Provider:       "testCompute",
    40  			ResourceName:   "testComputeResource",
    41  			ResourceType:   "myComputeObjectType",
    42  		}, false},
    43  		{"not nested", notNested, azure.Resource{
    44  			SubscriptionID: "asdf",
    45  			ResourceGroup:  "qwerty",
    46  			Provider:       "testCompute",
    47  			ResourceName:   "testComputeResource",
    48  			ResourceType:   "myComputeObjectType",
    49  		}, false},
    50  	} {
    51  		t.Run(c.desc, func(t *testing.T) {
    52  			r, err := ParseResourceID(c.testID)
    53  			if (err != nil) != c.xErr {
    54  				t.Fatalf("expected err==%v, got: %v", c.xErr, err)
    55  			}
    56  			if !reflect.DeepEqual(r, c.expect) {
    57  				t.Fatalf("resource does not match expected:\nexpected:\n\t%+v\ngot:\n\t%+v", c.expect, r)
    58  			}
    59  		})
    60  	}
    61  }
    62  func TestSimple(t *testing.T) {
    63  	vmProvider := "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachines/node3"
    64  	vmssProvider := "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/node4/virtualMachines/0"
    65  
    66  	for _, cfg := range []config.AzureConfig{
    67  		{},
    68  		{VMType: "vmss"},
    69  		{VMType: "vm"},
    70  	} {
    71  		desc := cfg.VMType
    72  		if desc == "" {
    73  			desc = "default"
    74  		}
    75  		t.Run(desc, func(t *testing.T) {
    76  			cloudClient := NewTestCloudClient(cfg)
    77  
    78  			node0 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node0"}}
    79  			node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node1"}}
    80  			node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node2"}}
    81  			node3 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node3-0"}, Spec: corev1.NodeSpec{ProviderID: vmProvider}}
    82  			node4 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node4-vmss0000000"}, Spec: corev1.NodeSpec{ProviderID: vmssProvider}}
    83  
    84  			err := cloudClient.UpdateUserMSI([]string{"ID0", "ID0again"}, []string{}, node0.Name, false)
    85  			if err != nil {
    86  				t.Errorf("Couldn't update MSI: %v", err)
    87  			}
    88  			err = cloudClient.UpdateUserMSI([]string{"ID1"}, []string{}, node1.Name, false)
    89  			if err != nil {
    90  				t.Errorf("Couldn't update MSI: %v", err)
    91  			}
    92  			err = cloudClient.UpdateUserMSI([]string{"ID2"}, []string{}, node2.Name, false)
    93  			if err != nil {
    94  				t.Errorf("Couldn't update MSI: %v", err)
    95  			}
    96  			err = cloudClient.UpdateUserMSI([]string{"ID3"}, []string{}, node3.Name, false)
    97  			if err != nil {
    98  				t.Errorf("Couldn't update MSI: %v", err)
    99  			}
   100  			err = cloudClient.UpdateUserMSI([]string{"ID4"}, []string{}, node4.Name, true)
   101  			if err != nil {
   102  				t.Errorf("Couldn't update MSI: %v", err)
   103  			}
   104  
   105  			testMSI := []string{"ID0", "ID0again"}
   106  			if !cloudClient.CompareMSI(node0.Name, false, testMSI) {
   107  				cloudClient.PrintMSI(t)
   108  				t.Error("MSI mismatch")
   109  			}
   110  
   111  			err = cloudClient.UpdateUserMSI([]string{}, []string{"ID0"}, node0.Name, false)
   112  			if err != nil {
   113  				t.Errorf("Couldn't update MSI: %v", err)
   114  			}
   115  			err = cloudClient.UpdateUserMSI([]string{}, []string{"ID2"}, node2.Name, false)
   116  			if err != nil {
   117  				t.Errorf("Couldn't update MSI: %v", err)
   118  			}
   119  
   120  			testMSI = []string{"ID0again"}
   121  			if !cloudClient.CompareMSI(node0.Name, false, testMSI) {
   122  				cloudClient.PrintMSI(t)
   123  				t.Error("MSI mismatch")
   124  			}
   125  			testMSI = []string{}
   126  			if !cloudClient.CompareMSI(node2.Name, false, testMSI) {
   127  				cloudClient.PrintMSI(t)
   128  				t.Error("MSI mismatch")
   129  			}
   130  
   131  			testMSI = []string{"ID3"}
   132  			if !cloudClient.CompareMSI(node3.Name, false, testMSI) {
   133  				cloudClient.PrintMSI(t)
   134  				t.Error("MSI mismatch")
   135  			}
   136  
   137  			testMSI = []string{"ID4"}
   138  			if !cloudClient.CompareMSI(node4.Name, true, testMSI) {
   139  				cloudClient.PrintMSI(t)
   140  				t.Error("MSI mismatch")
   141  			}
   142  
   143  			// test the UpdateUserMSI interface
   144  			err = cloudClient.UpdateUserMSI([]string{"ID1", "ID2", "ID3"}, []string{"ID0again"}, node0.Name, false)
   145  			if err != nil {
   146  				t.Errorf("Couldn't update MSI: %v", err)
   147  			}
   148  
   149  			testMSI = []string{"ID1", "ID2", "ID3"}
   150  			if !cloudClient.CompareMSI(node0.Name, false, testMSI) {
   151  				cloudClient.PrintMSI(t)
   152  				t.Error("MSI mismatch")
   153  			}
   154  
   155  			err = cloudClient.UpdateUserMSI(nil, []string{"ID3"}, node3.Name, false)
   156  			if err != nil {
   157  				t.Errorf("Couldn't update MSI: %v", err)
   158  			}
   159  
   160  			testMSI = []string{}
   161  			if !cloudClient.CompareMSI(node3.Name, false, testMSI) {
   162  				cloudClient.PrintMSI(t)
   163  				t.Error("MSI mismatch")
   164  			}
   165  
   166  			err = cloudClient.UpdateUserMSI([]string{"ID3"}, nil, node4.Name, true)
   167  			if err != nil {
   168  				t.Error("Couldn't update MSI")
   169  			}
   170  
   171  			testMSI = []string{"ID4", "ID3"}
   172  			if !cloudClient.CompareMSI(node4.Name, true, testMSI) {
   173  				cloudClient.PrintMSI(t)
   174  				t.Error("MSI mismatch")
   175  			}
   176  
   177  			err = cloudClient.UpdateUserMSI([]string{"ID3"}, []string{"ID3"}, node4.Name, true)
   178  			if err != nil {
   179  				t.Errorf("Couldn't update MSI: %v", err)
   180  			}
   181  
   182  			testMSI = []string{"ID4", "ID3"}
   183  			if !cloudClient.CompareMSI(node4.Name, true, testMSI) {
   184  				cloudClient.PrintMSI(t)
   185  				t.Error("MSI mismatch")
   186  			}
   187  
   188  			// when no add or remove identities, then GET and PATCH should be skipped
   189  			err = cloudClient.UpdateUserMSI(nil, nil, node4.Name, true)
   190  			if err != nil {
   191  				t.Errorf("Couldn't update MSI: %v", err)
   192  			}
   193  
   194  			testMSI = []string{"ID4", "ID3"}
   195  			if !cloudClient.CompareMSI(node4.Name, true, testMSI) {
   196  				cloudClient.PrintMSI(t)
   197  				t.Error("MSI mismatch")
   198  			}
   199  		})
   200  	}
   201  }
   202  
   203  func TestExtractIdentitiesFromError(t *testing.T) {
   204  	testCases := []struct {
   205  		err                  error
   206  		expectedErroneousIDs []string
   207  	}{
   208  		{
   209  			err: errors.New(`on the linked scope(s) '/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1' or the linked scope(s) are invalid`),
   210  			expectedErroneousIDs: []string{
   211  				"/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1",
   212  			},
   213  		},
   214  		{
   215  			err: errors.New(`on the linked scope(s) '/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1,/subscriptions/xxxxxxxx-4321-8765-xxxx-xxxxxxxxxxxx/resourcegroups/rg-4567/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-2' or the linked scope(s) are invalid`),
   216  			expectedErroneousIDs: []string{
   217  				"/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1",
   218  				"/subscriptions/xxxxxxxx-4321-8765-xxxx-xxxxxxxxxxxx/resourcegroups/rg-4567/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-2",
   219  			},
   220  		},
   221  		{
   222  			err:                  errors.New(`error message`),
   223  			expectedErroneousIDs: []string{},
   224  		},
   225  		{
   226  			err:                  nil,
   227  			expectedErroneousIDs: []string{},
   228  		},
   229  	}
   230  
   231  	for _, tc := range testCases {
   232  		actual := extractIdentitiesFromError(tc.err)
   233  		if len(tc.expectedErroneousIDs) != len(actual) {
   234  			t.Fatalf("expected to extract %d identity, but got %d", len(tc.expectedErroneousIDs), len(actual))
   235  		}
   236  
   237  		if !isSliceEqual(actual, tc.expectedErroneousIDs) {
   238  			t.Fatalf("expected %v to be extracted from the error message, but got %v", tc.expectedErroneousIDs, actual)
   239  		}
   240  	}
   241  }
   242  
   243  type TestCloudClient struct {
   244  	*Client
   245  	// testVMClient is test validation purpose.
   246  	testVMClient   *TestVMClient
   247  	testVMSSClient *TestVMSSClient
   248  }
   249  
   250  type TestVMClient struct {
   251  	*VMClient
   252  	nodeMap map[string]*compute.VirtualMachine
   253  	nodeIDs map[string]map[string]bool
   254  	err     *error
   255  }
   256  
   257  func (c *TestVMClient) SetError(err error) {
   258  	c.err = &err
   259  }
   260  
   261  func (c *TestVMClient) UnsetError() {
   262  	c.err = nil
   263  }
   264  
   265  func (c *TestVMClient) Get(rgName string, nodeName string) (compute.VirtualMachine, error) {
   266  	stored := c.nodeMap[nodeName]
   267  	if stored == nil {
   268  		vm := new(compute.VirtualMachine)
   269  		vm.Identity = &compute.VirtualMachineIdentity{}
   270  		c.nodeMap[nodeName] = vm
   271  		c.nodeIDs[nodeName] = make(map[string]bool)
   272  		return *vm, nil
   273  	}
   274  
   275  	storedIDs := c.nodeIDs[nodeName]
   276  	newVMIdentity := make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue)
   277  	for id := range storedIDs {
   278  		newVMIdentity[id] = &compute.VirtualMachineIdentityUserAssignedIdentitiesValue{}
   279  	}
   280  	stored.Identity.UserAssignedIdentities = newVMIdentity
   281  	return *stored, nil
   282  }
   283  
   284  func (c *TestVMClient) UpdateIdentities(rg, nodeName string, vm compute.VirtualMachine) error {
   285  	if c.err != nil {
   286  		// Only return the error once
   287  		defer c.UnsetError()
   288  		return *c.err
   289  	}
   290  
   291  	if vm.Identity != nil && vm.Identity.UserAssignedIdentities != nil {
   292  		for k, v := range vm.Identity.UserAssignedIdentities {
   293  			if v == nil {
   294  				delete(c.nodeIDs[nodeName], k)
   295  			} else {
   296  				c.nodeIDs[nodeName][k] = true
   297  			}
   298  		}
   299  	}
   300  	if vm.Identity != nil && vm.Identity.UserAssignedIdentities == nil {
   301  		for k := range c.nodeIDs[nodeName] {
   302  			delete(c.nodeIDs[nodeName], k)
   303  		}
   304  	}
   305  
   306  	c.nodeMap[nodeName] = &vm
   307  	return nil
   308  }
   309  
   310  func (c *TestVMClient) ListMSI() map[string]*[]string {
   311  	ret := make(map[string]*[]string)
   312  
   313  	for key, val := range c.nodeMap {
   314  		var ids []string
   315  		for k := range val.Identity.UserAssignedIdentities {
   316  			ids = append(ids, k)
   317  		}
   318  		ret[key] = &ids
   319  	}
   320  	return ret
   321  }
   322  
   323  func (c *TestVMClient) CompareMSI(nodeName string, expectedUserIDs []string) bool {
   324  	stored := c.nodeMap[nodeName]
   325  	if stored == nil || stored.Identity == nil {
   326  		return false
   327  	}
   328  
   329  	var actualUserIDs []string
   330  	for k := range c.nodeIDs[nodeName] {
   331  		actualUserIDs = append(actualUserIDs, k)
   332  	}
   333  	if actualUserIDs == nil {
   334  		if len(expectedUserIDs) == 0 && stored.Identity.Type == compute.ResourceIdentityTypeNone { // Validate that we have reset the resource type as none.
   335  			return true
   336  		}
   337  		return false
   338  	}
   339  
   340  	return isSliceEqual(actualUserIDs, expectedUserIDs)
   341  }
   342  
   343  type TestVMSSClient struct {
   344  	*VMSSClient
   345  	nodeMap map[string]*compute.VirtualMachineScaleSet
   346  	nodeIDs map[string]map[string]bool
   347  	err     *error
   348  }
   349  
   350  func (c *TestVMSSClient) SetError(err error) {
   351  	c.err = &err
   352  }
   353  
   354  func (c *TestVMSSClient) UnsetError() {
   355  	c.err = nil
   356  }
   357  
   358  func (c *TestVMSSClient) Get(rgName string, nodeName string) (compute.VirtualMachineScaleSet, error) {
   359  	stored := c.nodeMap[nodeName]
   360  	if stored == nil {
   361  		vmss := new(compute.VirtualMachineScaleSet)
   362  		vmss.Identity = &compute.VirtualMachineScaleSetIdentity{}
   363  		c.nodeMap[nodeName] = vmss
   364  		c.nodeIDs[nodeName] = make(map[string]bool)
   365  		return *vmss, nil
   366  	}
   367  
   368  	storedIDs := c.nodeIDs[nodeName]
   369  	newVMSSIdentity := make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue)
   370  	for id := range storedIDs {
   371  		newVMSSIdentity[id] = &compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{}
   372  	}
   373  	stored.Identity.UserAssignedIdentities = newVMSSIdentity
   374  	return *stored, nil
   375  }
   376  
   377  func (c *TestVMSSClient) UpdateIdentities(rg, nodeName string, vmss compute.VirtualMachineScaleSet) error {
   378  	if c.err != nil {
   379  		// Only return the error once
   380  		defer c.UnsetError()
   381  		return *c.err
   382  	}
   383  	if vmss.Identity != nil && vmss.Identity.UserAssignedIdentities != nil {
   384  		for k, v := range vmss.Identity.UserAssignedIdentities {
   385  			if v == nil {
   386  				delete(c.nodeIDs[nodeName], k)
   387  			} else {
   388  				c.nodeIDs[nodeName][k] = true
   389  			}
   390  		}
   391  	}
   392  	if vmss.Identity != nil && vmss.Identity.UserAssignedIdentities == nil {
   393  		for k := range c.nodeIDs[nodeName] {
   394  			delete(c.nodeIDs[nodeName], k)
   395  		}
   396  	}
   397  
   398  	c.nodeMap[nodeName] = &vmss
   399  	return nil
   400  }
   401  
   402  func (c *TestVMSSClient) ListMSI() map[string]*[]string {
   403  	ret := make(map[string]*[]string)
   404  
   405  	for key, val := range c.nodeMap {
   406  		var ids []string
   407  		for k := range val.Identity.UserAssignedIdentities {
   408  			ids = append(ids, k)
   409  		}
   410  		ret[key] = &ids
   411  	}
   412  	return ret
   413  }
   414  
   415  func (c *TestVMSSClient) CompareMSI(nodeName string, expectedUserIDs []string) bool {
   416  	stored := c.nodeMap[nodeName]
   417  	if stored == nil || stored.Identity == nil {
   418  		return false
   419  	}
   420  
   421  	var actualUserIDs []string
   422  	for k := range c.nodeIDs[nodeName] {
   423  		actualUserIDs = append(actualUserIDs, k)
   424  	}
   425  
   426  	if actualUserIDs == nil {
   427  		// Validate that we have reset the resource type as none.
   428  		if len(expectedUserIDs) == 0 && stored.Identity.Type == compute.ResourceIdentityTypeNone {
   429  			return true
   430  		}
   431  		return false
   432  	}
   433  
   434  	if len(actualUserIDs) != len(expectedUserIDs) {
   435  		return false
   436  	}
   437  
   438  	return isSliceEqual(actualUserIDs, expectedUserIDs)
   439  }
   440  
   441  func (c *TestCloudClient) ListMSI() map[string]*[]string {
   442  	vmssLs := c.testVMSSClient.ListMSI()
   443  	vmLs := c.testVMClient.ListMSI()
   444  
   445  	if vmssLs == nil {
   446  		return vmLs
   447  	}
   448  	if vmLs == nil {
   449  		return vmssLs
   450  	}
   451  
   452  	for k, v := range vmLs {
   453  		if v == nil {
   454  			continue
   455  		}
   456  		orig := vmssLs[k]
   457  		if orig == nil {
   458  			vmssLs[k] = v
   459  			continue
   460  		}
   461  
   462  		updated := *orig
   463  		updated = append(updated, *v...)
   464  		vmssLs[k] = &updated
   465  	}
   466  	return vmssLs
   467  }
   468  
   469  func (c *TestCloudClient) CompareMSI(name string, isvmss bool, userIDs []string) bool {
   470  	if isvmss {
   471  		return c.testVMSSClient.CompareMSI(name, userIDs)
   472  	}
   473  	return c.testVMClient.CompareMSI(name, userIDs)
   474  }
   475  
   476  func (c *TestCloudClient) PrintMSI(t *testing.T) {
   477  	t.Helper()
   478  	for key, val := range c.ListMSI() {
   479  		t.Logf("\nNode name: %s\n", key)
   480  		if val != nil {
   481  			for i, id := range *val {
   482  				t.Logf("%d) %s\n", i, id)
   483  			}
   484  		}
   485  	}
   486  }
   487  
   488  func (c *TestCloudClient) SetError(err error) {
   489  	c.testVMClient.SetError(err)
   490  	c.testVMSSClient.SetError(err)
   491  }
   492  
   493  func NewTestVMClient() *TestVMClient {
   494  	nodeMap := make(map[string]*compute.VirtualMachine)
   495  	nodeIDs := make(map[string]map[string]bool)
   496  	vmClient := &VMClient{}
   497  
   498  	return &TestVMClient{
   499  		vmClient,
   500  		nodeMap,
   501  		nodeIDs,
   502  		nil,
   503  	}
   504  }
   505  
   506  func NewTestVMSSClient() *TestVMSSClient {
   507  	nodeMap := make(map[string]*compute.VirtualMachineScaleSet)
   508  	nodeIDs := make(map[string]map[string]bool)
   509  	vmssClient := &VMSSClient{}
   510  
   511  	return &TestVMSSClient{
   512  		vmssClient,
   513  		nodeMap,
   514  		nodeIDs,
   515  		nil,
   516  	}
   517  }
   518  
   519  func NewTestCloudClient(cfg config.AzureConfig) *TestCloudClient {
   520  	vmClient := NewTestVMClient()
   521  	vmssClient := NewTestVMSSClient()
   522  	retryClient := retry.NewRetryClient(2, 0)
   523  	cloudClient := &Client{
   524  		Config:      cfg,
   525  		VMClient:    vmClient,
   526  		VMSSClient:  vmssClient,
   527  		RetryClient: retryClient,
   528  	}
   529  
   530  	return &TestCloudClient{
   531  		cloudClient,
   532  		vmClient,
   533  		vmssClient,
   534  	}
   535  }
   536  
   537  func isSliceEqual(s1, s2 []string) bool {
   538  	if len(s1) != len(s2) {
   539  		return false
   540  	}
   541  	sort.Strings(s1)
   542  	sort.Strings(s2)
   543  	for i := range s1 {
   544  		if !strings.EqualFold(s1[i], s2[i]) {
   545  			return false
   546  		}
   547  	}
   548  	return true
   549  }
   550  
   551  func TestGetRetryAfter(t *testing.T) {
   552  	cases := []struct {
   553  		desc               string
   554  		resp               *http.Response
   555  		expectedRetryAfter time.Duration
   556  	}{
   557  		{
   558  			desc:               "response is nil",
   559  			expectedRetryAfter: 0,
   560  		},
   561  		{
   562  			desc:               "no Retry-After header in the response",
   563  			resp:               &http.Response{},
   564  			expectedRetryAfter: 0,
   565  		},
   566  		{
   567  			desc:               "Retry-After in response is unknown format",
   568  			resp:               &http.Response{Header: http.Header{"Retry-After": []string{time.Now().Add(180 * time.Second).Format(time.RFC822)}}},
   569  			expectedRetryAfter: 0,
   570  		},
   571  		{
   572  			desc:               "Retry-After in response is 180",
   573  			resp:               &http.Response{Header: http.Header{"Retry-After": []string{"180"}}},
   574  			expectedRetryAfter: 3 * time.Minute,
   575  		},
   576  		{
   577  			desc:               "Retry-After in response is in RFC1123 format",
   578  			resp:               &http.Response{Header: http.Header{"Retry-After": []string{time.Now().Add(180 * time.Second).Format(time.RFC1123)}}},
   579  			expectedRetryAfter: 3 * time.Minute,
   580  		},
   581  	}
   582  
   583  	for _, tc := range cases {
   584  		t.Run(tc.desc, func(t *testing.T) {
   585  			retryAfterDuration := getRetryAfter(tc.resp)
   586  			if tc.expectedRetryAfter != retryAfterDuration.Round(time.Minute) {
   587  				t.Fatalf("expected retry after to be: %v, got: %v", tc.expectedRetryAfter, retryAfterDuration)
   588  			}
   589  		})
   590  	}
   591  }
   592  
   593  func TestGetClusterIdentity(t *testing.T) {
   594  	cases := []struct {
   595  		desc             string
   596  		config           config.AzureConfig
   597  		expectedClientID string
   598  	}{
   599  		{
   600  			desc: "cluster using service principal",
   601  			config: config.AzureConfig{
   602  				ClientID:               "clientid",
   603  				ClientSecret:           "clientsecret",
   604  				UserAssignedIdentityID: "",
   605  			},
   606  			expectedClientID: "",
   607  		},
   608  		{
   609  			desc: "cluster using system-assigned managed identity",
   610  			config: config.AzureConfig{
   611  				ClientID:               "msi",
   612  				ClientSecret:           "msi",
   613  				UserAssignedIdentityID: "",
   614  			},
   615  			expectedClientID: "",
   616  		},
   617  		{
   618  			desc: "cluster using user-assigned managed identity",
   619  			config: config.AzureConfig{
   620  				ClientID:               "msi",
   621  				ClientSecret:           "msi",
   622  				UserAssignedIdentityID: "userAssignedIdentityID",
   623  			},
   624  			expectedClientID: "userAssignedIdentityID",
   625  		},
   626  	}
   627  
   628  	for _, tc := range cases {
   629  		t.Run(tc.desc, func(t *testing.T) {
   630  			client := NewTestCloudClient(tc.config)
   631  			actualClientID := client.GetClusterIdentity()
   632  			if tc.expectedClientID != actualClientID {
   633  				t.Fatalf("expected clientID: %s, got: %s", tc.expectedClientID, actualClientID)
   634  			}
   635  		})
   636  	}
   637  }