github.com/Azure/aad-pod-identity@v1.8.17/pkg/cloudprovider/cloudprovider.go (about)

     1  package cloudprovider
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"os"
     7  	"path"
     8  	"regexp"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/Azure/aad-pod-identity/pkg/config"
    14  	"github.com/Azure/aad-pod-identity/pkg/retry"
    15  	"github.com/Azure/aad-pod-identity/pkg/utils"
    16  	"github.com/Azure/aad-pod-identity/version"
    17  
    18  	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute"
    19  	"github.com/Azure/go-autorest/autorest"
    20  	"github.com/Azure/go-autorest/autorest/adal"
    21  	"github.com/Azure/go-autorest/autorest/azure"
    22  	yaml "gopkg.in/yaml.v2"
    23  	"k8s.io/klog/v2"
    24  )
    25  
    26  // Client is a cloud provider client
    27  type Client struct {
    28  	VMClient    VMClientInt
    29  	VMSSClient  VMSSClientInt
    30  	RetryClient retry.ClientInt
    31  	ExtClient   compute.VirtualMachineExtensionsClient
    32  	Config      config.AzureConfig
    33  	configFile  string
    34  }
    35  
    36  // ClientInt client interface
    37  type ClientInt interface {
    38  	UpdateUserMSI(addUserAssignedMSIIDs, removeUserAssignedMSIIDs []string, name string, isvmss bool) error
    39  	GetUserMSIs(name string, isvmss bool) ([]string, error)
    40  	Init() error
    41  }
    42  
    43  const (
    44  	// Occurs when the cluster service principal / managed identity does not
    45  	// have the correct role assignment to access a user-assigned identity.
    46  	linkedAuthorizationFailed retry.RetriableError = "LinkedAuthorizationFailed"
    47  	// Occurs when the user-assigned identity does not exist.
    48  	failedIdentityOperation retry.RetriableError = "FailedIdentityOperation"
    49  	// retryAfterHeaderKey is the retry-after header key in ARM responses.
    50  	retryAfterHeaderKey = "Retry-After"
    51  )
    52  
    53  // NewCloudProvider returns a azure cloud provider client
    54  func NewCloudProvider(configFile string, updateUserMSIMaxRetry int, updateUseMSIRetryInterval time.Duration) (*Client, error) {
    55  	client := &Client{
    56  		configFile: configFile,
    57  	}
    58  	if err := client.Init(); err != nil {
    59  		return nil, fmt.Errorf("failed to initialize cloud provider client, error: %+v", err)
    60  	}
    61  	client.RetryClient = retry.NewRetryClient(updateUserMSIMaxRetry, updateUseMSIRetryInterval)
    62  	client.RetryClient.RegisterRetriableErrors(linkedAuthorizationFailed, failedIdentityOperation)
    63  	return client, nil
    64  }
    65  
    66  // Init initializes the cloud provider client based
    67  // on a config path or environment variables
    68  func (c *Client) Init() error {
    69  	c.Config = config.AzureConfig{}
    70  	if c.configFile != "" {
    71  		klog.V(6).Info("populating AzureConfig from azure.json")
    72  		bytes, err := os.ReadFile(c.configFile)
    73  		if err != nil {
    74  			return fmt.Errorf("failed to config file %s, error: %+v", c.configFile, err)
    75  		}
    76  		if err = yaml.Unmarshal(bytes, &c.Config); err != nil {
    77  			return fmt.Errorf("failed to unmarshal JSON, error: %+v", err)
    78  		}
    79  	} else {
    80  		klog.V(6).Info("populating AzureConfig from secret/environment variables")
    81  		c.Config.Cloud = os.Getenv("CLOUD")
    82  		c.Config.TenantID = os.Getenv("TENANT_ID")
    83  		c.Config.ClientID = os.Getenv("CLIENT_ID")
    84  		c.Config.ClientSecret = os.Getenv("CLIENT_SECRET")
    85  		c.Config.SubscriptionID = os.Getenv("SUBSCRIPTION_ID")
    86  		c.Config.ResourceGroupName = os.Getenv("RESOURCE_GROUP")
    87  		c.Config.VMType = os.Getenv("VM_TYPE")
    88  		c.Config.UseManagedIdentityExtension = strings.EqualFold(os.Getenv("USE_MSI"), "True")
    89  		c.Config.UserAssignedIdentityID = os.Getenv("USER_ASSIGNED_MSI_CLIENT_ID")
    90  	}
    91  
    92  	azureEnv, err := azure.EnvironmentFromName(c.Config.Cloud)
    93  	if err != nil {
    94  		return fmt.Errorf("failed to get cloud environment, error: %+v", err)
    95  	}
    96  
    97  	err = adal.AddToUserAgent(version.GetUserAgent("MIC", version.MICVersion))
    98  	if err != nil {
    99  		return fmt.Errorf("failed to add MIC to user agent, error: %+v", err)
   100  	}
   101  
   102  	oauthConfig, err := adal.NewOAuthConfig(azureEnv.ActiveDirectoryEndpoint, c.Config.TenantID)
   103  	if err != nil {
   104  		return fmt.Errorf("failed to create OAuth config, error: %+v", err)
   105  	}
   106  
   107  	var spt *adal.ServicePrincipalToken
   108  	if c.Config.UseManagedIdentityExtension {
   109  		// UserAssignedIdentityID is empty, so we are going to use system assigned MSI
   110  		if c.Config.UserAssignedIdentityID == "" {
   111  			klog.Infof("MIC using system assigned identity for authentication.")
   112  			spt, err = adal.NewServicePrincipalTokenFromMSI("", azureEnv.ResourceManagerEndpoint)
   113  			if err != nil {
   114  				return fmt.Errorf("failed to get token from system-assigned identity, error: %+v", err)
   115  			}
   116  		} else { // User assigned identity usage.
   117  			klog.Infof("MIC using user assigned identity: %s for authentication.", utils.RedactClientID(c.Config.UserAssignedIdentityID))
   118  			spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID("", azureEnv.ResourceManagerEndpoint, c.Config.UserAssignedIdentityID)
   119  			if err != nil {
   120  				return fmt.Errorf("failed to get token from user-assigned identity, error: %+v", err)
   121  			}
   122  		}
   123  	} else { // This is the default scenario - use service principal to get the token.
   124  		spt, err = adal.NewServicePrincipalToken(
   125  			*oauthConfig,
   126  			c.Config.ClientID,
   127  			c.Config.ClientSecret,
   128  			azureEnv.ResourceManagerEndpoint,
   129  		)
   130  		if err != nil {
   131  			return fmt.Errorf("failed to get service principal token, error: %+v", err)
   132  		}
   133  	}
   134  
   135  	extClient := compute.NewVirtualMachineExtensionsClient(c.Config.SubscriptionID)
   136  	extClient.BaseURI = azure.PublicCloud.ResourceManagerEndpoint
   137  	extClient.Authorizer = autorest.NewBearerAuthorizer(spt)
   138  	extClient.PollingDelay = 5 * time.Second
   139  
   140  	c.VMSSClient, err = NewVMSSClient(c.Config, spt)
   141  	if err != nil {
   142  		return fmt.Errorf("failed to create VMSS client, error: %+v", err)
   143  	}
   144  	c.VMClient, err = NewVirtualMachinesClient(c.Config, spt)
   145  	if err != nil {
   146  		return fmt.Errorf("failed to create VM client, error: %+v", err)
   147  	}
   148  
   149  	// We explicitly removes http.StatusTooManyRequests from autorest.StatusCodesForRetry.
   150  	// Refer https://github.com/Azure/go-autorest/issues/398.
   151  	statusCodesForRetry := make([]int, 0)
   152  	for _, code := range autorest.StatusCodesForRetry {
   153  		if code != http.StatusTooManyRequests {
   154  			statusCodesForRetry = append(statusCodesForRetry, code)
   155  		}
   156  	}
   157  	autorest.StatusCodesForRetry = statusCodesForRetry
   158  
   159  	return nil
   160  }
   161  
   162  // GetUserMSIs will return a list of all identities on the node or vmss based on value of isvmss
   163  func (c *Client) GetUserMSIs(name string, isvmss bool) ([]string, error) {
   164  	idH, _, err := c.getIdentityResource(name, isvmss)
   165  	if err != nil {
   166  		return nil, fmt.Errorf("failed to get identity resource, error: %v", err)
   167  	}
   168  	info := idH.IdentityInfo()
   169  	if info == nil {
   170  		return []string{}, nil
   171  	}
   172  	idList := info.GetUserIdentityList()
   173  	return idList, nil
   174  }
   175  
   176  // UpdateUserMSI will batch process the removal and addition of ids
   177  func (c *Client) UpdateUserMSI(addUserAssignedMSIIDs, removeUserAssignedMSIIDs []string, name string, isvmss bool) error {
   178  	// if there are no identities to be assigned and un-assigned, then we should not
   179  	// invoke an additional GET or PATCH request.
   180  	if len(addUserAssignedMSIIDs) == 0 && len(removeUserAssignedMSIIDs) == 0 {
   181  		klog.Infof("No identities to assign or un-assign")
   182  		return nil
   183  	}
   184  	idH, updateFunc, err := c.getIdentityResource(name, isvmss)
   185  	if err != nil {
   186  		return fmt.Errorf("failed to get identity resource, error: %v", err)
   187  	}
   188  
   189  	info := idH.IdentityInfo()
   190  	if info == nil {
   191  		info = idH.ResetIdentity()
   192  	}
   193  
   194  	ids := make(map[string]bool)
   195  	// remove msi ids from the list
   196  	for _, userAssignedMSIID := range removeUserAssignedMSIIDs {
   197  		ids[userAssignedMSIID] = false
   198  	}
   199  	// add new ids to the list
   200  	// add is done after setting del ids in the map to ensure an identity if in
   201  	// both add and del list is not deleted
   202  	for _, userAssignedMSIID := range addUserAssignedMSIIDs {
   203  		ids[userAssignedMSIID] = true
   204  	}
   205  
   206  	if requiresUpdate := info.SetUserIdentities(ids); !requiresUpdate {
   207  		return nil
   208  	}
   209  
   210  	klog.Infof("updating user-assigned identities on %s, assign [%d], unassign [%d]", name, len(addUserAssignedMSIIDs), len(removeUserAssignedMSIIDs))
   211  	timeStarted := time.Now()
   212  	shouldRetry := func(err error) bool {
   213  		if err == nil {
   214  			return false
   215  		}
   216  
   217  		// Filter previously-assigned IDs based on which identities
   218  		// are erroneous from the last occurred error
   219  		erroneousIDs := extractIdentitiesFromError(err)
   220  		removedAny := false
   221  		for _, erroneousID := range erroneousIDs {
   222  			if removed := info.RemoveUserIdentity(erroneousID); removed {
   223  				removedAny = true
   224  				klog.Infof("removing %s from ID list since it is erroneous", erroneousID)
   225  			}
   226  		}
   227  
   228  		// Only retry if there is at least one ID after deleting
   229  		remainingIDs := info.GetUserIdentityList()
   230  		if removedAny && len(remainingIDs) > 0 {
   231  			klog.Infof("attempting to retry with ID list %v", remainingIDs)
   232  			return true
   233  		}
   234  
   235  		return false
   236  	}
   237  	if err := c.RetryClient.Do(updateFunc, shouldRetry); err != nil {
   238  		return err
   239  	}
   240  
   241  	klog.V(6).Infof("UpdateUserMSI of %s completed in %s", name, time.Since(timeStarted))
   242  
   243  	return nil
   244  }
   245  
   246  func (c *Client) getIdentityResource(name string, isvmss bool) (IdentityHolder, func() error, error) {
   247  	rg := c.Config.ResourceGroupName
   248  
   249  	if isvmss {
   250  		vmss, err := c.VMSSClient.Get(rg, name)
   251  		if err != nil {
   252  			return nil, nil, fmt.Errorf("failed to get vmss %s in resource group %s, error: %+v", name, rg, err)
   253  		}
   254  
   255  		update := func() error {
   256  			return c.VMSSClient.UpdateIdentities(rg, name, vmss)
   257  		}
   258  		idH := &vmssIdentityHolder{&vmss}
   259  		return idH, update, nil
   260  	}
   261  
   262  	vm, err := c.VMClient.Get(rg, name)
   263  	if err != nil {
   264  		return nil, nil, fmt.Errorf("failed to get vm %s in resource group %s, error: %+v", name, rg, err)
   265  	}
   266  	update := func() error {
   267  		return c.VMClient.UpdateIdentities(rg, name, vm)
   268  	}
   269  	idH := &vmIdentityHolder{&vm}
   270  	return idH, update, nil
   271  }
   272  
   273  const nestedResourceIDPatternText = `(?i)subscriptions/(.+)/resourceGroups/(.+)/providers/(.+?)/(.+?)/(.+?)/(.+)`
   274  const resourceIDPatternText = `(?i)subscriptions/(.+)/resourceGroups/(.+)/providers/(.+?)/(.+?)/(.+)`
   275  
   276  var (
   277  	nestedResourceIDPattern = regexp.MustCompile(nestedResourceIDPatternText)
   278  	resourceIDPattern       = regexp.MustCompile(resourceIDPatternText)
   279  )
   280  
   281  const (
   282  	// VMResourceType virtual machine resource type
   283  	VMResourceType = "virtualMachines"
   284  	// VMSSResourceType virtual machine scale sets resource type
   285  	VMSSResourceType = "virtualMachineScaleSets"
   286  )
   287  
   288  // ParseResourceID is a slightly modified version of https://github.com/Azure/go-autorest/blob/528b76fd0ebec0682f3e3da7c808cd472b999615/autorest/azure/azure.go#L175
   289  // The modification here is to support a nested resource such as is the case for a node resource in a vmss.
   290  func ParseResourceID(resourceID string) (azure.Resource, error) {
   291  	match := nestedResourceIDPattern.FindStringSubmatch(resourceID)
   292  	if len(match) == 0 {
   293  		match = resourceIDPattern.FindStringSubmatch(resourceID)
   294  	}
   295  
   296  	if len(match) < 6 {
   297  		return azure.Resource{}, fmt.Errorf("failed to parse %s: invalid resource ID format", resourceID)
   298  	}
   299  
   300  	result := azure.Resource{
   301  		SubscriptionID: match[1],
   302  		ResourceGroup:  match[2],
   303  		Provider:       match[3],
   304  		ResourceType:   match[4],
   305  		ResourceName:   path.Base(match[5]),
   306  	}
   307  
   308  	return result, nil
   309  }
   310  
   311  const (
   312  	// This matches identity resource IDs on an error message from ARM
   313  	userAssignedIdentitiesPatternText = `'(,?(?i)/subscriptions/[a-zA-Z0-9-_]+/resourcegroups/[a-zA-Z0-9-_]+/providers/Microsoft.ManagedIdentity/userAssignedIdentities/[a-zA-Z0-9-_]+)+'`
   314  )
   315  
   316  var (
   317  	userAssignedIdentitiesPattern = regexp.MustCompile(userAssignedIdentitiesPatternText)
   318  )
   319  
   320  func extractIdentitiesFromError(err error) []string {
   321  	var extracted []string
   322  	if err == nil {
   323  		return extracted
   324  	}
   325  
   326  	matches := userAssignedIdentitiesPattern.FindStringSubmatch(err.Error())
   327  	if len(matches) == 0 {
   328  		return extracted
   329  	}
   330  
   331  	match := matches[0]
   332  	// Remove leading and trailing single quotes
   333  	match = match[1 : len(match)-1]
   334  
   335  	for _, id := range strings.Split(match, ",") {
   336  		// Sanity check
   337  		if err := utils.ValidateResourceID(id); err != nil {
   338  			klog.Errorf("failed to validate %s, error: %+v", id, err)
   339  			continue
   340  		}
   341  		extracted = append(extracted, id)
   342  	}
   343  
   344  	return extracted
   345  }
   346  
   347  // getRetryAfter gets the retryAfter from http response.
   348  // The value of Retry-After can be either the number of seconds or a date in RFC1123 format.
   349  func getRetryAfter(resp *http.Response) time.Duration {
   350  	if resp == nil {
   351  		return 0
   352  	}
   353  
   354  	ra := resp.Header.Get(retryAfterHeaderKey)
   355  	if ra == "" {
   356  		return 0
   357  	}
   358  
   359  	var dur time.Duration
   360  	if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 {
   361  		dur = time.Duration(retryAfter) * time.Second
   362  	} else if t, err := time.Parse(time.RFC1123, ra); err == nil {
   363  		dur = time.Until(t)
   364  	}
   365  	return dur
   366  }
   367  
   368  // GetClusterIdentity returns the cluster identity that MIC will use for all
   369  // cloud provider operations. This is userAssignedIdentityID configured in the azure.json
   370  func (c *Client) GetClusterIdentity() string {
   371  	return c.Config.UserAssignedIdentityID
   372  }