github.com/Azure/aad-pod-identity@v1.8.17/pkg/cloudprovider/cloudprovider_test.go (about) 1 package cloudprovider 2 3 import ( 4 "errors" 5 "net/http" 6 "reflect" 7 "sort" 8 "strings" 9 "testing" 10 "time" 11 12 "github.com/Azure/aad-pod-identity/pkg/config" 13 "github.com/Azure/aad-pod-identity/pkg/retry" 14 15 "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute" 16 "github.com/Azure/go-autorest/autorest/azure" 17 corev1 "k8s.io/api/core/v1" 18 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 19 ) 20 21 func TestParseResourceID(t *testing.T) { 22 type testCase struct { 23 desc string 24 testID string 25 expect azure.Resource 26 xErr bool 27 } 28 29 notNested := "/subscriptions/asdf/resourceGroups/qwerty/providers/testCompute/myComputeObjectType/testComputeResource" 30 nested := "/subscriptions/asdf/resourceGroups/qwerty/providers/testCompute/myComputeObjectType/testComputeResource/someNestedResource/myNestedResource" 31 32 for _, c := range []testCase{ 33 {"empty string", "", azure.Resource{}, true}, 34 {"just a string", "asdf", azure.Resource{}, true}, 35 {"partial match", "/subscriptions/asdf/resourceGroups/qwery", azure.Resource{}, true}, 36 {"nested", nested, azure.Resource{ 37 SubscriptionID: "asdf", 38 ResourceGroup: "qwerty", 39 Provider: "testCompute", 40 ResourceName: "testComputeResource", 41 ResourceType: "myComputeObjectType", 42 }, false}, 43 {"not nested", notNested, azure.Resource{ 44 SubscriptionID: "asdf", 45 ResourceGroup: "qwerty", 46 Provider: "testCompute", 47 ResourceName: "testComputeResource", 48 ResourceType: "myComputeObjectType", 49 }, false}, 50 } { 51 t.Run(c.desc, func(t *testing.T) { 52 r, err := ParseResourceID(c.testID) 53 if (err != nil) != c.xErr { 54 t.Fatalf("expected err==%v, got: %v", c.xErr, err) 55 } 56 if !reflect.DeepEqual(r, c.expect) { 57 t.Fatalf("resource does not match expected:\nexpected:\n\t%+v\ngot:\n\t%+v", c.expect, r) 58 } 59 }) 60 } 61 } 62 func TestSimple(t *testing.T) { 63 vmProvider := "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachines/node3" 64 vmssProvider := "azure:///subscriptions/fakeSub/resourceGroups/fakeGroup/providers/Microsoft.Compute/virtualMachineScaleSets/node4/virtualMachines/0" 65 66 for _, cfg := range []config.AzureConfig{ 67 {}, 68 {VMType: "vmss"}, 69 {VMType: "vm"}, 70 } { 71 desc := cfg.VMType 72 if desc == "" { 73 desc = "default" 74 } 75 t.Run(desc, func(t *testing.T) { 76 cloudClient := NewTestCloudClient(cfg) 77 78 node0 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node0"}} 79 node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node1"}} 80 node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node2"}} 81 node3 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node3-0"}, Spec: corev1.NodeSpec{ProviderID: vmProvider}} 82 node4 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node4-vmss0000000"}, Spec: corev1.NodeSpec{ProviderID: vmssProvider}} 83 84 err := cloudClient.UpdateUserMSI([]string{"ID0", "ID0again"}, []string{}, node0.Name, false) 85 if err != nil { 86 t.Errorf("Couldn't update MSI: %v", err) 87 } 88 err = cloudClient.UpdateUserMSI([]string{"ID1"}, []string{}, node1.Name, false) 89 if err != nil { 90 t.Errorf("Couldn't update MSI: %v", err) 91 } 92 err = cloudClient.UpdateUserMSI([]string{"ID2"}, []string{}, node2.Name, false) 93 if err != nil { 94 t.Errorf("Couldn't update MSI: %v", err) 95 } 96 err = cloudClient.UpdateUserMSI([]string{"ID3"}, []string{}, node3.Name, false) 97 if err != nil { 98 t.Errorf("Couldn't update MSI: %v", err) 99 } 100 err = cloudClient.UpdateUserMSI([]string{"ID4"}, []string{}, node4.Name, true) 101 if err != nil { 102 t.Errorf("Couldn't update MSI: %v", err) 103 } 104 105 testMSI := []string{"ID0", "ID0again"} 106 if !cloudClient.CompareMSI(node0.Name, false, testMSI) { 107 cloudClient.PrintMSI(t) 108 t.Error("MSI mismatch") 109 } 110 111 err = cloudClient.UpdateUserMSI([]string{}, []string{"ID0"}, node0.Name, false) 112 if err != nil { 113 t.Errorf("Couldn't update MSI: %v", err) 114 } 115 err = cloudClient.UpdateUserMSI([]string{}, []string{"ID2"}, node2.Name, false) 116 if err != nil { 117 t.Errorf("Couldn't update MSI: %v", err) 118 } 119 120 testMSI = []string{"ID0again"} 121 if !cloudClient.CompareMSI(node0.Name, false, testMSI) { 122 cloudClient.PrintMSI(t) 123 t.Error("MSI mismatch") 124 } 125 testMSI = []string{} 126 if !cloudClient.CompareMSI(node2.Name, false, testMSI) { 127 cloudClient.PrintMSI(t) 128 t.Error("MSI mismatch") 129 } 130 131 testMSI = []string{"ID3"} 132 if !cloudClient.CompareMSI(node3.Name, false, testMSI) { 133 cloudClient.PrintMSI(t) 134 t.Error("MSI mismatch") 135 } 136 137 testMSI = []string{"ID4"} 138 if !cloudClient.CompareMSI(node4.Name, true, testMSI) { 139 cloudClient.PrintMSI(t) 140 t.Error("MSI mismatch") 141 } 142 143 // test the UpdateUserMSI interface 144 err = cloudClient.UpdateUserMSI([]string{"ID1", "ID2", "ID3"}, []string{"ID0again"}, node0.Name, false) 145 if err != nil { 146 t.Errorf("Couldn't update MSI: %v", err) 147 } 148 149 testMSI = []string{"ID1", "ID2", "ID3"} 150 if !cloudClient.CompareMSI(node0.Name, false, testMSI) { 151 cloudClient.PrintMSI(t) 152 t.Error("MSI mismatch") 153 } 154 155 err = cloudClient.UpdateUserMSI(nil, []string{"ID3"}, node3.Name, false) 156 if err != nil { 157 t.Errorf("Couldn't update MSI: %v", err) 158 } 159 160 testMSI = []string{} 161 if !cloudClient.CompareMSI(node3.Name, false, testMSI) { 162 cloudClient.PrintMSI(t) 163 t.Error("MSI mismatch") 164 } 165 166 err = cloudClient.UpdateUserMSI([]string{"ID3"}, nil, node4.Name, true) 167 if err != nil { 168 t.Error("Couldn't update MSI") 169 } 170 171 testMSI = []string{"ID4", "ID3"} 172 if !cloudClient.CompareMSI(node4.Name, true, testMSI) { 173 cloudClient.PrintMSI(t) 174 t.Error("MSI mismatch") 175 } 176 177 err = cloudClient.UpdateUserMSI([]string{"ID3"}, []string{"ID3"}, node4.Name, true) 178 if err != nil { 179 t.Errorf("Couldn't update MSI: %v", err) 180 } 181 182 testMSI = []string{"ID4", "ID3"} 183 if !cloudClient.CompareMSI(node4.Name, true, testMSI) { 184 cloudClient.PrintMSI(t) 185 t.Error("MSI mismatch") 186 } 187 188 // when no add or remove identities, then GET and PATCH should be skipped 189 err = cloudClient.UpdateUserMSI(nil, nil, node4.Name, true) 190 if err != nil { 191 t.Errorf("Couldn't update MSI: %v", err) 192 } 193 194 testMSI = []string{"ID4", "ID3"} 195 if !cloudClient.CompareMSI(node4.Name, true, testMSI) { 196 cloudClient.PrintMSI(t) 197 t.Error("MSI mismatch") 198 } 199 }) 200 } 201 } 202 203 func TestExtractIdentitiesFromError(t *testing.T) { 204 testCases := []struct { 205 err error 206 expectedErroneousIDs []string 207 }{ 208 { 209 err: errors.New(`on the linked scope(s) '/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1' or the linked scope(s) are invalid`), 210 expectedErroneousIDs: []string{ 211 "/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1", 212 }, 213 }, 214 { 215 err: errors.New(`on the linked scope(s) '/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1,/subscriptions/xxxxxxxx-4321-8765-xxxx-xxxxxxxxxxxx/resourcegroups/rg-4567/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-2' or the linked scope(s) are invalid`), 216 expectedErroneousIDs: []string{ 217 "/subscriptions/xxxxxxxx-1234-5678-xxxx-xxxxxxxxxxxx/resourcegroups/rg-1234/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-1", 218 "/subscriptions/xxxxxxxx-4321-8765-xxxx-xxxxxxxxxxxx/resourcegroups/rg-4567/providers/Microsoft.ManagedIdentity/userAssignedIdentities/user-id-2", 219 }, 220 }, 221 { 222 err: errors.New(`error message`), 223 expectedErroneousIDs: []string{}, 224 }, 225 { 226 err: nil, 227 expectedErroneousIDs: []string{}, 228 }, 229 } 230 231 for _, tc := range testCases { 232 actual := extractIdentitiesFromError(tc.err) 233 if len(tc.expectedErroneousIDs) != len(actual) { 234 t.Fatalf("expected to extract %d identity, but got %d", len(tc.expectedErroneousIDs), len(actual)) 235 } 236 237 if !isSliceEqual(actual, tc.expectedErroneousIDs) { 238 t.Fatalf("expected %v to be extracted from the error message, but got %v", tc.expectedErroneousIDs, actual) 239 } 240 } 241 } 242 243 type TestCloudClient struct { 244 *Client 245 // testVMClient is test validation purpose. 246 testVMClient *TestVMClient 247 testVMSSClient *TestVMSSClient 248 } 249 250 type TestVMClient struct { 251 *VMClient 252 nodeMap map[string]*compute.VirtualMachine 253 nodeIDs map[string]map[string]bool 254 err *error 255 } 256 257 func (c *TestVMClient) SetError(err error) { 258 c.err = &err 259 } 260 261 func (c *TestVMClient) UnsetError() { 262 c.err = nil 263 } 264 265 func (c *TestVMClient) Get(rgName string, nodeName string) (compute.VirtualMachine, error) { 266 stored := c.nodeMap[nodeName] 267 if stored == nil { 268 vm := new(compute.VirtualMachine) 269 vm.Identity = &compute.VirtualMachineIdentity{} 270 c.nodeMap[nodeName] = vm 271 c.nodeIDs[nodeName] = make(map[string]bool) 272 return *vm, nil 273 } 274 275 storedIDs := c.nodeIDs[nodeName] 276 newVMIdentity := make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue) 277 for id := range storedIDs { 278 newVMIdentity[id] = &compute.VirtualMachineIdentityUserAssignedIdentitiesValue{} 279 } 280 stored.Identity.UserAssignedIdentities = newVMIdentity 281 return *stored, nil 282 } 283 284 func (c *TestVMClient) UpdateIdentities(rg, nodeName string, vm compute.VirtualMachine) error { 285 if c.err != nil { 286 // Only return the error once 287 defer c.UnsetError() 288 return *c.err 289 } 290 291 if vm.Identity != nil && vm.Identity.UserAssignedIdentities != nil { 292 for k, v := range vm.Identity.UserAssignedIdentities { 293 if v == nil { 294 delete(c.nodeIDs[nodeName], k) 295 } else { 296 c.nodeIDs[nodeName][k] = true 297 } 298 } 299 } 300 if vm.Identity != nil && vm.Identity.UserAssignedIdentities == nil { 301 for k := range c.nodeIDs[nodeName] { 302 delete(c.nodeIDs[nodeName], k) 303 } 304 } 305 306 c.nodeMap[nodeName] = &vm 307 return nil 308 } 309 310 func (c *TestVMClient) ListMSI() map[string]*[]string { 311 ret := make(map[string]*[]string) 312 313 for key, val := range c.nodeMap { 314 var ids []string 315 for k := range val.Identity.UserAssignedIdentities { 316 ids = append(ids, k) 317 } 318 ret[key] = &ids 319 } 320 return ret 321 } 322 323 func (c *TestVMClient) CompareMSI(nodeName string, expectedUserIDs []string) bool { 324 stored := c.nodeMap[nodeName] 325 if stored == nil || stored.Identity == nil { 326 return false 327 } 328 329 var actualUserIDs []string 330 for k := range c.nodeIDs[nodeName] { 331 actualUserIDs = append(actualUserIDs, k) 332 } 333 if actualUserIDs == nil { 334 if len(expectedUserIDs) == 0 && stored.Identity.Type == compute.ResourceIdentityTypeNone { // Validate that we have reset the resource type as none. 335 return true 336 } 337 return false 338 } 339 340 return isSliceEqual(actualUserIDs, expectedUserIDs) 341 } 342 343 type TestVMSSClient struct { 344 *VMSSClient 345 nodeMap map[string]*compute.VirtualMachineScaleSet 346 nodeIDs map[string]map[string]bool 347 err *error 348 } 349 350 func (c *TestVMSSClient) SetError(err error) { 351 c.err = &err 352 } 353 354 func (c *TestVMSSClient) UnsetError() { 355 c.err = nil 356 } 357 358 func (c *TestVMSSClient) Get(rgName string, nodeName string) (compute.VirtualMachineScaleSet, error) { 359 stored := c.nodeMap[nodeName] 360 if stored == nil { 361 vmss := new(compute.VirtualMachineScaleSet) 362 vmss.Identity = &compute.VirtualMachineScaleSetIdentity{} 363 c.nodeMap[nodeName] = vmss 364 c.nodeIDs[nodeName] = make(map[string]bool) 365 return *vmss, nil 366 } 367 368 storedIDs := c.nodeIDs[nodeName] 369 newVMSSIdentity := make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue) 370 for id := range storedIDs { 371 newVMSSIdentity[id] = &compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{} 372 } 373 stored.Identity.UserAssignedIdentities = newVMSSIdentity 374 return *stored, nil 375 } 376 377 func (c *TestVMSSClient) UpdateIdentities(rg, nodeName string, vmss compute.VirtualMachineScaleSet) error { 378 if c.err != nil { 379 // Only return the error once 380 defer c.UnsetError() 381 return *c.err 382 } 383 if vmss.Identity != nil && vmss.Identity.UserAssignedIdentities != nil { 384 for k, v := range vmss.Identity.UserAssignedIdentities { 385 if v == nil { 386 delete(c.nodeIDs[nodeName], k) 387 } else { 388 c.nodeIDs[nodeName][k] = true 389 } 390 } 391 } 392 if vmss.Identity != nil && vmss.Identity.UserAssignedIdentities == nil { 393 for k := range c.nodeIDs[nodeName] { 394 delete(c.nodeIDs[nodeName], k) 395 } 396 } 397 398 c.nodeMap[nodeName] = &vmss 399 return nil 400 } 401 402 func (c *TestVMSSClient) ListMSI() map[string]*[]string { 403 ret := make(map[string]*[]string) 404 405 for key, val := range c.nodeMap { 406 var ids []string 407 for k := range val.Identity.UserAssignedIdentities { 408 ids = append(ids, k) 409 } 410 ret[key] = &ids 411 } 412 return ret 413 } 414 415 func (c *TestVMSSClient) CompareMSI(nodeName string, expectedUserIDs []string) bool { 416 stored := c.nodeMap[nodeName] 417 if stored == nil || stored.Identity == nil { 418 return false 419 } 420 421 var actualUserIDs []string 422 for k := range c.nodeIDs[nodeName] { 423 actualUserIDs = append(actualUserIDs, k) 424 } 425 426 if actualUserIDs == nil { 427 // Validate that we have reset the resource type as none. 428 if len(expectedUserIDs) == 0 && stored.Identity.Type == compute.ResourceIdentityTypeNone { 429 return true 430 } 431 return false 432 } 433 434 if len(actualUserIDs) != len(expectedUserIDs) { 435 return false 436 } 437 438 return isSliceEqual(actualUserIDs, expectedUserIDs) 439 } 440 441 func (c *TestCloudClient) ListMSI() map[string]*[]string { 442 vmssLs := c.testVMSSClient.ListMSI() 443 vmLs := c.testVMClient.ListMSI() 444 445 if vmssLs == nil { 446 return vmLs 447 } 448 if vmLs == nil { 449 return vmssLs 450 } 451 452 for k, v := range vmLs { 453 if v == nil { 454 continue 455 } 456 orig := vmssLs[k] 457 if orig == nil { 458 vmssLs[k] = v 459 continue 460 } 461 462 updated := *orig 463 updated = append(updated, *v...) 464 vmssLs[k] = &updated 465 } 466 return vmssLs 467 } 468 469 func (c *TestCloudClient) CompareMSI(name string, isvmss bool, userIDs []string) bool { 470 if isvmss { 471 return c.testVMSSClient.CompareMSI(name, userIDs) 472 } 473 return c.testVMClient.CompareMSI(name, userIDs) 474 } 475 476 func (c *TestCloudClient) PrintMSI(t *testing.T) { 477 t.Helper() 478 for key, val := range c.ListMSI() { 479 t.Logf("\nNode name: %s\n", key) 480 if val != nil { 481 for i, id := range *val { 482 t.Logf("%d) %s\n", i, id) 483 } 484 } 485 } 486 } 487 488 func (c *TestCloudClient) SetError(err error) { 489 c.testVMClient.SetError(err) 490 c.testVMSSClient.SetError(err) 491 } 492 493 func NewTestVMClient() *TestVMClient { 494 nodeMap := make(map[string]*compute.VirtualMachine) 495 nodeIDs := make(map[string]map[string]bool) 496 vmClient := &VMClient{} 497 498 return &TestVMClient{ 499 vmClient, 500 nodeMap, 501 nodeIDs, 502 nil, 503 } 504 } 505 506 func NewTestVMSSClient() *TestVMSSClient { 507 nodeMap := make(map[string]*compute.VirtualMachineScaleSet) 508 nodeIDs := make(map[string]map[string]bool) 509 vmssClient := &VMSSClient{} 510 511 return &TestVMSSClient{ 512 vmssClient, 513 nodeMap, 514 nodeIDs, 515 nil, 516 } 517 } 518 519 func NewTestCloudClient(cfg config.AzureConfig) *TestCloudClient { 520 vmClient := NewTestVMClient() 521 vmssClient := NewTestVMSSClient() 522 retryClient := retry.NewRetryClient(2, 0) 523 cloudClient := &Client{ 524 Config: cfg, 525 VMClient: vmClient, 526 VMSSClient: vmssClient, 527 RetryClient: retryClient, 528 } 529 530 return &TestCloudClient{ 531 cloudClient, 532 vmClient, 533 vmssClient, 534 } 535 } 536 537 func isSliceEqual(s1, s2 []string) bool { 538 if len(s1) != len(s2) { 539 return false 540 } 541 sort.Strings(s1) 542 sort.Strings(s2) 543 for i := range s1 { 544 if !strings.EqualFold(s1[i], s2[i]) { 545 return false 546 } 547 } 548 return true 549 } 550 551 func TestGetRetryAfter(t *testing.T) { 552 cases := []struct { 553 desc string 554 resp *http.Response 555 expectedRetryAfter time.Duration 556 }{ 557 { 558 desc: "response is nil", 559 expectedRetryAfter: 0, 560 }, 561 { 562 desc: "no Retry-After header in the response", 563 resp: &http.Response{}, 564 expectedRetryAfter: 0, 565 }, 566 { 567 desc: "Retry-After in response is unknown format", 568 resp: &http.Response{Header: http.Header{"Retry-After": []string{time.Now().Add(180 * time.Second).Format(time.RFC822)}}}, 569 expectedRetryAfter: 0, 570 }, 571 { 572 desc: "Retry-After in response is 180", 573 resp: &http.Response{Header: http.Header{"Retry-After": []string{"180"}}}, 574 expectedRetryAfter: 3 * time.Minute, 575 }, 576 { 577 desc: "Retry-After in response is in RFC1123 format", 578 resp: &http.Response{Header: http.Header{"Retry-After": []string{time.Now().Add(180 * time.Second).Format(time.RFC1123)}}}, 579 expectedRetryAfter: 3 * time.Minute, 580 }, 581 } 582 583 for _, tc := range cases { 584 t.Run(tc.desc, func(t *testing.T) { 585 retryAfterDuration := getRetryAfter(tc.resp) 586 if tc.expectedRetryAfter != retryAfterDuration.Round(time.Minute) { 587 t.Fatalf("expected retry after to be: %v, got: %v", tc.expectedRetryAfter, retryAfterDuration) 588 } 589 }) 590 } 591 } 592 593 func TestGetClusterIdentity(t *testing.T) { 594 cases := []struct { 595 desc string 596 config config.AzureConfig 597 expectedClientID string 598 }{ 599 { 600 desc: "cluster using service principal", 601 config: config.AzureConfig{ 602 ClientID: "clientid", 603 ClientSecret: "clientsecret", 604 UserAssignedIdentityID: "", 605 }, 606 expectedClientID: "", 607 }, 608 { 609 desc: "cluster using system-assigned managed identity", 610 config: config.AzureConfig{ 611 ClientID: "msi", 612 ClientSecret: "msi", 613 UserAssignedIdentityID: "", 614 }, 615 expectedClientID: "", 616 }, 617 { 618 desc: "cluster using user-assigned managed identity", 619 config: config.AzureConfig{ 620 ClientID: "msi", 621 ClientSecret: "msi", 622 UserAssignedIdentityID: "userAssignedIdentityID", 623 }, 624 expectedClientID: "userAssignedIdentityID", 625 }, 626 } 627 628 for _, tc := range cases { 629 t.Run(tc.desc, func(t *testing.T) { 630 client := NewTestCloudClient(tc.config) 631 actualClientID := client.GetClusterIdentity() 632 if tc.expectedClientID != actualClientID { 633 t.Fatalf("expected clientID: %s, got: %s", tc.expectedClientID, actualClientID) 634 } 635 }) 636 } 637 }