github.com/blend/go-sdk@v1.20220411.3/oauth/public_key_cache_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package oauth
     9  
    10  import (
    11  	"context"
    12  	"crypto/rsa"
    13  	"encoding/json"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/golang-jwt/jwt"
    20  
    21  	"github.com/blend/go-sdk/assert"
    22  	"github.com/blend/go-sdk/jwk"
    23  	"github.com/blend/go-sdk/r2"
    24  	"github.com/blend/go-sdk/uuid"
    25  )
    26  
    27  func Test_PublicKeyCache_Keyfunc(t *testing.T) {
    28  	it := assert.New(t)
    29  
    30  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
    31  	it.Nil(err)
    32  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
    33  	it.Nil(err)
    34  	keys := []jwk.JWK{
    35  		createJWK(pk0),
    36  		createJWK(pk1),
    37  	}
    38  
    39  	var didCallResponder bool
    40  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    41  		didCallResponder = true
    42  	}))
    43  	defer keysResponder.Close()
    44  
    45  	cache := new(PublicKeyCache)
    46  	cache.FetchPublicKeysDefaults = []r2.Option{
    47  		r2.OptURL(keysResponder.URL),
    48  	}
    49  	cache.current = &PublicKeysResponse{
    50  		CacheControl: "public, max-age=23196, must-revalidate, no-transform",
    51  		Expires:      time.Now().UTC().AddDate(0, 0, 1),
    52  		Keys:         jwkLookup(keys),
    53  	}
    54  
    55  	keyfunc := cache.Keyfunc(context.TODO())
    56  
    57  	pub, err := keyfunc(&jwt.Token{
    58  		Header: map[string]interface{}{
    59  			"kid": keys[0].KID,
    60  		},
    61  	})
    62  
    63  	it.Nil(err)
    64  
    65  	typedPub, ok := pub.(*rsa.PublicKey)
    66  	it.True(ok)
    67  	it.Equal(*pk0.PublicKey.N, *typedPub.N)
    68  	it.False(didCallResponder)
    69  }
    70  
    71  func Test_PublicKeyCache_Keyfunc_MissingKIDHeader(t *testing.T) {
    72  	it := assert.New(t)
    73  
    74  	cache := new(PublicKeyCache)
    75  	keyfunc := cache.Keyfunc(context.TODO())
    76  	pub, err := keyfunc(&jwt.Token{})
    77  	it.NotNil(err)
    78  	it.Nil(pub)
    79  }
    80  
    81  func Test_PublicKeyCache_Keyfunc_InvalidKID(t *testing.T) {
    82  	it := assert.New(t)
    83  
    84  	cache := new(PublicKeyCache)
    85  	keyfunc := cache.Keyfunc(context.TODO())
    86  	pub, err := keyfunc(&jwt.Token{
    87  		Header: map[string]interface{}{
    88  			"kid": 1234,
    89  		},
    90  	})
    91  	it.NotNil(err)
    92  	it.Nil(pub)
    93  }
    94  
    95  func Test_PublicKeyCache_Get(t *testing.T) {
    96  	it := assert.New(t)
    97  
    98  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
    99  	it.Nil(err)
   100  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   101  	it.Nil(err)
   102  	keys := []jwk.JWK{
   103  		createJWK(pk0),
   104  		createJWK(pk1),
   105  	}
   106  	var didCallResponder bool
   107  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   108  		didCallResponder = true
   109  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   110  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
   111  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
   112  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
   113  		rw.WriteHeader(http.StatusOK)
   114  		_ = json.NewEncoder(rw).Encode(struct {
   115  			Keys []jwk.JWK `json:"keys"`
   116  		}{
   117  			Keys: keys,
   118  		})
   119  	}))
   120  	defer keysResponder.Close()
   121  
   122  	cache := new(PublicKeyCache)
   123  	cache.FetchPublicKeysDefaults = []r2.Option{
   124  		r2.OptURL(keysResponder.URL),
   125  	}
   126  
   127  	pub, err := cache.Get(context.TODO(), keys[0].KID)
   128  	it.Nil(err)
   129  	it.NotNil(pub)
   130  	it.Equal(*pk0.PublicKey.N, *pub.N)
   131  	it.True(didCallResponder)
   132  
   133  	didCallResponder = false
   134  
   135  	pub, err = cache.Get(context.TODO(), keys[1].KID)
   136  	it.Nil(err)
   137  	it.NotNil(pub)
   138  	it.Equal(*pk1.PublicKey.N, *pub.N)
   139  	it.False(didCallResponder, "we should have cached the results")
   140  }
   141  
   142  func Test_PublicKeyCache_Get_NoRefresh(t *testing.T) {
   143  	it := assert.New(t)
   144  
   145  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
   146  	it.Nil(err)
   147  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   148  	it.Nil(err)
   149  	keys := []jwk.JWK{
   150  		createJWK(pk0),
   151  		createJWK(pk1),
   152  	}
   153  	var didCallResponder bool
   154  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   155  		didCallResponder = true
   156  	}))
   157  	defer keysResponder.Close()
   158  
   159  	cache := new(PublicKeyCache)
   160  	cache.current = &PublicKeysResponse{
   161  		CacheControl: "public, max-age=23196, must-revalidate, no-transform",
   162  		Expires:      time.Now().UTC().AddDate(0, 0, 1),
   163  		Keys:         jwkLookup(keys),
   164  	}
   165  	cache.FetchPublicKeysDefaults = []r2.Option{
   166  		r2.OptURL(keysResponder.URL),
   167  	}
   168  
   169  	pub, err := cache.Get(context.TODO(), keys[0].KID)
   170  	it.Nil(err)
   171  	it.NotNil(pub)
   172  	it.Equal(*pk0.PublicKey.N, *pub.N)
   173  	it.False(didCallResponder)
   174  }
   175  
   176  func Test_PublicKeyCache_Get_Refresh(t *testing.T) {
   177  	it := assert.New(t)
   178  
   179  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
   180  	it.Nil(err)
   181  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   182  	it.Nil(err)
   183  	keys := []jwk.JWK{
   184  		createJWK(pk0),
   185  		createJWK(pk1),
   186  	}
   187  	var didCallResponder bool
   188  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   189  		didCallResponder = true
   190  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   191  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
   192  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
   193  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
   194  		rw.WriteHeader(200)
   195  		_ = json.NewEncoder(rw).Encode(struct {
   196  			Keys []jwk.JWK `json:"keys"`
   197  		}{
   198  			Keys: keys,
   199  		})
   200  	}))
   201  	defer keysResponder.Close()
   202  
   203  	cache := new(PublicKeyCache)
   204  	cache.current = &PublicKeysResponse{
   205  		CacheControl: "public, max-age=23196, must-revalidate, no-transform",
   206  		Expires:      time.Now().UTC().AddDate(0, 0, -1),
   207  		Keys:         jwkLookup(keys),
   208  	}
   209  	cache.FetchPublicKeysDefaults = []r2.Option{
   210  		r2.OptURL(keysResponder.URL),
   211  	}
   212  
   213  	pub, err := cache.Get(context.TODO(), keys[0].KID)
   214  	it.Nil(err)
   215  	it.NotNil(pub)
   216  	it.Equal(*pk0.PublicKey.N, *pub.N)
   217  	it.True(didCallResponder)
   218  }
   219  
   220  func Test_PublicKeyCache_Get_RefreshOnMiss(t *testing.T) {
   221  	it := assert.New(t)
   222  
   223  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
   224  	it.Nil(err)
   225  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   226  	it.Nil(err)
   227  	keys := []jwk.JWK{
   228  		createJWK(pk0),
   229  		createJWK(pk1),
   230  	}
   231  	var didCallResponder bool
   232  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   233  		didCallResponder = true
   234  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   235  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
   236  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
   237  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
   238  		rw.WriteHeader(200)
   239  		_ = json.NewEncoder(rw).Encode(struct {
   240  			Keys []jwk.JWK `json:"keys"`
   241  		}{
   242  			Keys: keys,
   243  		})
   244  	}))
   245  	defer keysResponder.Close()
   246  
   247  	cache := new(PublicKeyCache)
   248  	cache.current = &PublicKeysResponse{
   249  		CacheControl: "public, max-age=23196, must-revalidate, no-transform",
   250  		Expires:      time.Now().UTC().AddDate(0, 0, -1),
   251  		Keys:         jwkLookup(keys),
   252  	}
   253  	cache.FetchPublicKeysDefaults = []r2.Option{
   254  		r2.OptURL(keysResponder.URL),
   255  	}
   256  
   257  	pub, err := cache.Get(context.TODO(), uuid.V4().String())
   258  	it.NotNil(err)
   259  	it.Nil(pub)
   260  	it.True(didCallResponder)
   261  }