github.com/avenga/couper@v1.12.2/accesscontrol/jwk/jwks.go (about)

     1  package jwk
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/golang-jwt/jwt/v4"
    12  
    13  	"github.com/avenga/couper/config"
    14  	jsn "github.com/avenga/couper/json"
    15  )
    16  
    17  var alg2kty = map[string]string{
    18  	"RS256": "RSA",
    19  	"RS384": "RSA",
    20  	"RS512": "RSA",
    21  	"ES256": "EC",
    22  	"ES384": "EC",
    23  	"ES512": "EC",
    24  }
    25  
    26  type JWKSData struct {
    27  	Keys []*JWK `json:"keys"`
    28  }
    29  
    30  type JWKS struct {
    31  	syncedJSON *jsn.SyncedJSON
    32  }
    33  
    34  func NewJWKS(ctx context.Context, uri string, ttl string, maxStale string, transport http.RoundTripper) (*JWKS, error) {
    35  	timetolive, err := config.ParseDuration("jwks_ttl", ttl, time.Hour)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	maxStaleTime, err := config.ParseDuration("jwks_max_stale", maxStale, time.Hour)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	var file string
    44  	if strings.HasPrefix(uri, "file:") {
    45  		file = uri[5:]
    46  	} else if !strings.HasPrefix(uri, "http:") && !strings.HasPrefix(uri, "https:") {
    47  		return nil, fmt.Errorf("unsupported JWKS URI scheme: %q", uri)
    48  	}
    49  
    50  	jwks := &JWKS{}
    51  	jwks.syncedJSON, err = jsn.NewSyncedJSON(ctx, file, "jwks_url", uri, transport, "jwks", timetolive, maxStaleTime, jwks)
    52  	return jwks, err
    53  }
    54  
    55  func (j *JWKS) GetSigKeyForToken(token *jwt.Token) (interface{}, error) {
    56  	algorithm := token.Header["alg"]
    57  	if algorithm == nil {
    58  		return nil, fmt.Errorf("missing \"alg\" in JOSE header")
    59  	}
    60  	id := token.Header["kid"]
    61  	if id == nil {
    62  		id = ""
    63  	}
    64  	jwk, err := j.GetKey(id.(string), algorithm.(string), "sig")
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	if jwk == nil {
    70  		return nil, fmt.Errorf("no matching %s JWK for kid %q", algorithm, id)
    71  	}
    72  
    73  	return jwk.Key, nil
    74  }
    75  
    76  func (j *JWKS) GetKey(kid string, alg string, use string) (*JWK, error) {
    77  	keys, err := j.getKeys(kid)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	for _, key := range keys {
    83  		if key.Use == use {
    84  			if key.Algorithm == alg {
    85  				return key, nil
    86  			} else if key.Algorithm == "" {
    87  				if kty, exists := alg2kty[alg]; exists && key.KeyType == kty {
    88  					return key, nil
    89  				}
    90  			}
    91  		}
    92  	}
    93  	return nil, nil
    94  }
    95  
    96  func (j *JWKS) getKeys(kid string) ([]*JWK, error) {
    97  	var keys []*JWK
    98  
    99  	jwksData, err := j.Data()
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	if len(jwksData.Keys) == 0 {
   105  		return nil, fmt.Errorf("missing jwks key-data")
   106  	}
   107  
   108  	for _, key := range jwksData.Keys {
   109  		if key.KeyID == kid {
   110  			keys = append(keys, key)
   111  		}
   112  	}
   113  
   114  	return keys, nil
   115  }
   116  
   117  func (j *JWKS) Data() (*JWKSData, error) {
   118  	data, err := j.syncedJSON.Data()
   119  	// Ignore backend errors as long as we still get cached (stale) data.
   120  	jwksData, ok := data.(*JWKSData)
   121  	if !ok {
   122  		return nil, fmt.Errorf("received no valid JWKs data: %#v, %w", data, err)
   123  	}
   124  
   125  	return jwksData, nil
   126  }
   127  
   128  func (j *JWKS) Unmarshal(rawJSON []byte) (interface{}, error) {
   129  	jsonData := &JWKSData{}
   130  	err := json.Unmarshal(rawJSON, jsonData)
   131  	return jsonData, err
   132  }