github.com/blend/go-sdk@v1.20220411.3/oauth/public_key_cache_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package oauth 9 10 import ( 11 "context" 12 "crypto/rsa" 13 "encoding/json" 14 "net/http" 15 "net/http/httptest" 16 "testing" 17 "time" 18 19 "github.com/golang-jwt/jwt" 20 21 "github.com/blend/go-sdk/assert" 22 "github.com/blend/go-sdk/jwk" 23 "github.com/blend/go-sdk/r2" 24 "github.com/blend/go-sdk/uuid" 25 ) 26 27 func Test_PublicKeyCache_Keyfunc(t *testing.T) { 28 it := assert.New(t) 29 30 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 31 it.Nil(err) 32 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 33 it.Nil(err) 34 keys := []jwk.JWK{ 35 createJWK(pk0), 36 createJWK(pk1), 37 } 38 39 var didCallResponder bool 40 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 41 didCallResponder = true 42 })) 43 defer keysResponder.Close() 44 45 cache := new(PublicKeyCache) 46 cache.FetchPublicKeysDefaults = []r2.Option{ 47 r2.OptURL(keysResponder.URL), 48 } 49 cache.current = &PublicKeysResponse{ 50 CacheControl: "public, max-age=23196, must-revalidate, no-transform", 51 Expires: time.Now().UTC().AddDate(0, 0, 1), 52 Keys: jwkLookup(keys), 53 } 54 55 keyfunc := cache.Keyfunc(context.TODO()) 56 57 pub, err := keyfunc(&jwt.Token{ 58 Header: map[string]interface{}{ 59 "kid": keys[0].KID, 60 }, 61 }) 62 63 it.Nil(err) 64 65 typedPub, ok := pub.(*rsa.PublicKey) 66 it.True(ok) 67 it.Equal(*pk0.PublicKey.N, *typedPub.N) 68 it.False(didCallResponder) 69 } 70 71 func Test_PublicKeyCache_Keyfunc_MissingKIDHeader(t *testing.T) { 72 it := assert.New(t) 73 74 cache := new(PublicKeyCache) 75 keyfunc := cache.Keyfunc(context.TODO()) 76 pub, err := keyfunc(&jwt.Token{}) 77 it.NotNil(err) 78 it.Nil(pub) 79 } 80 81 func Test_PublicKeyCache_Keyfunc_InvalidKID(t *testing.T) { 82 it := assert.New(t) 83 84 cache := new(PublicKeyCache) 85 keyfunc := cache.Keyfunc(context.TODO()) 86 pub, err := keyfunc(&jwt.Token{ 87 Header: map[string]interface{}{ 88 "kid": 1234, 89 }, 90 }) 91 it.NotNil(err) 92 it.Nil(pub) 93 } 94 95 func Test_PublicKeyCache_Get(t *testing.T) { 96 it := assert.New(t) 97 98 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 99 it.Nil(err) 100 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 101 it.Nil(err) 102 keys := []jwk.JWK{ 103 createJWK(pk0), 104 createJWK(pk1), 105 } 106 var didCallResponder bool 107 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 108 didCallResponder = true 109 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 110 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 111 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 112 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 113 rw.WriteHeader(http.StatusOK) 114 _ = json.NewEncoder(rw).Encode(struct { 115 Keys []jwk.JWK `json:"keys"` 116 }{ 117 Keys: keys, 118 }) 119 })) 120 defer keysResponder.Close() 121 122 cache := new(PublicKeyCache) 123 cache.FetchPublicKeysDefaults = []r2.Option{ 124 r2.OptURL(keysResponder.URL), 125 } 126 127 pub, err := cache.Get(context.TODO(), keys[0].KID) 128 it.Nil(err) 129 it.NotNil(pub) 130 it.Equal(*pk0.PublicKey.N, *pub.N) 131 it.True(didCallResponder) 132 133 didCallResponder = false 134 135 pub, err = cache.Get(context.TODO(), keys[1].KID) 136 it.Nil(err) 137 it.NotNil(pub) 138 it.Equal(*pk1.PublicKey.N, *pub.N) 139 it.False(didCallResponder, "we should have cached the results") 140 } 141 142 func Test_PublicKeyCache_Get_NoRefresh(t *testing.T) { 143 it := assert.New(t) 144 145 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 146 it.Nil(err) 147 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 148 it.Nil(err) 149 keys := []jwk.JWK{ 150 createJWK(pk0), 151 createJWK(pk1), 152 } 153 var didCallResponder bool 154 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 155 didCallResponder = true 156 })) 157 defer keysResponder.Close() 158 159 cache := new(PublicKeyCache) 160 cache.current = &PublicKeysResponse{ 161 CacheControl: "public, max-age=23196, must-revalidate, no-transform", 162 Expires: time.Now().UTC().AddDate(0, 0, 1), 163 Keys: jwkLookup(keys), 164 } 165 cache.FetchPublicKeysDefaults = []r2.Option{ 166 r2.OptURL(keysResponder.URL), 167 } 168 169 pub, err := cache.Get(context.TODO(), keys[0].KID) 170 it.Nil(err) 171 it.NotNil(pub) 172 it.Equal(*pk0.PublicKey.N, *pub.N) 173 it.False(didCallResponder) 174 } 175 176 func Test_PublicKeyCache_Get_Refresh(t *testing.T) { 177 it := assert.New(t) 178 179 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 180 it.Nil(err) 181 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 182 it.Nil(err) 183 keys := []jwk.JWK{ 184 createJWK(pk0), 185 createJWK(pk1), 186 } 187 var didCallResponder bool 188 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 189 didCallResponder = true 190 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 191 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 192 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 193 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 194 rw.WriteHeader(200) 195 _ = json.NewEncoder(rw).Encode(struct { 196 Keys []jwk.JWK `json:"keys"` 197 }{ 198 Keys: keys, 199 }) 200 })) 201 defer keysResponder.Close() 202 203 cache := new(PublicKeyCache) 204 cache.current = &PublicKeysResponse{ 205 CacheControl: "public, max-age=23196, must-revalidate, no-transform", 206 Expires: time.Now().UTC().AddDate(0, 0, -1), 207 Keys: jwkLookup(keys), 208 } 209 cache.FetchPublicKeysDefaults = []r2.Option{ 210 r2.OptURL(keysResponder.URL), 211 } 212 213 pub, err := cache.Get(context.TODO(), keys[0].KID) 214 it.Nil(err) 215 it.NotNil(pub) 216 it.Equal(*pk0.PublicKey.N, *pub.N) 217 it.True(didCallResponder) 218 } 219 220 func Test_PublicKeyCache_Get_RefreshOnMiss(t *testing.T) { 221 it := assert.New(t) 222 223 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 224 it.Nil(err) 225 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 226 it.Nil(err) 227 keys := []jwk.JWK{ 228 createJWK(pk0), 229 createJWK(pk1), 230 } 231 var didCallResponder bool 232 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 233 didCallResponder = true 234 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 235 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 236 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 237 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 238 rw.WriteHeader(200) 239 _ = json.NewEncoder(rw).Encode(struct { 240 Keys []jwk.JWK `json:"keys"` 241 }{ 242 Keys: keys, 243 }) 244 })) 245 defer keysResponder.Close() 246 247 cache := new(PublicKeyCache) 248 cache.current = &PublicKeysResponse{ 249 CacheControl: "public, max-age=23196, must-revalidate, no-transform", 250 Expires: time.Now().UTC().AddDate(0, 0, -1), 251 Keys: jwkLookup(keys), 252 } 253 cache.FetchPublicKeysDefaults = []r2.Option{ 254 r2.OptURL(keysResponder.URL), 255 } 256 257 pub, err := cache.Get(context.TODO(), uuid.V4().String()) 258 it.NotNil(err) 259 it.Nil(pub) 260 it.True(didCallResponder) 261 }