go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/gce/vmtoken/vmtoken.go (about) 1 // Copyright 2019 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package vmtoken implements parsing and verification of signed GCE VM metadata 16 // tokens. 17 // 18 // See https://cloud.google.com/compute/docs/instances/verifying-instance-identity 19 // 20 // Intended to be used from a server environment (e.g. from a GAE), since it 21 // depends on a bunch of luci/server packages that require a properly configured 22 // context. 23 package vmtoken 24 25 import ( 26 "context" 27 "encoding/base64" 28 "encoding/json" 29 "fmt" 30 "net/http" 31 "strings" 32 33 "go.chromium.org/luci/common/clock" 34 "go.chromium.org/luci/common/errors" 35 "go.chromium.org/luci/common/logging" 36 "go.chromium.org/luci/common/retry/transient" 37 "go.chromium.org/luci/gae/service/info" 38 "go.chromium.org/luci/server/auth/signing" 39 "go.chromium.org/luci/server/router" 40 "go.chromium.org/luci/server/warmup" 41 ) 42 43 // Header is the name of the HTTP header where the GCE VM metadata token is 44 // expected. 45 const Header = "X-Luci-Gce-Vm-Token" 46 47 // Payload is extracted from a verified GCE VM metadata token. 48 // 49 // It identifies a VM that produced the token and the target audience for 50 // the token (as it was supplied to the GCE metadata endpoint via 'audience' 51 // request parameter when generating the token). 52 type Payload struct { 53 Project string // GCE project name, e.g. "my-bots" or "domain.com:my-bots" 54 Zone string // GCE zone name where the VM is, e.g. "us-central1-b" 55 Instance string // VM instance name, e.g. "my-instance-1" 56 Audience string // 'aud' field inside the token, usually the server URL 57 } 58 59 // Verify parses a GCE VM metadata token, verifies its signature and expiration 60 // time, and extracts interesting parts of it into Payload struct. 61 // 62 // Does NOT verify the audience field. This is responsibility of the caller. 63 // 64 // The token is in JWT form (three dot-separated base64-encoded strings). It is 65 // expected to be signed by Google OAuth2 backends using RS256 algo. 66 func Verify(c context.Context, jwt string) (*Payload, error) { 67 // Grab root Google OAuth2 keys to verify JWT signature. They are most likely 68 // already cached in the process memory. 69 certs, err := signing.FetchGoogleOAuth2Certificates(c) 70 if err != nil { 71 return nil, err 72 } 73 return verifyImpl(c, jwt, certs) 74 } 75 76 // signatureChecker is used to mock signing.PublicCertificates in tests. 77 type signatureChecker interface { 78 CheckSignature(key string, signed, signature []byte) error 79 } 80 81 func verifyImpl(c context.Context, jwt string, certs signatureChecker) (*Payload, error) { 82 chunks := strings.Split(jwt, ".") 83 if len(chunks) != 3 { 84 return nil, errors.Reason("bad JWT: expected 3 components separated by '.'").Err() 85 } 86 87 // Check the header, grab the key ID from it. 88 var hdr struct { 89 Alg string `json:"alg"` 90 Kid string `json:"kid"` 91 } 92 if err := unmarshalB64JSON(chunks[0], &hdr); err != nil { 93 return nil, errors.Annotate(err, "bad JWT header").Err() 94 } 95 if hdr.Alg != "RS256" { 96 return nil, errors.Reason("bad JWT: only RS256 alg is supported, not %q", hdr.Alg).Err() 97 } 98 if hdr.Kid == "" { 99 return nil, errors.Reason("bad JWT: missing the signing key ID in the header").Err() 100 } 101 102 // Need a raw binary blob with the signature to verify it. 103 sig, err := base64.RawURLEncoding.DecodeString(chunks[2]) 104 if err != nil { 105 return nil, errors.Annotate(err, "bad JWT: can't base64 decode the signature").Err() 106 } 107 108 // Check that "b64(hdr).b64(payload)" part of the token matches the signature. 109 // If it does, we know the token was created by Google. 110 if err = certs.CheckSignature(hdr.Kid, []byte(chunks[0]+"."+chunks[1]), sig); err != nil { 111 return nil, errors.Annotate(err, "bad JWT: bad signature").Err() 112 } 113 114 // Decode the payload. There should be no errors here generally, the encoded 115 // payload is signed and the signature was already verified. Note that for the 116 // sake of completeness and documentation we decode all fields usually present 117 // in the token, even though we use only subset of them below. 118 var payload struct { 119 Aud string `json:"aud"` // audience 120 Azp string `json:"azp"` // authorized party (GCE VM service account ID) 121 Email string `json:"email"` // GCE VM service account email 122 EmailVerified bool `json:"email_verified"` // always true 123 Exp int64 `json:"exp"` // "expiry", as unix timestamp 124 Iat int64 `json:"iat"` // "issued at", as unix timestamp 125 Iss string `json:"iss"` // issuer name 126 Sub string `json:"sub"` // subject (GCE VM service account ID) 127 Google struct { 128 ComputeEngine struct { 129 InstanceCreationTimestamp int64 `json:"instance_creation_timestamp"` 130 InstanceID string `json:"instance_id"` 131 InstanceName string `json:"instance_name"` 132 ProjectID string `json:"project_id"` 133 ProjectNumber int64 `json:"project_number"` 134 Zone string `json:"zone"` 135 } `json:"compute_engine"` 136 } `json:"google"` 137 } 138 if err = unmarshalB64JSON(chunks[1], &payload); err != nil { 139 return nil, errors.Annotate(err, "bad JWT payload").Err() 140 } 141 142 // Tokens can either be in "full" or "standard" format. We want "full", since 143 // "standard" doesn't have details about the VM. 144 if payload.Google.ComputeEngine.ProjectID == "" { 145 return nil, errors.Reason("no google.compute_engine in the GCE VM token, use 'full' format").Err() 146 } 147 148 // Check token's "issued at" and "expiry" claims. Allow some leeway for clock 149 // discrepancy between us and Google OAuth2 backend. 150 const allowedDriftSec = 30 151 switch now := clock.Now(c).Unix(); { 152 case now < payload.Iat-allowedDriftSec: 153 return nil, errors.Reason("bad JWT: too early (now %d < iat %d)", now, payload.Iat).Err() 154 case now > payload.Exp+allowedDriftSec: 155 return nil, errors.Reason("bad JWT: expired (now %d > exp %d)", now, payload.Exp).Err() 156 } 157 158 // The caller is supposed to check 'aud' claim to finish the verification. 159 return &Payload{ 160 Project: payload.Google.ComputeEngine.ProjectID, 161 Zone: payload.Google.ComputeEngine.Zone, 162 Instance: payload.Google.ComputeEngine.InstanceName, 163 Audience: payload.Aud, 164 }, nil 165 } 166 167 func unmarshalB64JSON(blob string, out any) error { 168 raw, err := base64.RawURLEncoding.DecodeString(blob) 169 if err != nil { 170 return errors.Annotate(err, "not base64").Err() 171 } 172 if err := json.Unmarshal(raw, out); err != nil { 173 return errors.Annotate(err, "not JSON").Err() 174 } 175 return nil 176 } 177 178 // pldKey is the key to a *Payload in the context. 179 var pldKey = "pld" 180 181 // withPayload returns a new context with the given *Payload installed. 182 func withPayload(c context.Context, p *Payload) context.Context { 183 return context.WithValue(c, &pldKey, p) 184 } 185 186 // getPayload returns the *Payload installed in the current context. May be nil. 187 func getPayload(c context.Context) *Payload { 188 p, _ := c.Value(&pldKey).(*Payload) 189 return p 190 } 191 192 // Clear returns a new context without a GCE VM metadata token installed. 193 func Clear(c context.Context) context.Context { 194 return context.WithValue(c, &pldKey, nil) 195 } 196 197 // Has returns whether the current context contains a valid GCE VM metadata 198 // token. 199 func Has(c context.Context) bool { 200 return getPayload(c) != nil 201 } 202 203 // Hostname returns the hostname of the VM stored in the current context. 204 func Hostname(c context.Context) string { 205 p := getPayload(c) 206 if p == nil { 207 return "" 208 } 209 return p.Instance 210 } 211 212 // CurrentIdentity returns the identity of the VM stored in the current context. 213 func CurrentIdentity(c context.Context) string { 214 p := getPayload(c) 215 if p == nil { 216 return "gce:anonymous" 217 } 218 // GCE hostnames must be unique per project, so <instance, project> suffices. 219 return fmt.Sprintf("gce:%s:%s", p.Instance, p.Project) 220 } 221 222 // Matches returns whether the current context contains a GCE VM metadata 223 // token matching the given identity. 224 func Matches(c context.Context, host, zone, proj string) bool { 225 p := getPayload(c) 226 if p == nil { 227 return false 228 } 229 logging.Debugf(c, "expecting VM token from %q in %q in %q", host, zone, proj) 230 return p.Instance == host && p.Zone == zone && p.Project == proj 231 } 232 233 // Middleware embeds a Payload in the context if the request contains a GCE VM 234 // metadata token. 235 func Middleware(c *router.Context, next router.Handler) { 236 if tok := c.Request.Header.Get(Header); tok != "" { 237 // TODO(smut): Support requests to other modules, versions. 238 aud := "https://" + info.DefaultVersionHostname(c.Request.Context()) 239 logging.Debugf(c.Request.Context(), "expecting VM token for: %s", aud) 240 switch p, err := Verify(c.Request.Context(), tok); { 241 case transient.Tag.In(err): 242 logging.WithError(err).Errorf(c.Request.Context(), "transient error verifying VM token") 243 http.Error(c.Writer, "error: failed to verify VM token", http.StatusInternalServerError) 244 return 245 case err != nil: 246 logging.WithError(err).Errorf(c.Request.Context(), "invalid VM token") 247 http.Error(c.Writer, "error: invalid VM token", http.StatusUnauthorized) 248 return 249 case p.Audience != aud: 250 logging.Errorf(c.Request.Context(), "received VM token intended for: %s", p.Audience) 251 http.Error(c.Writer, "error: VM token audience mismatch", http.StatusUnauthorized) 252 return 253 default: 254 logging.Debugf(c.Request.Context(), "received VM token from %q in %q in %q for: %s", p.Instance, p.Zone, p.Project, p.Audience) 255 c.Request = c.Request.WithContext(withPayload(c.Request.Context(), p)) 256 } 257 } 258 next(c) 259 } 260 261 func init() { 262 warmup.Register("gce/vmtoken", func(c context.Context) error { 263 _, err := signing.FetchGoogleOAuth2Certificates(c) 264 return err 265 }) 266 }