github.com/Azure/aad-pod-identity@v1.8.17/pkg/auth/auth.go (about)

     1  package auth
     2  
     3  import (
     4  	"context"
     5  	"crypto/rsa"
     6  	"time"
     7  
     8  	"github.com/Azure/aad-pod-identity/pkg/metrics"
     9  	"github.com/Azure/aad-pod-identity/version"
    10  
    11  	"github.com/Azure/go-autorest/autorest/adal"
    12  	"golang.org/x/crypto/pkcs12"
    13  	"k8s.io/klog/v2"
    14  )
    15  
    16  const (
    17  	defaultActiveDirectoryEndpoint = "https://login.microsoftonline.com/"
    18  )
    19  
    20  var reporter *metrics.Reporter
    21  
    22  // GetServicePrincipalTokenFromMSI return the token for the assigned user
    23  func GetServicePrincipalTokenFromMSI(resource string) (_ *adal.Token, err error) {
    24  	begin := time.Now()
    25  	defer func() {
    26  		if err != nil {
    27  			merr := reporter.ReportIMDSOperationError(metrics.AdalTokenFromMSIOperationName)
    28  			if merr != nil {
    29  				klog.Warningf("failed to report metrics, error: %+v", merr)
    30  			}
    31  			return
    32  		}
    33  		merr := reporter.ReportIMDSOperationDuration(metrics.AdalTokenFromMSIOperationName, time.Since(begin))
    34  		if merr != nil {
    35  			klog.Warningf("failed to report metrics, error: %+v", merr)
    36  		}
    37  	}()
    38  
    39  	spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	// obtain a fresh token
    44  	err = spt.Refresh()
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	token := spt.Token()
    49  	return &token, nil
    50  }
    51  
    52  // GetServicePrincipalTokenFromMSIWithUserAssignedID return the token for the assigned user
    53  func GetServicePrincipalTokenFromMSIWithUserAssignedID(clientID, resource string) (_ *adal.Token, err error) {
    54  	begin := time.Now()
    55  	defer func() {
    56  		if err != nil {
    57  			merr := reporter.ReportIMDSOperationError(metrics.AdalTokenFromMSIWithUserAssignedIDOperationName)
    58  			if merr != nil {
    59  				klog.Warningf("failed to report metrics, error: %+v", merr)
    60  			}
    61  			return
    62  		}
    63  		merr := reporter.ReportIMDSOperationDuration(metrics.AdalTokenFromMSIWithUserAssignedIDOperationName, time.Since(begin))
    64  		if merr != nil {
    65  			klog.Warningf("failed to report metrics, error: %+v", merr)
    66  		}
    67  	}()
    68  
    69  	managedIdentityOptions := &adal.ManagedIdentityOptions{ClientID: clientID}
    70  	spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, managedIdentityOptions)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	// obtain a fresh token
    76  	err = spt.Refresh()
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	token := spt.Token()
    81  	return &token, nil
    82  }
    83  
    84  // GetServicePrincipalToken return the token for the assigned user with client secret
    85  func GetServicePrincipalToken(adEndpointFromSpec, tenantID, clientID, secret, resource string, auxiliaryTenantIDs []string) (_ []*adal.Token, err error) {
    86  	begin := time.Now()
    87  	defer func() {
    88  		if err != nil {
    89  			merr := reporter.ReportIMDSOperationError(metrics.AdalTokenOperationName)
    90  			if merr != nil {
    91  				klog.Warningf("failed to report metrics, error: %+v", merr)
    92  			}
    93  			return
    94  		}
    95  		merr := reporter.ReportIMDSOperationDuration(metrics.AdalTokenOperationName, time.Since(begin))
    96  		if merr != nil {
    97  			klog.Warningf("failed to report metrics, error: %+v", merr)
    98  		}
    99  	}()
   100  
   101  	activeDirectoryEndpoint := defaultActiveDirectoryEndpoint
   102  	if adEndpointFromSpec != "" {
   103  		activeDirectoryEndpoint = adEndpointFromSpec
   104  	}
   105  
   106  	if len(auxiliaryTenantIDs) != 0 {
   107  		return newMultiTenantServicePrincipalToken(activeDirectoryEndpoint, tenantID, clientID, secret, resource, auxiliaryTenantIDs)
   108  	}
   109  	return newServicePrincipalToken(activeDirectoryEndpoint, tenantID, clientID, secret, resource)
   110  }
   111  
   112  // newServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
   113  // credentials scoped to the named resource and tenant
   114  func newServicePrincipalToken(activeDirectoryEndpoint, tenantID, clientID, secret, resource string) ([]*adal.Token, error) {
   115  	oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	spt, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, secret, resource)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	// obtain a fresh token
   124  	err = spt.Refresh()
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	token := spt.Token()
   129  	return []*adal.Token{&token}, nil
   130  }
   131  
   132  // newMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
   133  // the first token in the array of tokens returned is the primaryToken
   134  // all tokens [1:] are the auxiliary tokens
   135  func newMultiTenantServicePrincipalToken(activeDirectoryEndpoint, primaryTenantID, clientID, secret, resource string, auxiliaryTenantIDs []string) ([]*adal.Token, error) {
   136  	oauthConfig, err := adal.NewMultiTenantOAuthConfig(activeDirectoryEndpoint, primaryTenantID, auxiliaryTenantIDs, adal.OAuthOptions{})
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	spt, err := adal.NewMultiTenantServicePrincipalToken(oauthConfig, clientID, secret, resource)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	err = spt.RefreshWithContext(context.TODO())
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	var tokens []*adal.Token
   150  	// add primary token as the first token
   151  	primaryToken := spt.PrimaryToken.Token()
   152  	tokens = append(tokens, &primaryToken)
   153  
   154  	// add the auxiliary tokens from [1:]
   155  	for idx := range spt.AuxiliaryTokens {
   156  		auxiliaryToken := spt.AuxiliaryTokens[idx].Token()
   157  		tokens = append(tokens, &auxiliaryToken)
   158  	}
   159  	return tokens, nil
   160  }
   161  
   162  // GetServicePrincipalTokenWithCertificate return the token for the assigned user with certificate
   163  func GetServicePrincipalTokenWithCertificate(adEndpointFromSpec, tenantID, clientID string, certificate []byte, password, resource string) (_ *adal.Token, err error) {
   164  	begin := time.Now()
   165  	defer func() {
   166  		if err != nil {
   167  			merr := reporter.ReportIMDSOperationError(metrics.AdalTokenOperationName)
   168  			if merr != nil {
   169  				klog.Warningf("failed to report metrics, error: %+v", merr)
   170  			}
   171  			return
   172  		}
   173  		merr := reporter.ReportIMDSOperationDuration(metrics.AdalTokenOperationName, time.Since(begin))
   174  		if merr != nil {
   175  			klog.Warningf("failed to report metrics, error: %+v", merr)
   176  		}
   177  	}()
   178  
   179  	activeDirectoryEndpoint := defaultActiveDirectoryEndpoint
   180  	if adEndpointFromSpec != "" {
   181  		activeDirectoryEndpoint = adEndpointFromSpec
   182  	}
   183  	oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	privateKey, cert, err := pkcs12.Decode(certificate, password)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	spt, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, cert, privateKey.(*rsa.PrivateKey), resource)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	// obtain a fresh token
   198  	err = spt.Refresh()
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	token := spt.Token()
   203  	return &token, nil
   204  }
   205  
   206  func init() {
   207  	err := adal.AddToUserAgent(version.GetUserAgent("NMI", version.NMIVersion))
   208  	if err != nil {
   209  		// shouldn't fail ever
   210  		panic(err)
   211  	}
   212  }
   213  
   214  // InitReporter initialize the reporter with given reporter
   215  func InitReporter(reporterInstance *metrics.Reporter) {
   216  	reporter = reporterInstance
   217  }