github.com/avenga/couper@v1.12.2/accesscontrol/jwk/jwks.go (about) 1 package jwk 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "strings" 9 "time" 10 11 "github.com/golang-jwt/jwt/v4" 12 13 "github.com/avenga/couper/config" 14 jsn "github.com/avenga/couper/json" 15 ) 16 17 var alg2kty = map[string]string{ 18 "RS256": "RSA", 19 "RS384": "RSA", 20 "RS512": "RSA", 21 "ES256": "EC", 22 "ES384": "EC", 23 "ES512": "EC", 24 } 25 26 type JWKSData struct { 27 Keys []*JWK `json:"keys"` 28 } 29 30 type JWKS struct { 31 syncedJSON *jsn.SyncedJSON 32 } 33 34 func NewJWKS(ctx context.Context, uri string, ttl string, maxStale string, transport http.RoundTripper) (*JWKS, error) { 35 timetolive, err := config.ParseDuration("jwks_ttl", ttl, time.Hour) 36 if err != nil { 37 return nil, err 38 } 39 maxStaleTime, err := config.ParseDuration("jwks_max_stale", maxStale, time.Hour) 40 if err != nil { 41 return nil, err 42 } 43 var file string 44 if strings.HasPrefix(uri, "file:") { 45 file = uri[5:] 46 } else if !strings.HasPrefix(uri, "http:") && !strings.HasPrefix(uri, "https:") { 47 return nil, fmt.Errorf("unsupported JWKS URI scheme: %q", uri) 48 } 49 50 jwks := &JWKS{} 51 jwks.syncedJSON, err = jsn.NewSyncedJSON(ctx, file, "jwks_url", uri, transport, "jwks", timetolive, maxStaleTime, jwks) 52 return jwks, err 53 } 54 55 func (j *JWKS) GetSigKeyForToken(token *jwt.Token) (interface{}, error) { 56 algorithm := token.Header["alg"] 57 if algorithm == nil { 58 return nil, fmt.Errorf("missing \"alg\" in JOSE header") 59 } 60 id := token.Header["kid"] 61 if id == nil { 62 id = "" 63 } 64 jwk, err := j.GetKey(id.(string), algorithm.(string), "sig") 65 if err != nil { 66 return nil, err 67 } 68 69 if jwk == nil { 70 return nil, fmt.Errorf("no matching %s JWK for kid %q", algorithm, id) 71 } 72 73 return jwk.Key, nil 74 } 75 76 func (j *JWKS) GetKey(kid string, alg string, use string) (*JWK, error) { 77 keys, err := j.getKeys(kid) 78 if err != nil { 79 return nil, err 80 } 81 82 for _, key := range keys { 83 if key.Use == use { 84 if key.Algorithm == alg { 85 return key, nil 86 } else if key.Algorithm == "" { 87 if kty, exists := alg2kty[alg]; exists && key.KeyType == kty { 88 return key, nil 89 } 90 } 91 } 92 } 93 return nil, nil 94 } 95 96 func (j *JWKS) getKeys(kid string) ([]*JWK, error) { 97 var keys []*JWK 98 99 jwksData, err := j.Data() 100 if err != nil { 101 return nil, err 102 } 103 104 if len(jwksData.Keys) == 0 { 105 return nil, fmt.Errorf("missing jwks key-data") 106 } 107 108 for _, key := range jwksData.Keys { 109 if key.KeyID == kid { 110 keys = append(keys, key) 111 } 112 } 113 114 return keys, nil 115 } 116 117 func (j *JWKS) Data() (*JWKSData, error) { 118 data, err := j.syncedJSON.Data() 119 // Ignore backend errors as long as we still get cached (stale) data. 120 jwksData, ok := data.(*JWKSData) 121 if !ok { 122 return nil, fmt.Errorf("received no valid JWKs data: %#v, %w", data, err) 123 } 124 125 return jwksData, nil 126 } 127 128 func (j *JWKS) Unmarshal(rawJSON []byte) (interface{}, error) { 129 jsonData := &JWKSData{} 130 err := json.Unmarshal(rawJSON, jsonData) 131 return jsonData, err 132 }