github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/testing/fakejwt.go (about) 1 // Copyright 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package testing 5 6 import ( 7 "crypto/rand" 8 "crypto/rsa" 9 "crypto/x509" 10 "encoding/pem" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/juju/errors" 15 "github.com/lestrrat-go/jwx/v2/jwa" 16 "github.com/lestrrat-go/jwx/v2/jwk" 17 "github.com/lestrrat-go/jwx/v2/jwt" 18 ) 19 20 // JWTParams are the necessary params to issue a ready-to-go JWT. 21 type JWTParams struct { 22 Controller string 23 User string 24 Access map[string]string 25 } 26 27 func getJWKS() (jwk.Set, []byte, error) { 28 keySet, err := rsa.GenerateKey(rand.Reader, 4096) 29 if err != nil { 30 return nil, nil, errors.Trace(err) 31 } 32 33 privateKeyPEM := pem.EncodeToMemory( 34 &pem.Block{ 35 Type: "RSA PRIVATE KEY", 36 Bytes: x509.MarshalPKCS1PrivateKey(keySet), 37 }, 38 ) 39 40 kid, err := uuid.NewRandom() 41 if err != nil { 42 return nil, nil, errors.Trace(err) 43 } 44 45 jwks, err := jwk.FromRaw(keySet.PublicKey) 46 if err != nil { 47 return nil, nil, errors.Trace(err) 48 } 49 err = jwks.Set(jwk.KeyIDKey, kid.String()) 50 if err != nil { 51 return nil, nil, errors.Trace(err) 52 } 53 54 err = jwks.Set(jwk.KeyUsageKey, "sig") 55 if err != nil { 56 return nil, nil, errors.Trace(err) 57 } 58 59 err = jwks.Set(jwk.AlgorithmKey, jwa.RS256) 60 if err != nil { 61 return nil, nil, errors.Trace(err) 62 } 63 64 ks := jwk.NewSet() 65 err = ks.AddKey(jwks) 66 if err != nil { 67 return nil, nil, errors.Trace(err) 68 } 69 70 return ks, privateKeyPEM, nil 71 } 72 73 func generateJTI() (string, error) { 74 id, err := uuid.NewRandom() 75 if err != nil { 76 return "", err 77 } 78 return id.String(), nil 79 } 80 81 // NewJWKSet returns a new key set and signing key. 82 func NewJWKSet() (jwk.Set, jwk.Key, error) { 83 jwkSet, pkeyPem, err := getJWKS() 84 if err != nil { 85 return nil, nil, errors.Trace(err) 86 } 87 88 block, _ := pem.Decode(pkeyPem) 89 90 pkeyDecoded, err := x509.ParsePKCS1PrivateKey(block.Bytes) 91 if err != nil { 92 return nil, nil, errors.Trace(err) 93 } 94 95 signingKey, err := jwk.FromRaw(pkeyDecoded) 96 if err != nil { 97 return nil, nil, errors.Trace(err) 98 } 99 return jwkSet, signingKey, nil 100 } 101 102 // NewJWT returns a parsed jwt. 103 func NewJWT(params JWTParams) (jwt.Token, error) { 104 jwkSet, signingKey, err := NewJWKSet() 105 if err != nil { 106 return nil, errors.Trace(err) 107 } 108 tok, err := EncodedJWT(params, jwkSet, signingKey) 109 if err != nil { 110 return nil, errors.Trace(err) 111 } 112 return jwt.Parse( 113 tok, 114 jwt.WithKeySet(jwkSet), 115 ) 116 } 117 118 // EncodedJWT returns jwt as bytes signed by the specified key. 119 func EncodedJWT(params JWTParams, jwkSet jwk.Set, signingKey jwk.Key) ([]byte, error) { 120 jti, err := generateJTI() 121 if err != nil { 122 return nil, errors.Trace(err) 123 } 124 pubKey, ok := jwkSet.Key(jwkSet.Len() - 1) 125 if !ok { 126 return nil, errors.Errorf("no jwk found") 127 } 128 129 err = signingKey.Set(jwk.AlgorithmKey, jwa.RS256) 130 if err != nil { 131 return nil, errors.Trace(err) 132 } 133 err = signingKey.Set(jwk.KeyIDKey, pubKey.KeyID()) 134 if err != nil { 135 return nil, errors.Trace(err) 136 } 137 138 token, err := jwt.NewBuilder(). 139 Audience([]string{params.Controller}). 140 Subject(params.User). 141 Issuer("test"). 142 JwtID(jti). 143 Claim("access", params.Access). 144 Expiration(time.Now().Add(time.Hour)). 145 Build() 146 if err != nil { 147 return nil, errors.Trace(err) 148 } 149 150 freshToken, err := jwt.Sign( 151 token, 152 jwt.WithKey( 153 jwa.RS256, 154 signingKey, 155 ), 156 ) 157 return freshToken, errors.Trace(err) 158 }