github.com/Azure/aad-pod-identity@v1.8.17/pkg/cloudprovider/vm.go (about)

     1  package cloudprovider
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/Azure/aad-pod-identity/pkg/config"
    10  	"github.com/Azure/aad-pod-identity/pkg/metrics"
    11  	"github.com/Azure/aad-pod-identity/pkg/stats"
    12  	"github.com/Azure/aad-pod-identity/version"
    13  
    14  	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute"
    15  	"github.com/Azure/go-autorest/autorest"
    16  	"github.com/Azure/go-autorest/autorest/adal"
    17  	"github.com/Azure/go-autorest/autorest/azure"
    18  	"k8s.io/klog/v2"
    19  )
    20  
    21  // VMClient client for VirtualMachines
    22  type VMClient struct {
    23  	client   compute.VirtualMachinesClient
    24  	reporter *metrics.Reporter
    25  	// ARM throttling configures.
    26  	retryAfterReader time.Time
    27  	retryAfterWriter time.Time
    28  }
    29  
    30  // VMClientInt is the interface used by "cloudprovider" for interacting with Azure vmas
    31  type VMClientInt interface {
    32  	Get(rgName string, nodeName string) (compute.VirtualMachine, error)
    33  	UpdateIdentities(rg, nodeName string, vmu compute.VirtualMachine) error
    34  }
    35  
    36  // NewVirtualMachinesClient creates a new vm client.
    37  func NewVirtualMachinesClient(config config.AzureConfig, spt *adal.ServicePrincipalToken) (*VMClient, error) {
    38  	client := compute.NewVirtualMachinesClient(config.SubscriptionID)
    39  
    40  	azureEnv, err := azure.EnvironmentFromName(config.Cloud)
    41  	if err != nil {
    42  		return nil, fmt.Errorf("failed to get cloud environment, error: %+v", err)
    43  	}
    44  	client.BaseURI = azureEnv.ResourceManagerEndpoint
    45  	client.Authorizer = autorest.NewBearerAuthorizer(spt)
    46  	client.PollingDelay = 5 * time.Second
    47  	err = client.AddToUserAgent(version.GetUserAgent("MIC", version.MICVersion))
    48  	if err != nil {
    49  		return nil, fmt.Errorf("failed to add MIC to user agent, error: %+v", err)
    50  	}
    51  
    52  	reporter, err := metrics.NewReporter()
    53  	if err != nil {
    54  		return nil, fmt.Errorf("failed to create reporter for metrics, error: %+v", err)
    55  	}
    56  
    57  	return &VMClient{
    58  		client:   client,
    59  		reporter: reporter,
    60  	}, nil
    61  }
    62  
    63  // Get gets the passed in vm.
    64  func (c *VMClient) Get(rgName string, nodeName string) (_ compute.VirtualMachine, err error) {
    65  	ctx := context.Background()
    66  	begin := time.Now()
    67  	defer func() {
    68  		if err != nil {
    69  			merr := c.reporter.ReportCloudProviderOperationError(metrics.GetVMOperationName)
    70  			if merr != nil {
    71  				klog.Warningf("failed to report metrics, error: %+v", merr)
    72  			}
    73  			return
    74  		}
    75  		merr := c.reporter.ReportCloudProviderOperationDuration(metrics.GetVMOperationName, time.Since(begin))
    76  		if merr != nil {
    77  			klog.Warningf("failed to report metrics, error: %+v", merr)
    78  		}
    79  	}()
    80  
    81  	// Report errors if the client is throttled.
    82  	if c.retryAfterReader.After(time.Now()) {
    83  		return compute.VirtualMachine{}, fmt.Errorf("VMGet client throttled, retry after: %v", c.retryAfterReader)
    84  	}
    85  
    86  	vm, err := c.client.Get(ctx, rgName, nodeName, "")
    87  	if err != nil {
    88  		resp := vm.Response.Response
    89  		// Update RetryAfterReader so that no more requests would be sent until RetryAfter expires.
    90  		c.retryAfterReader = time.Now().Add(getRetryAfter(resp))
    91  		return compute.VirtualMachine{}, fmt.Errorf("failed to get vm %s in resource group %s, error: %+v", nodeName, rgName, err)
    92  	}
    93  	stats.Increment(stats.TotalGetCalls, 1)
    94  	stats.AggregateConcurrent(stats.CloudGet, begin, time.Now())
    95  	return vm, nil
    96  }
    97  
    98  // UpdateIdentities updates the user assigned identities for the provided node
    99  func (c *VMClient) UpdateIdentities(rg, nodeName string, vm compute.VirtualMachine) (err error) {
   100  	// if provisioning state is nil, we keep backward compatibility and proceed with the operation
   101  	if vm.ProvisioningState != nil && *vm.ProvisioningState == string(compute.ProvisioningStateDeleting) {
   102  		return fmt.Errorf("failed to update identities for %s in %s, vm is in '%s' provisioning state", nodeName, rg, *vm.ProvisioningState)
   103  	}
   104  
   105  	var future compute.VirtualMachinesUpdateFuture
   106  	ctx := context.Background()
   107  	begin := time.Now()
   108  	defer func() {
   109  		if err != nil {
   110  			merr := c.reporter.ReportCloudProviderOperationError(metrics.UpdateVMOperationName)
   111  			if merr != nil {
   112  				klog.Warningf("failed to report metrics, error: %+v", merr)
   113  			}
   114  			return
   115  		}
   116  		merr := c.reporter.ReportCloudProviderOperationDuration(metrics.UpdateVMOperationName, time.Since(begin))
   117  		if merr != nil {
   118  			klog.Warningf("failed to report metrics, error: %+v", merr)
   119  		}
   120  	}()
   121  
   122  	// Report errors if the client is throttled.
   123  	if c.retryAfterWriter.After(time.Now()) {
   124  		return fmt.Errorf("VMUpdate client throttled, retry after: %v", c.retryAfterWriter)
   125  	}
   126  
   127  	hasUpdated := false
   128  	remainingIDs := vm.Identity.UserAssignedIdentities
   129  	for !hasUpdated || len(remainingIDs) > 0 {
   130  		hasUpdated = true
   131  		vm.Identity.UserAssignedIdentities, remainingIDs = truncateVMIdentities(remainingIDs)
   132  		if future, err = c.client.Update(ctx, rg, nodeName, compute.VirtualMachineUpdate{Identity: vm.Identity}); err != nil {
   133  			resp := future.Response()
   134  			// Update RetryAfterWriter so that no more requests would be sent until RetryAfter expires.
   135  			c.retryAfterWriter = time.Now().Add(getRetryAfter(resp))
   136  			return fmt.Errorf("failed to update identities for %s in %s, error: %+v", nodeName, rg, err)
   137  		}
   138  		if err = future.WaitForCompletionRef(ctx, c.client.Client); err != nil {
   139  			return fmt.Errorf("failed to wait for identity update completion for vm %s in resource group %s, error: %+v", nodeName, rg, err)
   140  		}
   141  		stats.Increment(stats.TotalPatchCalls, 1)
   142  		stats.AggregateConcurrent(stats.CloudPatch, begin, time.Now())
   143  	}
   144  
   145  	return nil
   146  }
   147  
   148  type vmIdentityHolder struct {
   149  	vm *compute.VirtualMachine
   150  }
   151  
   152  func (h *vmIdentityHolder) IdentityInfo() IdentityInfo {
   153  	if h.vm.Identity == nil {
   154  		return nil
   155  	}
   156  	return &vmIdentityInfo{h.vm.Identity}
   157  }
   158  
   159  func (h *vmIdentityHolder) ResetIdentity() IdentityInfo {
   160  	h.vm.Identity = &compute.VirtualMachineIdentity{}
   161  	return h.IdentityInfo()
   162  }
   163  
   164  type vmIdentityInfo struct {
   165  	info *compute.VirtualMachineIdentity
   166  }
   167  
   168  func (i *vmIdentityInfo) GetUserIdentityList() []string {
   169  	var ids []string
   170  	if i.info == nil {
   171  		return ids
   172  	}
   173  	for id := range i.info.UserAssignedIdentities {
   174  		ids = append(ids, id)
   175  	}
   176  	return ids
   177  }
   178  
   179  func (i *vmIdentityInfo) SetUserIdentities(ids map[string]bool) bool {
   180  	if i.info.UserAssignedIdentities == nil {
   181  		i.info.UserAssignedIdentities = make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue)
   182  	}
   183  
   184  	nodeList := make(map[string]bool)
   185  	// add all current existing ids
   186  	for id := range i.info.UserAssignedIdentities {
   187  		id = strings.ToLower(id)
   188  		nodeList[id] = true
   189  	}
   190  
   191  	// add and remove the new list of identities keeping the same type as before
   192  	userAssignedIdentities := make(map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue)
   193  	for id, add := range ids {
   194  		id = strings.ToLower(id)
   195  		_, exists := nodeList[id]
   196  		// already exists on node and want to remove existing identity
   197  		if exists && !add {
   198  			userAssignedIdentities[id] = nil
   199  			delete(nodeList, id)
   200  		}
   201  		// doesn't exist on the node and want to add new identity
   202  		if !exists && add {
   203  			userAssignedIdentities[id] = &compute.VirtualMachineIdentityUserAssignedIdentitiesValue{}
   204  			nodeList[id] = true
   205  		}
   206  		// exists and add - will already be in the nodeList and no need to patch for it
   207  		// not exists and delete - no need to patch it as it already doesn't exist
   208  	}
   209  
   210  	// all identities are the node are to be removed
   211  	if len(nodeList) == 0 {
   212  		i.info.UserAssignedIdentities = nil
   213  		if i.info.Type == compute.ResourceIdentityTypeSystemAssignedUserAssigned {
   214  			i.info.Type = compute.ResourceIdentityTypeSystemAssigned
   215  		} else {
   216  			i.info.Type = compute.ResourceIdentityTypeNone
   217  		}
   218  		return true
   219  	}
   220  
   221  	i.info.Type = getUpdatedResourceIdentityType(i.info.Type)
   222  	i.info.UserAssignedIdentities = userAssignedIdentities
   223  	return len(i.info.UserAssignedIdentities) > 0
   224  }
   225  
   226  func (i *vmIdentityInfo) RemoveUserIdentity(delID string) bool {
   227  	delID = strings.ToLower(delID)
   228  	if i.info.UserAssignedIdentities != nil {
   229  		if _, ok := i.info.UserAssignedIdentities[delID]; ok {
   230  			delete(i.info.UserAssignedIdentities, delID)
   231  			return true
   232  		}
   233  	}
   234  
   235  	return false
   236  }