github.com/Azure/aad-pod-identity@v1.8.17/pkg/cloudprovider/vmss.go (about)

     1  package cloudprovider
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/Azure/aad-pod-identity/pkg/config"
    10  	"github.com/Azure/aad-pod-identity/pkg/metrics"
    11  	"github.com/Azure/aad-pod-identity/pkg/stats"
    12  	"github.com/Azure/aad-pod-identity/version"
    13  
    14  	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute"
    15  	"github.com/Azure/go-autorest/autorest"
    16  	"github.com/Azure/go-autorest/autorest/adal"
    17  	"github.com/Azure/go-autorest/autorest/azure"
    18  	"k8s.io/klog/v2"
    19  )
    20  
    21  // VMSSClient is used to interact with Azure virtual machine scale sets.
    22  type VMSSClient struct {
    23  	client   compute.VirtualMachineScaleSetsClient
    24  	reporter *metrics.Reporter
    25  	// ARM throttling configures.
    26  	retryAfterReader time.Time
    27  	retryAfterWriter time.Time
    28  }
    29  
    30  // VMSSClientInt is the interface used by "cloudprovider" for interacting with Azure vmss
    31  type VMSSClientInt interface {
    32  	Get(rgName, name string) (compute.VirtualMachineScaleSet, error)
    33  	UpdateIdentities(rg, vmssName string, vmu compute.VirtualMachineScaleSet) error
    34  }
    35  
    36  // NewVMSSClient creates a new vmss client.
    37  func NewVMSSClient(config config.AzureConfig, spt *adal.ServicePrincipalToken) (*VMSSClient, error) {
    38  	client := compute.NewVirtualMachineScaleSetsClient(config.SubscriptionID)
    39  
    40  	azureEnv, err := azure.EnvironmentFromName(config.Cloud)
    41  	if err != nil {
    42  		return nil, fmt.Errorf("failed to get cloud environment, error: %+v", err)
    43  	}
    44  	client.BaseURI = azureEnv.ResourceManagerEndpoint
    45  	client.Authorizer = autorest.NewBearerAuthorizer(spt)
    46  	client.PollingDelay = 5 * time.Second
    47  	err = client.AddToUserAgent(version.GetUserAgent("MIC", version.MICVersion))
    48  	if err != nil {
    49  		return nil, fmt.Errorf("failed to add MIC to user agent, error: %+v", err)
    50  	}
    51  
    52  	reporter, err := metrics.NewReporter()
    53  	if err != nil {
    54  		return nil, fmt.Errorf("failed to create reporter for metrics, error: %+v", err)
    55  	}
    56  
    57  	return &VMSSClient{
    58  		client:   client,
    59  		reporter: reporter,
    60  	}, nil
    61  }
    62  
    63  // UpdateIdentities updates the user assigned identities for the provided node
    64  func (c *VMSSClient) UpdateIdentities(rg, vmssName string, vmss compute.VirtualMachineScaleSet) (err error) {
    65  	// if provisioning state is nil, we keep backward compatibility and proceed with the operation
    66  	if vmss.ProvisioningState != nil && *vmss.ProvisioningState == string(compute.ProvisioningStateDeleting) {
    67  		return fmt.Errorf("failed to update identities for %s in %s, vmss is in %s provisioning state", vmssName, rg, *vmss.ProvisioningState)
    68  	}
    69  
    70  	var future compute.VirtualMachineScaleSetsUpdateFuture
    71  	ctx := context.Background()
    72  	begin := time.Now()
    73  	defer func() {
    74  		if err != nil {
    75  			merr := c.reporter.ReportCloudProviderOperationError(metrics.UpdateVMSSOperationName)
    76  			if merr != nil {
    77  				klog.Warningf("failed to report metrics, error: %+v", merr)
    78  			}
    79  			return
    80  		}
    81  		merr := c.reporter.ReportCloudProviderOperationDuration(metrics.UpdateVMSSOperationName, time.Since(begin))
    82  		if merr != nil {
    83  			klog.Warningf("failed to report metrics, error: %+v", merr)
    84  		}
    85  	}()
    86  
    87  	// Report errors if the client is throttled.
    88  	if c.retryAfterWriter.After(time.Now()) {
    89  		return fmt.Errorf("VMSSUpdate client throttled, retry after: %v", c.retryAfterWriter)
    90  	}
    91  
    92  	hasUpdated := false
    93  	remainingIDs := vmss.Identity.UserAssignedIdentities
    94  	for !hasUpdated || len(remainingIDs) > 0 {
    95  		hasUpdated = true
    96  		vmss.Identity.UserAssignedIdentities, remainingIDs = truncateVMSSIdentities(remainingIDs)
    97  		if future, err = c.client.Update(ctx, rg, vmssName, compute.VirtualMachineScaleSetUpdate{Identity: vmss.Identity}); err != nil {
    98  			resp := future.Response()
    99  			// Update RetryAfterWriter so that no more requests would be sent until RetryAfter expires.
   100  			c.retryAfterWriter = time.Now().Add(getRetryAfter(resp))
   101  			return fmt.Errorf("failed to update identities for %s in %s, error: %+v", vmssName, rg, err)
   102  		}
   103  		if err = future.WaitForCompletionRef(ctx, c.client.Client); err != nil {
   104  			return fmt.Errorf("failed to wait for identity update completion for vmss %s in resource group %s, error: %+v", vmssName, rg, err)
   105  		}
   106  		stats.Increment(stats.TotalPatchCalls, 1)
   107  		stats.AggregateConcurrent(stats.CloudPatch, begin, time.Now())
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  // Get gets the passed in vmss.
   114  func (c *VMSSClient) Get(rgName string, vmssName string) (_ compute.VirtualMachineScaleSet, err error) {
   115  	ctx := context.Background()
   116  	begin := time.Now()
   117  	defer func() {
   118  		if err != nil {
   119  			merr := c.reporter.ReportCloudProviderOperationError(metrics.GetVmssOperationName)
   120  			if merr != nil {
   121  				klog.Warningf("failed to report metrics, error: %+v", merr)
   122  			}
   123  			return
   124  		}
   125  		merr := c.reporter.ReportCloudProviderOperationDuration(metrics.GetVmssOperationName, time.Since(begin))
   126  		if merr != nil {
   127  			klog.Warningf("failed to report metrics, error: %+v", merr)
   128  		}
   129  	}()
   130  
   131  	// Report errors if the client is throttled.
   132  	if c.retryAfterReader.After(time.Now()) {
   133  		return compute.VirtualMachineScaleSet{}, fmt.Errorf("VMSSGet client throttled, retry after: %v", c.retryAfterReader)
   134  	}
   135  
   136  	vmss, err := c.client.Get(ctx, rgName, vmssName)
   137  	if err != nil {
   138  		resp := vmss.Response.Response
   139  		// Update RetryAfterReader so that no more requests would be sent until RetryAfter expires.
   140  		c.retryAfterReader = time.Now().Add(getRetryAfter(resp))
   141  		return compute.VirtualMachineScaleSet{}, fmt.Errorf("failed to get vmss %s in resource group %s, error: %+v", vmssName, rgName, err)
   142  	}
   143  	stats.Increment(stats.TotalGetCalls, 1)
   144  	stats.AggregateConcurrent(stats.CloudGet, begin, time.Now())
   145  	return vmss, nil
   146  }
   147  
   148  // vmssIdentityHolder implements `IdentityHolder` for vmss resources.
   149  type vmssIdentityHolder struct {
   150  	vmss *compute.VirtualMachineScaleSet
   151  }
   152  
   153  func (h *vmssIdentityHolder) IdentityInfo() IdentityInfo {
   154  	if h.vmss.Identity == nil {
   155  		return nil
   156  	}
   157  	return &vmssIdentityInfo{h.vmss.Identity}
   158  }
   159  
   160  func (h *vmssIdentityHolder) ResetIdentity() IdentityInfo {
   161  	h.vmss.Identity = &compute.VirtualMachineScaleSetIdentity{
   162  		Type:                   compute.ResourceIdentityTypeUserAssigned,
   163  		UserAssignedIdentities: make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue),
   164  	}
   165  	return h.IdentityInfo()
   166  }
   167  
   168  type vmssIdentityInfo struct {
   169  	info *compute.VirtualMachineScaleSetIdentity
   170  }
   171  
   172  func (i *vmssIdentityInfo) GetUserIdentityList() []string {
   173  	var ids []string
   174  	if i.info == nil {
   175  		return ids
   176  	}
   177  	for id := range i.info.UserAssignedIdentities {
   178  		ids = append(ids, id)
   179  	}
   180  	return ids
   181  }
   182  
   183  func (i *vmssIdentityInfo) SetUserIdentities(ids map[string]bool) bool {
   184  	if i.info.UserAssignedIdentities == nil {
   185  		i.info.UserAssignedIdentities = make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue)
   186  	}
   187  
   188  	nodeList := make(map[string]bool)
   189  	// add all current existing ids
   190  	for id := range i.info.UserAssignedIdentities {
   191  		id = strings.ToLower(id)
   192  		nodeList[id] = true
   193  	}
   194  
   195  	// add and remove the new list of identities keeping the same type as before
   196  	userAssignedIdentities := make(map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue)
   197  	for id, add := range ids {
   198  		id = strings.ToLower(id)
   199  		_, exists := nodeList[id]
   200  		// already exists on node and want to remove existing identity
   201  		if exists && !add {
   202  			userAssignedIdentities[id] = nil
   203  			delete(nodeList, id)
   204  		}
   205  		// doesn't exist on the node and want to add new identity
   206  		if !exists && add {
   207  			userAssignedIdentities[id] = &compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{}
   208  			nodeList[id] = true
   209  		}
   210  		// exists and add - will already be in the nodeList and no need to patch for it
   211  		// not exists and delete - no need to patch it as it already doesn't exist
   212  	}
   213  
   214  	// all identities are the node are to be removed
   215  	if len(nodeList) == 0 {
   216  		i.info.UserAssignedIdentities = nil
   217  		if i.info.Type == compute.ResourceIdentityTypeSystemAssignedUserAssigned {
   218  			i.info.Type = compute.ResourceIdentityTypeSystemAssigned
   219  		} else {
   220  			i.info.Type = compute.ResourceIdentityTypeNone
   221  		}
   222  		return true
   223  	}
   224  
   225  	i.info.Type = getUpdatedResourceIdentityType(i.info.Type)
   226  	i.info.UserAssignedIdentities = userAssignedIdentities
   227  	return len(i.info.UserAssignedIdentities) > 0
   228  }
   229  
   230  func (i *vmssIdentityInfo) RemoveUserIdentity(delID string) bool {
   231  	delID = strings.ToLower(delID)
   232  	if i.info.UserAssignedIdentities != nil {
   233  		if _, ok := i.info.UserAssignedIdentities[delID]; ok {
   234  			delete(i.info.UserAssignedIdentities, delID)
   235  			return true
   236  		}
   237  	}
   238  
   239  	return false
   240  }