github.com/Azure/aad-pod-identity@v1.8.17/pkg/cloudprovider/vm.go (about) 1 package cloudprovider 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 "time" 8 9 "github.com/Azure/aad-pod-identity/pkg/config" 10 "github.com/Azure/aad-pod-identity/pkg/metrics" 11 "github.com/Azure/aad-pod-identity/pkg/stats" 12 "github.com/Azure/aad-pod-identity/version" 13 14 "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute" 15 "github.com/Azure/go-autorest/autorest" 16 "github.com/Azure/go-autorest/autorest/adal" 17 "github.com/Azure/go-autorest/autorest/azure" 18 "k8s.io/klog/v2" 19 ) 20 21 // VMClient client for VirtualMachines 22 type VMClient struct { 23 client compute.VirtualMachinesClient 24 reporter *metrics.Reporter 25 // ARM throttling configures. 26 retryAfterReader time.Time 27 retryAfterWriter time.Time 28 } 29 30 // VMClientInt is the interface used by "cloudprovider" for interacting with Azure vmas 31 type VMClientInt interface { 32 Get(rgName string, nodeName string) (compute.VirtualMachine, error) 33 UpdateIdentities(rg, nodeName string, vmu compute.VirtualMachine) error 34 } 35 36 // NewVirtualMachinesClient creates a new vm client. 37 func NewVirtualMachinesClient(config config.AzureConfig, spt *adal.ServicePrincipalToken) (*VMClient, error) { 38 client := compute.NewVirtualMachinesClient(config.SubscriptionID) 39 40 azureEnv, err := azure.EnvironmentFromName(config.Cloud) 41 if err != nil { 42 return nil, fmt.Errorf("failed to get cloud environment, error: %+v", err) 43 } 44 client.BaseURI = azureEnv.ResourceManagerEndpoint 45 client.Authorizer = autorest.NewBearerAuthorizer(spt) 46 client.PollingDelay = 5 * time.Second 47 err = client.AddToUserAgent(version.GetUserAgent("MIC", version.MICVersion)) 48 if err != nil { 49 return nil, fmt.Errorf("failed to add MIC to user agent, error: %+v", err) 50 } 51 52 reporter, err := metrics.NewReporter() 53 if err != nil { 54 return nil, fmt.Errorf("failed to create reporter for metrics, error: %+v", err) 55 } 56 57 return &VMClient{ 58 client: client, 59 reporter: reporter, 60 }, nil 61 } 62 63 // Get gets the passed in vm. 64 func (c *VMClient) Get(rgName string, nodeName string) (_ compute.VirtualMachine, err error) { 65 ctx := context.Background() 66 begin := time.Now() 67 defer func() { 68 if err != nil { 69 merr := c.reporter.ReportCloudProviderOperationError(metrics.GetVMOperationName) 70 if merr != nil { 71 klog.Warningf("failed to report metrics, error: %+v", merr) 72 } 73 return 74 } 75 merr := c.reporter.ReportCloudProviderOperationDuration(metrics.GetVMOperationName, time.Since(begin)) 76 if merr != nil { 77 klog.Warningf("failed to report metrics, error: %+v", merr) 78 } 79 }() 80 81 // Report errors if the client is throttled. 82 if c.retryAfterReader.After(time.Now()) { 83 return compute.VirtualMachine{}, fmt.Errorf("VMGet client throttled, retry after: %v", c.retryAfterReader) 84 } 85 86 vm, err := c.client.Get(ctx, rgName, nodeName, "") 87 if err != nil { 88 resp := vm.Response.Response 89 // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. 90 c.retryAfterReader = time.Now().Add(getRetryAfter(resp)) 91 return compute.VirtualMachine{}, fmt.Errorf("failed to get vm %s in resource group %s, error: %+v", nodeName, rgName, err) 92 } 93 stats.Increment(stats.TotalGetCalls, 1) 94 stats.AggregateConcurrent(stats.CloudGet, begin, time.Now()) 95 return vm, nil 96 } 97 98 // UpdateIdentities updates the user assigned identities for the provided node 99 func (c *VMClient) UpdateIdentities(rg, nodeName string, vm compute.VirtualMachine) (err error) { 100 // if provisioning state is nil, we keep backward compatibility and proceed with the operation 101 if vm.ProvisioningState != nil && *vm.ProvisioningState == string(compute.ProvisioningStateDeleting) { 102 return fmt.Errorf("failed to update identities for %s in %s, vm is in '%s' provisioning state", nodeName, rg, *vm.ProvisioningState) 103 } 104 105 var future compute.VirtualMachinesUpdateFuture 106 ctx := context.Background() 107 begin := time.Now() 108 defer func() { 109 if err != nil { 110 merr := c.reporter.ReportCloudProviderOperationError(metrics.UpdateVMOperationName) 111 if merr != nil { 112 klog.Warningf("failed to report metrics, error: %+v", merr) 113 } 114 return 115 } 116 merr := c.reporter.ReportCloudProviderOperationDuration(metrics.UpdateVMOperationName, time.Since(begin)) 117 if merr != nil { 118 klog.Warningf("failed to report metrics, error: %+v", merr) 119 } 120 }() 121 122 // Report errors if the client is throttled. 123 if c.retryAfterWriter.After(time.Now()) { 124 return fmt.Errorf("VMUpdate client throttled, retry after: %v", c.retryAfterWriter) 125 } 126 127 hasUpdated := false 128 remainingIDs := vm.Identity.UserAssignedIdentities 129 for !hasUpdated || len(remainingIDs) > 0 { 130 hasUpdated = true 131 vm.Identity.UserAssignedIdentities, remainingIDs = truncateVMIdentities(remainingIDs) 132 if future, err = c.client.Update(ctx, rg, nodeName, compute.VirtualMachineUpdate{Identity: vm.Identity}); err != nil { 133 resp := future.Response() 134 // Update RetryAfterWriter so that no more requests would be sent until RetryAfter expires. 135 c.retryAfterWriter = time.Now().Add(getRetryAfter(resp)) 136 return fmt.Errorf("failed to update identities for %s in %s, error: %+v", nodeName, rg, err) 137 } 138 if err = future.WaitForCompletionRef(ctx, c.client.Client); err != nil { 139 return fmt.Errorf("failed to wait for identity update completion for vm %s in resource group %s, error: %+v", nodeName, rg, err) 140 } 141 stats.Increment(stats.TotalPatchCalls, 1) 142 stats.AggregateConcurrent(stats.CloudPatch, begin, time.Now()) 143 } 144 145 return nil 146 } 147 148 type vmIdentityHolder struct { 149 vm *compute.VirtualMachine 150 } 151 152 func (h *vmIdentityHolder) IdentityInfo() IdentityInfo { 153 if h.vm.Identity == nil { 154 return nil 155 } 156 return &vmIdentityInfo{h.vm.Identity} 157 } 158 159 func (h *vmIdentityHolder) ResetIdentity() IdentityInfo { 160 h.vm.Identity = &compute.VirtualMachineIdentity{} 161 return h.IdentityInfo() 162 } 163 164 type vmIdentityInfo struct { 165 info *compute.VirtualMachineIdentity 166 } 167 168 func (i *vmIdentityInfo) GetUserIdentityList() []string { 169 var ids []string 170 if i.info == nil { 171 return ids 172 } 173 for id := range i.info.UserAssignedIdentities { 174 ids = append(ids, id) 175 } 176 return ids 177 } 178 179 func (i *vmIdentityInfo) SetUserIdentities(ids map[string]bool) bool { 180 if i.info.UserAssignedIdentities == nil { 181 i.info.UserAssignedIdentities = make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue) 182 } 183 184 nodeList := make(map[string]bool) 185 // add all current existing ids 186 for id := range i.info.UserAssignedIdentities { 187 id = strings.ToLower(id) 188 nodeList[id] = true 189 } 190 191 // add and remove the new list of identities keeping the same type as before 192 userAssignedIdentities := make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue) 193 for id, add := range ids { 194 id = strings.ToLower(id) 195 _, exists := nodeList[id] 196 // already exists on node and want to remove existing identity 197 if exists && !add { 198 userAssignedIdentities[id] = nil 199 delete(nodeList, id) 200 } 201 // doesn't exist on the node and want to add new identity 202 if !exists && add { 203 userAssignedIdentities[id] = &compute.VirtualMachineIdentityUserAssignedIdentitiesValue{} 204 nodeList[id] = true 205 } 206 // exists and add - will already be in the nodeList and no need to patch for it 207 // not exists and delete - no need to patch it as it already doesn't exist 208 } 209 210 // all identities are the node are to be removed 211 if len(nodeList) == 0 { 212 i.info.UserAssignedIdentities = nil 213 if i.info.Type == compute.ResourceIdentityTypeSystemAssignedUserAssigned { 214 i.info.Type = compute.ResourceIdentityTypeSystemAssigned 215 } else { 216 i.info.Type = compute.ResourceIdentityTypeNone 217 } 218 return true 219 } 220 221 i.info.Type = getUpdatedResourceIdentityType(i.info.Type) 222 i.info.UserAssignedIdentities = userAssignedIdentities 223 return len(i.info.UserAssignedIdentities) > 0 224 } 225 226 func (i *vmIdentityInfo) RemoveUserIdentity(delID string) bool { 227 delID = strings.ToLower(delID) 228 if i.info.UserAssignedIdentities != nil { 229 if _, ok := i.info.UserAssignedIdentities[delID]; ok { 230 delete(i.info.UserAssignedIdentities, delID) 231 return true 232 } 233 } 234 235 return false 236 }