github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/actas/actas.go (about)

     1  // Package actas provides a TokenSource that returns access tokens that impersonate
     2  // a different service account other than the default app credentials.
     3  package actas
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"strings"
    13  	"time"
    14  
    15  	log "github.com/golang/glog"
    16  	"golang.org/x/oauth2"
    17  	"google.golang.org/grpc/credentials"
    18  )
    19  
    20  const (
    21  	// See https://cloud.google.com/iam/reference/rest/v1/projects.serviceAccounts/signJwt
    22  	// for details on signJWT call.
    23  	// See https://cloud.google.com/endpoints/docs/openapi/service-account-authentication
    24  	// for details on authenticating as a service account, including the grantType for
    25  	// authenticating using a signed JWT.
    26  	signJWTURLTemplate = "https://iam.googleapis.com/v1/projects/-/serviceAccounts/%s:signJwt"
    27  	audience           = "https://www.googleapis.com/oauth2/v4/token"
    28  	grantType          = "urn:ietf:params:oauth:grant-type:jwt-bearer"
    29  
    30  	// expiryWiggleRoom is the number of seconds to subtract from the expiry time to give a bit of
    31  	// wiggle room for token refreshing.
    32  	expiryWiggleRoom = 120 * time.Second
    33  )
    34  
    35  // signJwtURL is the url string for the signJwt call.  The expected payload is the JWT to sign.
    36  var newSignJWTURL = func(account string) string {
    37  	return fmt.Sprintf(signJWTURLTemplate, account)
    38  }
    39  
    40  var audienceURL = audience
    41  
    42  // TokenSource is an oauth2.TokenSource implementation that provides impersonated credentials.
    43  // The current implementation uses the application default credentials to
    44  // sign for the impersonated credentials.  This means whatever service account
    45  // is running the code needs to be a member in the ServiceAccountTokenCreator
    46  // role for the impersonated service account.
    47  type TokenSource struct {
    48  	// ctx is the context used for obtaining tokens.
    49  	ctx context.Context
    50  
    51  	// actAsAccount is the account that tokens are obtained for.
    52  	actAsAccount string
    53  
    54  	// cred is the credentials used.
    55  	cred credentials.PerRPCCredentials
    56  
    57  	// scopes is the list of scopes for the tokens.
    58  	scopes []string
    59  
    60  	// httpClient is the http client used to obtain tokens.
    61  	httpClient *http.Client
    62  }
    63  
    64  // NewTokenSource returns a impersonated credentials token source.
    65  func NewTokenSource(ctx context.Context, cred credentials.PerRPCCredentials, client *http.Client, actAsAccount string, scopes []string) *TokenSource {
    66  	return &TokenSource{
    67  		ctx:          ctx,
    68  		actAsAccount: actAsAccount,
    69  		cred:         cred,
    70  		scopes:       scopes,
    71  		httpClient:   client,
    72  	}
    73  }
    74  
    75  // Token returns an authorization token for the impersonated service account.
    76  func (s *TokenSource) Token() (*oauth2.Token, error) {
    77  	log.Infof("Generating new act-as token.")
    78  
    79  	authHeaders, err := s.cred.GetRequestMetadata(s.ctx)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	log.V(1).Infof("Obtained credentials request metadata: %+v", authHeaders)
    85  
    86  	signature, err := s.getSignedJWT(authHeaders)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	// Next do the access token request, using the JWT signed as the impersonated SA
    92  	token, err := s.getToken(authHeaders, signature)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return token, nil
    98  }
    99  
   100  // claims contains the set of JWT claims needed to for the signJWT call.
   101  type claims struct {
   102  	Scope string `json:"scope"`
   103  	Iss   string `json:"iss"`
   104  	Aud   string `json:"aud"`
   105  	Iat   int64  `json:"iat"`
   106  }
   107  
   108  // newClaims constructs and encode the jwt claims.
   109  func (s *TokenSource) newClaims() (string, error) {
   110  	c := claims{
   111  		Scope: strings.Join(s.scopes, " "),
   112  		Aud:   audienceURL,
   113  		Iss:   s.actAsAccount,
   114  		Iat:   time.Now().Unix(),
   115  	}
   116  	j, err := json.Marshal(c)
   117  	if err != nil {
   118  		return "", err
   119  	}
   120  	claimsEncoded := strings.Replace(string(j), "\"", "\\\"", -1)
   121  	claimsPayload := "{\"payload\": \"" + claimsEncoded + "\"}"
   122  	return claimsPayload, nil
   123  }
   124  
   125  // signaturePayload is the structure of the returned payload of a successful signJWT call.
   126  type signaturePayload struct {
   127  	KeyID     string `json:"keyId"`
   128  	SignedJwt string `json:"signedJwt"`
   129  }
   130  
   131  // signatureError is the structure of the returned payload of a failed signJWT call.
   132  type signatureError struct {
   133  	Error struct {
   134  		Code    int64  `json:"code"`
   135  		Message string `json:"message"`
   136  		Status  string `json:"status"`
   137  	} `json:"error"`
   138  }
   139  
   140  func (s *TokenSource) getSignedJWT(headers map[string]string) (*signaturePayload, error) {
   141  	claims, err := s.newClaims()
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	// Construct the signJWT request and send it.
   147  	req, err := http.NewRequest("POST", newSignJWTURL(s.actAsAccount), strings.NewReader(claims))
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	// Copy the authHeaders from the default credentials into the request headers.
   152  	for k, v := range headers {
   153  		req.Header.Add(k, v)
   154  	}
   155  
   156  	log.V(1).Infof("HTTP request to signJWT: %+v", req)
   157  
   158  	// Execute the call to signJWT
   159  	resp, err := s.httpClient.Do(req)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	defer resp.Body.Close()
   164  
   165  	log.V(1).Infof("HTTP response from signJWT: %+v", resp)
   166  
   167  	// Extract the signedJWT from the response body.
   168  	signatureBody, err := io.ReadAll(resp.Body)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	if !isOK(resp.StatusCode) {
   173  		var errResp signatureError
   174  		err = json.Unmarshal(signatureBody, &errResp)
   175  		if err != nil {
   176  			return nil, fmt.Errorf("signJWT call failed with http code %v, unable to parse error payload", resp.StatusCode)
   177  		}
   178  		return nil, fmt.Errorf("signJWT call failed with http code %v, error status %v, message %v", resp.StatusCode, errResp.Error.Status, errResp.Error.Message)
   179  	}
   180  
   181  	payload := &signaturePayload{}
   182  	if err := json.Unmarshal(signatureBody, payload); err != nil {
   183  		return nil, fmt.Errorf("failed to parse sign jwt payload %q: %v", string(signatureBody), err)
   184  	}
   185  
   186  	log.V(1).Infof("Payload: %+v", payload)
   187  
   188  	return payload, nil
   189  }
   190  
   191  // tokenPayload is the structure of the returned payload of the access token call when it
   192  // returns successfully.
   193  type tokenPayload struct {
   194  	AccessToken string `json:"access_token"`
   195  	TokenType   string `json:"token_type"`
   196  	ExpiresIn   int64  `json:"expires_in"`
   197  }
   198  
   199  // tokenError is the structure of the returned payload of the access token call when there is
   200  // an error.
   201  type tokenError struct {
   202  	Error            string `json:"error"`
   203  	ErrorDescription string `json:"error_description"`
   204  }
   205  
   206  func (s TokenSource) getToken(headers map[string]string, signature *signaturePayload) (*oauth2.Token, error) {
   207  	accessForm := url.Values{}
   208  	accessForm.Add("grant_type", grantType)
   209  	accessForm.Add("assertion", signature.SignedJwt)
   210  	req, err := http.NewRequest("POST", audienceURL, strings.NewReader(accessForm.Encode()))
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	// Copy the authHeaders into the new request.
   215  	for k, v := range headers {
   216  		req.Header.Add(k, v)
   217  	}
   218  	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
   219  
   220  	log.V(1).Infof("HTTP request: %+v", req)
   221  
   222  	resp, err := s.httpClient.Do(req)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	defer resp.Body.Close()
   227  
   228  	log.V(1).Infof("HTTP response: %+v", resp)
   229  
   230  	tokenBody, err := io.ReadAll(resp.Body)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	if !isOK(resp.StatusCode) {
   235  		var errResp tokenError
   236  		err = json.Unmarshal(tokenBody, &errResp)
   237  		if err != nil {
   238  			return nil, fmt.Errorf("failed to unmarshal access token error response: %v", err)
   239  		}
   240  		return nil, fmt.Errorf("access token call failed with status code %d and error %s, %s", resp.StatusCode, errResp.Error, errResp.ErrorDescription)
   241  	}
   242  
   243  	payload := &tokenPayload{}
   244  	err = json.Unmarshal(tokenBody, payload)
   245  	if err != nil {
   246  		return nil, fmt.Errorf("failed to parse access token payload %q: %v", string(tokenBody), err)
   247  	}
   248  
   249  	log.V(1).Infof("Payload: %+v", payload)
   250  
   251  	token := &oauth2.Token{
   252  		AccessToken: payload.AccessToken,
   253  		TokenType:   payload.TokenType,
   254  		Expiry:      time.Now().Add(time.Duration(payload.ExpiresIn)*time.Second - expiryWiggleRoom),
   255  	}
   256  	return token, nil
   257  }
   258  
   259  func isOK(httpStatusCode int) bool {
   260  	return 200 <= httpStatusCode && httpStatusCode < 300
   261  }