github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/testing/fakejwt.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package testing
     5  
     6  import (
     7  	"crypto/rand"
     8  	"crypto/rsa"
     9  	"crypto/x509"
    10  	"encoding/pem"
    11  	"time"
    12  
    13  	"github.com/google/uuid"
    14  	"github.com/juju/errors"
    15  	"github.com/lestrrat-go/jwx/v2/jwa"
    16  	"github.com/lestrrat-go/jwx/v2/jwk"
    17  	"github.com/lestrrat-go/jwx/v2/jwt"
    18  )
    19  
    20  // JWTParams are the necessary params to issue a ready-to-go JWT.
    21  type JWTParams struct {
    22  	Controller string
    23  	User       string
    24  	Access     map[string]string
    25  }
    26  
    27  func getJWKS() (jwk.Set, []byte, error) {
    28  	keySet, err := rsa.GenerateKey(rand.Reader, 4096)
    29  	if err != nil {
    30  		return nil, nil, errors.Trace(err)
    31  	}
    32  
    33  	privateKeyPEM := pem.EncodeToMemory(
    34  		&pem.Block{
    35  			Type:  "RSA PRIVATE KEY",
    36  			Bytes: x509.MarshalPKCS1PrivateKey(keySet),
    37  		},
    38  	)
    39  
    40  	kid, err := uuid.NewRandom()
    41  	if err != nil {
    42  		return nil, nil, errors.Trace(err)
    43  	}
    44  
    45  	jwks, err := jwk.FromRaw(keySet.PublicKey)
    46  	if err != nil {
    47  		return nil, nil, errors.Trace(err)
    48  	}
    49  	err = jwks.Set(jwk.KeyIDKey, kid.String())
    50  	if err != nil {
    51  		return nil, nil, errors.Trace(err)
    52  	}
    53  
    54  	err = jwks.Set(jwk.KeyUsageKey, "sig")
    55  	if err != nil {
    56  		return nil, nil, errors.Trace(err)
    57  	}
    58  
    59  	err = jwks.Set(jwk.AlgorithmKey, jwa.RS256)
    60  	if err != nil {
    61  		return nil, nil, errors.Trace(err)
    62  	}
    63  
    64  	ks := jwk.NewSet()
    65  	err = ks.AddKey(jwks)
    66  	if err != nil {
    67  		return nil, nil, errors.Trace(err)
    68  	}
    69  
    70  	return ks, privateKeyPEM, nil
    71  }
    72  
    73  func generateJTI() (string, error) {
    74  	id, err := uuid.NewRandom()
    75  	if err != nil {
    76  		return "", err
    77  	}
    78  	return id.String(), nil
    79  }
    80  
    81  // NewJWKSet returns a new key set and signing key.
    82  func NewJWKSet() (jwk.Set, jwk.Key, error) {
    83  	jwkSet, pkeyPem, err := getJWKS()
    84  	if err != nil {
    85  		return nil, nil, errors.Trace(err)
    86  	}
    87  
    88  	block, _ := pem.Decode(pkeyPem)
    89  
    90  	pkeyDecoded, err := x509.ParsePKCS1PrivateKey(block.Bytes)
    91  	if err != nil {
    92  		return nil, nil, errors.Trace(err)
    93  	}
    94  
    95  	signingKey, err := jwk.FromRaw(pkeyDecoded)
    96  	if err != nil {
    97  		return nil, nil, errors.Trace(err)
    98  	}
    99  	return jwkSet, signingKey, nil
   100  }
   101  
   102  // NewJWT returns a parsed jwt.
   103  func NewJWT(params JWTParams) (jwt.Token, error) {
   104  	jwkSet, signingKey, err := NewJWKSet()
   105  	if err != nil {
   106  		return nil, errors.Trace(err)
   107  	}
   108  	tok, err := EncodedJWT(params, jwkSet, signingKey)
   109  	if err != nil {
   110  		return nil, errors.Trace(err)
   111  	}
   112  	return jwt.Parse(
   113  		tok,
   114  		jwt.WithKeySet(jwkSet),
   115  	)
   116  }
   117  
   118  // EncodedJWT returns jwt as bytes signed by the specified key.
   119  func EncodedJWT(params JWTParams, jwkSet jwk.Set, signingKey jwk.Key) ([]byte, error) {
   120  	jti, err := generateJTI()
   121  	if err != nil {
   122  		return nil, errors.Trace(err)
   123  	}
   124  	pubKey, ok := jwkSet.Key(jwkSet.Len() - 1)
   125  	if !ok {
   126  		return nil, errors.Errorf("no jwk found")
   127  	}
   128  
   129  	err = signingKey.Set(jwk.AlgorithmKey, jwa.RS256)
   130  	if err != nil {
   131  		return nil, errors.Trace(err)
   132  	}
   133  	err = signingKey.Set(jwk.KeyIDKey, pubKey.KeyID())
   134  	if err != nil {
   135  		return nil, errors.Trace(err)
   136  	}
   137  
   138  	token, err := jwt.NewBuilder().
   139  		Audience([]string{params.Controller}).
   140  		Subject(params.User).
   141  		Issuer("test").
   142  		JwtID(jti).
   143  		Claim("access", params.Access).
   144  		Expiration(time.Now().Add(time.Hour)).
   145  		Build()
   146  	if err != nil {
   147  		return nil, errors.Trace(err)
   148  	}
   149  
   150  	freshToken, err := jwt.Sign(
   151  		token,
   152  		jwt.WithKey(
   153  			jwa.RS256,
   154  			signingKey,
   155  		),
   156  	)
   157  	return freshToken, errors.Trace(err)
   158  }