github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/actas/actas.go (about) 1 // Package actas provides a TokenSource that returns access tokens that impersonate 2 // a different service account other than the default app credentials. 3 package actas 4 5 import ( 6 "context" 7 "encoding/json" 8 "fmt" 9 "io" 10 "net/http" 11 "net/url" 12 "strings" 13 "time" 14 15 log "github.com/golang/glog" 16 "golang.org/x/oauth2" 17 "google.golang.org/grpc/credentials" 18 ) 19 20 const ( 21 // See https://cloud.google.com/iam/reference/rest/v1/projects.serviceAccounts/signJwt 22 // for details on signJWT call. 23 // See https://cloud.google.com/endpoints/docs/openapi/service-account-authentication 24 // for details on authenticating as a service account, including the grantType for 25 // authenticating using a signed JWT. 26 signJWTURLTemplate = "https://iam.googleapis.com/v1/projects/-/serviceAccounts/%s:signJwt" 27 audience = "https://www.googleapis.com/oauth2/v4/token" 28 grantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" 29 30 // expiryWiggleRoom is the number of seconds to subtract from the expiry time to give a bit of 31 // wiggle room for token refreshing. 32 expiryWiggleRoom = 120 * time.Second 33 ) 34 35 // signJwtURL is the url string for the signJwt call. The expected payload is the JWT to sign. 36 var newSignJWTURL = func(account string) string { 37 return fmt.Sprintf(signJWTURLTemplate, account) 38 } 39 40 var audienceURL = audience 41 42 // TokenSource is an oauth2.TokenSource implementation that provides impersonated credentials. 43 // The current implementation uses the application default credentials to 44 // sign for the impersonated credentials. This means whatever service account 45 // is running the code needs to be a member in the ServiceAccountTokenCreator 46 // role for the impersonated service account. 47 type TokenSource struct { 48 // ctx is the context used for obtaining tokens. 49 ctx context.Context 50 51 // actAsAccount is the account that tokens are obtained for. 52 actAsAccount string 53 54 // cred is the credentials used. 55 cred credentials.PerRPCCredentials 56 57 // scopes is the list of scopes for the tokens. 58 scopes []string 59 60 // httpClient is the http client used to obtain tokens. 61 httpClient *http.Client 62 } 63 64 // NewTokenSource returns a impersonated credentials token source. 65 func NewTokenSource(ctx context.Context, cred credentials.PerRPCCredentials, client *http.Client, actAsAccount string, scopes []string) *TokenSource { 66 return &TokenSource{ 67 ctx: ctx, 68 actAsAccount: actAsAccount, 69 cred: cred, 70 scopes: scopes, 71 httpClient: client, 72 } 73 } 74 75 // Token returns an authorization token for the impersonated service account. 76 func (s *TokenSource) Token() (*oauth2.Token, error) { 77 log.Infof("Generating new act-as token.") 78 79 authHeaders, err := s.cred.GetRequestMetadata(s.ctx) 80 if err != nil { 81 return nil, err 82 } 83 84 log.V(1).Infof("Obtained credentials request metadata: %+v", authHeaders) 85 86 signature, err := s.getSignedJWT(authHeaders) 87 if err != nil { 88 return nil, err 89 } 90 91 // Next do the access token request, using the JWT signed as the impersonated SA 92 token, err := s.getToken(authHeaders, signature) 93 if err != nil { 94 return nil, err 95 } 96 97 return token, nil 98 } 99 100 // claims contains the set of JWT claims needed to for the signJWT call. 101 type claims struct { 102 Scope string `json:"scope"` 103 Iss string `json:"iss"` 104 Aud string `json:"aud"` 105 Iat int64 `json:"iat"` 106 } 107 108 // newClaims constructs and encode the jwt claims. 109 func (s *TokenSource) newClaims() (string, error) { 110 c := claims{ 111 Scope: strings.Join(s.scopes, " "), 112 Aud: audienceURL, 113 Iss: s.actAsAccount, 114 Iat: time.Now().Unix(), 115 } 116 j, err := json.Marshal(c) 117 if err != nil { 118 return "", err 119 } 120 claimsEncoded := strings.Replace(string(j), "\"", "\\\"", -1) 121 claimsPayload := "{\"payload\": \"" + claimsEncoded + "\"}" 122 return claimsPayload, nil 123 } 124 125 // signaturePayload is the structure of the returned payload of a successful signJWT call. 126 type signaturePayload struct { 127 KeyID string `json:"keyId"` 128 SignedJwt string `json:"signedJwt"` 129 } 130 131 // signatureError is the structure of the returned payload of a failed signJWT call. 132 type signatureError struct { 133 Error struct { 134 Code int64 `json:"code"` 135 Message string `json:"message"` 136 Status string `json:"status"` 137 } `json:"error"` 138 } 139 140 func (s *TokenSource) getSignedJWT(headers map[string]string) (*signaturePayload, error) { 141 claims, err := s.newClaims() 142 if err != nil { 143 return nil, err 144 } 145 146 // Construct the signJWT request and send it. 147 req, err := http.NewRequest("POST", newSignJWTURL(s.actAsAccount), strings.NewReader(claims)) 148 if err != nil { 149 return nil, err 150 } 151 // Copy the authHeaders from the default credentials into the request headers. 152 for k, v := range headers { 153 req.Header.Add(k, v) 154 } 155 156 log.V(1).Infof("HTTP request to signJWT: %+v", req) 157 158 // Execute the call to signJWT 159 resp, err := s.httpClient.Do(req) 160 if err != nil { 161 return nil, err 162 } 163 defer resp.Body.Close() 164 165 log.V(1).Infof("HTTP response from signJWT: %+v", resp) 166 167 // Extract the signedJWT from the response body. 168 signatureBody, err := io.ReadAll(resp.Body) 169 if err != nil { 170 return nil, err 171 } 172 if !isOK(resp.StatusCode) { 173 var errResp signatureError 174 err = json.Unmarshal(signatureBody, &errResp) 175 if err != nil { 176 return nil, fmt.Errorf("signJWT call failed with http code %v, unable to parse error payload", resp.StatusCode) 177 } 178 return nil, fmt.Errorf("signJWT call failed with http code %v, error status %v, message %v", resp.StatusCode, errResp.Error.Status, errResp.Error.Message) 179 } 180 181 payload := &signaturePayload{} 182 if err := json.Unmarshal(signatureBody, payload); err != nil { 183 return nil, fmt.Errorf("failed to parse sign jwt payload %q: %v", string(signatureBody), err) 184 } 185 186 log.V(1).Infof("Payload: %+v", payload) 187 188 return payload, nil 189 } 190 191 // tokenPayload is the structure of the returned payload of the access token call when it 192 // returns successfully. 193 type tokenPayload struct { 194 AccessToken string `json:"access_token"` 195 TokenType string `json:"token_type"` 196 ExpiresIn int64 `json:"expires_in"` 197 } 198 199 // tokenError is the structure of the returned payload of the access token call when there is 200 // an error. 201 type tokenError struct { 202 Error string `json:"error"` 203 ErrorDescription string `json:"error_description"` 204 } 205 206 func (s TokenSource) getToken(headers map[string]string, signature *signaturePayload) (*oauth2.Token, error) { 207 accessForm := url.Values{} 208 accessForm.Add("grant_type", grantType) 209 accessForm.Add("assertion", signature.SignedJwt) 210 req, err := http.NewRequest("POST", audienceURL, strings.NewReader(accessForm.Encode())) 211 if err != nil { 212 return nil, err 213 } 214 // Copy the authHeaders into the new request. 215 for k, v := range headers { 216 req.Header.Add(k, v) 217 } 218 req.Header.Add("Content-Type", "application/x-www-form-urlencoded") 219 220 log.V(1).Infof("HTTP request: %+v", req) 221 222 resp, err := s.httpClient.Do(req) 223 if err != nil { 224 return nil, err 225 } 226 defer resp.Body.Close() 227 228 log.V(1).Infof("HTTP response: %+v", resp) 229 230 tokenBody, err := io.ReadAll(resp.Body) 231 if err != nil { 232 return nil, err 233 } 234 if !isOK(resp.StatusCode) { 235 var errResp tokenError 236 err = json.Unmarshal(tokenBody, &errResp) 237 if err != nil { 238 return nil, fmt.Errorf("failed to unmarshal access token error response: %v", err) 239 } 240 return nil, fmt.Errorf("access token call failed with status code %d and error %s, %s", resp.StatusCode, errResp.Error, errResp.ErrorDescription) 241 } 242 243 payload := &tokenPayload{} 244 err = json.Unmarshal(tokenBody, payload) 245 if err != nil { 246 return nil, fmt.Errorf("failed to parse access token payload %q: %v", string(tokenBody), err) 247 } 248 249 log.V(1).Infof("Payload: %+v", payload) 250 251 token := &oauth2.Token{ 252 AccessToken: payload.AccessToken, 253 TokenType: payload.TokenType, 254 Expiry: time.Now().Add(time.Duration(payload.ExpiresIn)*time.Second - expiryWiggleRoom), 255 } 256 return token, nil 257 } 258 259 func isOK(httpStatusCode int) bool { 260 return 200 <= httpStatusCode && httpStatusCode < 300 261 }