go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/gce/vmtoken/vmtoken.go (about)

     1  // Copyright 2019 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package vmtoken implements parsing and verification of signed GCE VM metadata
    16  // tokens.
    17  //
    18  // See https://cloud.google.com/compute/docs/instances/verifying-instance-identity
    19  //
    20  // Intended to be used from a server environment (e.g. from a GAE), since it
    21  // depends on a bunch of luci/server packages that require a properly configured
    22  // context.
    23  package vmtoken
    24  
    25  import (
    26  	"context"
    27  	"encoding/base64"
    28  	"encoding/json"
    29  	"fmt"
    30  	"net/http"
    31  	"strings"
    32  
    33  	"go.chromium.org/luci/common/clock"
    34  	"go.chromium.org/luci/common/errors"
    35  	"go.chromium.org/luci/common/logging"
    36  	"go.chromium.org/luci/common/retry/transient"
    37  	"go.chromium.org/luci/gae/service/info"
    38  	"go.chromium.org/luci/server/auth/signing"
    39  	"go.chromium.org/luci/server/router"
    40  	"go.chromium.org/luci/server/warmup"
    41  )
    42  
    43  // Header is the name of the HTTP header where the GCE VM metadata token is
    44  // expected.
    45  const Header = "X-Luci-Gce-Vm-Token"
    46  
    47  // Payload is extracted from a verified GCE VM metadata token.
    48  //
    49  // It identifies a VM that produced the token and the target audience for
    50  // the token (as it was supplied to the GCE metadata endpoint via 'audience'
    51  // request parameter when generating the token).
    52  type Payload struct {
    53  	Project  string // GCE project name, e.g. "my-bots" or "domain.com:my-bots"
    54  	Zone     string // GCE zone name where the VM is, e.g. "us-central1-b"
    55  	Instance string // VM instance name, e.g. "my-instance-1"
    56  	Audience string // 'aud' field inside the token, usually the server URL
    57  }
    58  
    59  // Verify parses a GCE VM metadata token, verifies its signature and expiration
    60  // time, and extracts interesting parts of it into Payload struct.
    61  //
    62  // Does NOT verify the audience field. This is responsibility of the caller.
    63  //
    64  // The token is in JWT form (three dot-separated base64-encoded strings). It is
    65  // expected to be signed by Google OAuth2 backends using RS256 algo.
    66  func Verify(c context.Context, jwt string) (*Payload, error) {
    67  	// Grab root Google OAuth2 keys to verify JWT signature. They are most likely
    68  	// already cached in the process memory.
    69  	certs, err := signing.FetchGoogleOAuth2Certificates(c)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	return verifyImpl(c, jwt, certs)
    74  }
    75  
    76  // signatureChecker is used to mock signing.PublicCertificates in tests.
    77  type signatureChecker interface {
    78  	CheckSignature(key string, signed, signature []byte) error
    79  }
    80  
    81  func verifyImpl(c context.Context, jwt string, certs signatureChecker) (*Payload, error) {
    82  	chunks := strings.Split(jwt, ".")
    83  	if len(chunks) != 3 {
    84  		return nil, errors.Reason("bad JWT: expected 3 components separated by '.'").Err()
    85  	}
    86  
    87  	// Check the header, grab the key ID from it.
    88  	var hdr struct {
    89  		Alg string `json:"alg"`
    90  		Kid string `json:"kid"`
    91  	}
    92  	if err := unmarshalB64JSON(chunks[0], &hdr); err != nil {
    93  		return nil, errors.Annotate(err, "bad JWT header").Err()
    94  	}
    95  	if hdr.Alg != "RS256" {
    96  		return nil, errors.Reason("bad JWT: only RS256 alg is supported, not %q", hdr.Alg).Err()
    97  	}
    98  	if hdr.Kid == "" {
    99  		return nil, errors.Reason("bad JWT: missing the signing key ID in the header").Err()
   100  	}
   101  
   102  	// Need a raw binary blob with the signature to verify it.
   103  	sig, err := base64.RawURLEncoding.DecodeString(chunks[2])
   104  	if err != nil {
   105  		return nil, errors.Annotate(err, "bad JWT: can't base64 decode the signature").Err()
   106  	}
   107  
   108  	// Check that "b64(hdr).b64(payload)" part of the token matches the signature.
   109  	// If it does, we know the token was created by Google.
   110  	if err = certs.CheckSignature(hdr.Kid, []byte(chunks[0]+"."+chunks[1]), sig); err != nil {
   111  		return nil, errors.Annotate(err, "bad JWT: bad signature").Err()
   112  	}
   113  
   114  	// Decode the payload. There should be no errors here generally, the encoded
   115  	// payload is signed and the signature was already verified. Note that for the
   116  	// sake of completeness and documentation we decode all fields usually present
   117  	// in the token, even though we use only subset of them below.
   118  	var payload struct {
   119  		Aud           string `json:"aud"`            // audience
   120  		Azp           string `json:"azp"`            // authorized party (GCE VM service account ID)
   121  		Email         string `json:"email"`          // GCE VM service account email
   122  		EmailVerified bool   `json:"email_verified"` // always true
   123  		Exp           int64  `json:"exp"`            // "expiry", as unix timestamp
   124  		Iat           int64  `json:"iat"`            // "issued at", as unix timestamp
   125  		Iss           string `json:"iss"`            // issuer name
   126  		Sub           string `json:"sub"`            // subject (GCE VM service account ID)
   127  		Google        struct {
   128  			ComputeEngine struct {
   129  				InstanceCreationTimestamp int64  `json:"instance_creation_timestamp"`
   130  				InstanceID                string `json:"instance_id"`
   131  				InstanceName              string `json:"instance_name"`
   132  				ProjectID                 string `json:"project_id"`
   133  				ProjectNumber             int64  `json:"project_number"`
   134  				Zone                      string `json:"zone"`
   135  			} `json:"compute_engine"`
   136  		} `json:"google"`
   137  	}
   138  	if err = unmarshalB64JSON(chunks[1], &payload); err != nil {
   139  		return nil, errors.Annotate(err, "bad JWT payload").Err()
   140  	}
   141  
   142  	// Tokens can either be in "full" or "standard" format. We want "full", since
   143  	// "standard" doesn't have details about the VM.
   144  	if payload.Google.ComputeEngine.ProjectID == "" {
   145  		return nil, errors.Reason("no google.compute_engine in the GCE VM token, use 'full' format").Err()
   146  	}
   147  
   148  	// Check token's "issued at" and "expiry" claims. Allow some leeway for clock
   149  	// discrepancy between us and Google OAuth2 backend.
   150  	const allowedDriftSec = 30
   151  	switch now := clock.Now(c).Unix(); {
   152  	case now < payload.Iat-allowedDriftSec:
   153  		return nil, errors.Reason("bad JWT: too early (now %d < iat %d)", now, payload.Iat).Err()
   154  	case now > payload.Exp+allowedDriftSec:
   155  		return nil, errors.Reason("bad JWT: expired (now %d > exp %d)", now, payload.Exp).Err()
   156  	}
   157  
   158  	// The caller is supposed to check 'aud' claim to finish the verification.
   159  	return &Payload{
   160  		Project:  payload.Google.ComputeEngine.ProjectID,
   161  		Zone:     payload.Google.ComputeEngine.Zone,
   162  		Instance: payload.Google.ComputeEngine.InstanceName,
   163  		Audience: payload.Aud,
   164  	}, nil
   165  }
   166  
   167  func unmarshalB64JSON(blob string, out any) error {
   168  	raw, err := base64.RawURLEncoding.DecodeString(blob)
   169  	if err != nil {
   170  		return errors.Annotate(err, "not base64").Err()
   171  	}
   172  	if err := json.Unmarshal(raw, out); err != nil {
   173  		return errors.Annotate(err, "not JSON").Err()
   174  	}
   175  	return nil
   176  }
   177  
   178  // pldKey is the key to a *Payload in the context.
   179  var pldKey = "pld"
   180  
   181  // withPayload returns a new context with the given *Payload installed.
   182  func withPayload(c context.Context, p *Payload) context.Context {
   183  	return context.WithValue(c, &pldKey, p)
   184  }
   185  
   186  // getPayload returns the *Payload installed in the current context. May be nil.
   187  func getPayload(c context.Context) *Payload {
   188  	p, _ := c.Value(&pldKey).(*Payload)
   189  	return p
   190  }
   191  
   192  // Clear returns a new context without a GCE VM metadata token installed.
   193  func Clear(c context.Context) context.Context {
   194  	return context.WithValue(c, &pldKey, nil)
   195  }
   196  
   197  // Has returns whether the current context contains a valid GCE VM metadata
   198  // token.
   199  func Has(c context.Context) bool {
   200  	return getPayload(c) != nil
   201  }
   202  
   203  // Hostname returns the hostname of the VM stored in the current context.
   204  func Hostname(c context.Context) string {
   205  	p := getPayload(c)
   206  	if p == nil {
   207  		return ""
   208  	}
   209  	return p.Instance
   210  }
   211  
   212  // CurrentIdentity returns the identity of the VM stored in the current context.
   213  func CurrentIdentity(c context.Context) string {
   214  	p := getPayload(c)
   215  	if p == nil {
   216  		return "gce:anonymous"
   217  	}
   218  	// GCE hostnames must be unique per project, so <instance, project> suffices.
   219  	return fmt.Sprintf("gce:%s:%s", p.Instance, p.Project)
   220  }
   221  
   222  // Matches returns whether the current context contains a GCE VM metadata
   223  // token matching the given identity.
   224  func Matches(c context.Context, host, zone, proj string) bool {
   225  	p := getPayload(c)
   226  	if p == nil {
   227  		return false
   228  	}
   229  	logging.Debugf(c, "expecting VM token from %q in %q in %q", host, zone, proj)
   230  	return p.Instance == host && p.Zone == zone && p.Project == proj
   231  }
   232  
   233  // Middleware embeds a Payload in the context if the request contains a GCE VM
   234  // metadata token.
   235  func Middleware(c *router.Context, next router.Handler) {
   236  	if tok := c.Request.Header.Get(Header); tok != "" {
   237  		// TODO(smut): Support requests to other modules, versions.
   238  		aud := "https://" + info.DefaultVersionHostname(c.Request.Context())
   239  		logging.Debugf(c.Request.Context(), "expecting VM token for: %s", aud)
   240  		switch p, err := Verify(c.Request.Context(), tok); {
   241  		case transient.Tag.In(err):
   242  			logging.WithError(err).Errorf(c.Request.Context(), "transient error verifying VM token")
   243  			http.Error(c.Writer, "error: failed to verify VM token", http.StatusInternalServerError)
   244  			return
   245  		case err != nil:
   246  			logging.WithError(err).Errorf(c.Request.Context(), "invalid VM token")
   247  			http.Error(c.Writer, "error: invalid VM token", http.StatusUnauthorized)
   248  			return
   249  		case p.Audience != aud:
   250  			logging.Errorf(c.Request.Context(), "received VM token intended for: %s", p.Audience)
   251  			http.Error(c.Writer, "error: VM token audience mismatch", http.StatusUnauthorized)
   252  			return
   253  		default:
   254  			logging.Debugf(c.Request.Context(), "received VM token from %q in %q in %q for: %s", p.Instance, p.Zone, p.Project, p.Audience)
   255  			c.Request = c.Request.WithContext(withPayload(c.Request.Context(), p))
   256  		}
   257  	}
   258  	next(c)
   259  }
   260  
   261  func init() {
   262  	warmup.Register("gce/vmtoken", func(c context.Context) error {
   263  		_, err := signing.FetchGoogleOAuth2Certificates(c)
   264  		return err
   265  	})
   266  }