github.com/greenpau/go-authcrunch@v1.1.4/pkg/idp/oauth/jwks.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/ecdsa"
    20  	"crypto/elliptic"
    21  	"crypto/rsa"
    22  	"crypto/x509"
    23  	"encoding/base64"
    24  	"encoding/binary"
    25  	"encoding/pem"
    26  	"fmt"
    27  	"github.com/greenpau/go-authcrunch/pkg/errors"
    28  	"github.com/greenpau/go-authcrunch/pkg/util"
    29  	"io/ioutil"
    30  	"math/big"
    31  	"strings"
    32  )
    33  
    34  // JwksKey is a JSON object that represents a cryptographic key.
    35  // See https://tools.ietf.org/html/rfc7517#section-4,
    36  // https://tools.ietf.org/html/rfc7518#section-6.3
    37  type JwksKey struct {
    38  	Algorithm    string `json:"alg,omitempty" xml:"alg,omitempty" yaml:"alg,omitempty"`
    39  	Exponent     string `json:"e,omitempty" xml:"e,omitempty" yaml:"e,omitempty"`
    40  	KeyID        string `json:"kid,omitempty" xml:"kid,omitempty" yaml:"kid,omitempty"`
    41  	KeyType      string `json:"kty,omitempty" xml:"kty,omitempty" yaml:"kty,omitempty"`
    42  	Modulus      string `json:"n,omitempty" xml:"n,omitempty" yaml:"n,omitempty"`
    43  	PublicKeyUse string `json:"use,omitempty" xml:"use,omitempty" yaml:"use,omitempty"`
    44  	NotBefore    string `json:"nbf,omitempty" xml:"nbf,omitempty" yaml:"nbf,omitempty"`
    45  
    46  	Curve  string `json:"crv,omitempty" xml:"crv,omitempty" yaml:"crv,omitempty"`
    47  	CoordX string `json:"x,omitempty" xml:"x,omitempty" yaml:"x,omitempty"`
    48  	CoordY string `json:"y,omitempty" xml:"y,omitempty" yaml:"y,omitempty"`
    49  
    50  	SharedSecret string `json:"k,omitempty" xml:"k,omitempty" yaml:"k,omitempty"`
    51  
    52  	publicKey interface{}
    53  }
    54  
    55  // Validate returns error if JwksKey does not contain relevant information.
    56  func (k *JwksKey) Validate() error {
    57  	if k.KeyID == "" {
    58  		return errors.ErrJwksKeyIDEmpty
    59  	}
    60  
    61  	switch k.KeyType {
    62  	case "RSA":
    63  		switch k.Algorithm {
    64  		case "RS256", "RS384", "RS512", "RSA-OAEP-256", "":
    65  		default:
    66  			return errors.ErrJwksKeyAlgoUnsupported.WithArgs(k.Algorithm, k.KeyID)
    67  		}
    68  	case "EC":
    69  		switch k.Curve {
    70  		case "P-256", "P-384", "P-521":
    71  		case "":
    72  			return errors.ErrJwksKeyCurveEmpty.WithArgs(k.KeyID)
    73  		default:
    74  			return errors.ErrJwksKeyCurveUnsupported.WithArgs(k.Curve, k.KeyID)
    75  		}
    76  		if k.CoordX == "" || k.CoordY == "" {
    77  			return errors.ErrJwksKeyCurveCoordNotFound.WithArgs(k.KeyID)
    78  		}
    79  	case "oct":
    80  		if k.SharedSecret == "" {
    81  			return errors.ErrJwksKeySharedSecretEmpty.WithArgs(k.KeyID)
    82  		}
    83  		switch k.Algorithm {
    84  		case "HS256", "HS384", "HS512", "":
    85  		default:
    86  			return errors.ErrJwksKeyAlgoUnsupported.WithArgs(k.Algorithm, k.KeyID)
    87  		}
    88  	case "":
    89  		return errors.ErrJwksKeyTypeEmpty.WithArgs(k.KeyID)
    90  	default:
    91  		return errors.ErrJwksKeyTypeUnsupported.WithArgs(k.KeyType, k.KeyID)
    92  	}
    93  
    94  	switch k.PublicKeyUse {
    95  	case "sig", "enc", "":
    96  	default:
    97  		return errors.ErrJwksKeyUsageUnsupported.WithArgs(k.PublicKeyUse, k.KeyID)
    98  	}
    99  
   100  	switch k.KeyType {
   101  	case "RSA":
   102  		if k.Exponent == "" {
   103  			return errors.ErrJwksKeyExponentEmpty.WithArgs(k.KeyID)
   104  		}
   105  
   106  		if k.Modulus == "" {
   107  			return errors.ErrJwksKeyModulusEmpty.WithArgs(k.KeyID)
   108  		}
   109  
   110  		// Add padding
   111  		if i := len(k.Modulus) % 4; i != 0 {
   112  			k.Modulus += strings.Repeat("=", 4-i)
   113  		}
   114  
   115  		var mod []byte
   116  		var err error
   117  		if strings.ContainsAny(k.Modulus, "/+") {
   118  			// This decoding works with + and / signs. (legacy)
   119  			mod, err = base64.StdEncoding.DecodeString(k.Modulus)
   120  		} else {
   121  			// This decoding works with - and _ signs.
   122  			mod, err = base64.URLEncoding.DecodeString(k.Modulus)
   123  		}
   124  
   125  		if err != nil {
   126  			return errors.ErrJwksKeyDecodeModulus.WithArgs(k.KeyID, k.Modulus, err)
   127  		}
   128  		n := big.NewInt(0)
   129  		n.SetBytes(mod)
   130  
   131  		exp, err := base64.StdEncoding.DecodeString(k.Exponent)
   132  		if err != nil {
   133  			return errors.ErrJwksKeyDecodeExponent.WithArgs(k.KeyID, err)
   134  		}
   135  		// The "e" (exponent) parameter contains the exponent value for the RSA
   136  		// public key.  It is represented as a Base64urlUInt-encoded value.
   137  		//
   138  		// For instance, when representing the value 65537, the octet sequence
   139  		// to be base64url-encoded MUST consist of the three octets [1, 0, 1];
   140  		// the resulting representation for this value is "AQAB".
   141  		var eb []byte
   142  		if len(exp) < 8 {
   143  			eb = make([]byte, 8-len(exp), 8)
   144  			eb = append(eb, exp...)
   145  		} else {
   146  			eb = exp
   147  		}
   148  		er := bytes.NewReader(eb)
   149  		var e uint64
   150  		if err := binary.Read(er, binary.BigEndian, &e); err != nil {
   151  			return errors.ErrJwksKeyConvExponent.WithArgs(k.KeyID, err)
   152  		}
   153  		k.publicKey = &rsa.PublicKey{N: n, E: int(e)}
   154  	case "EC":
   155  		var expByteCount int
   156  		pk := &ecdsa.PublicKey{}
   157  		switch k.Curve {
   158  		case "P-256":
   159  			pk.Curve = elliptic.P256()
   160  			expByteCount = 32
   161  		case "P-384":
   162  			pk.Curve = elliptic.P384()
   163  			expByteCount = 48
   164  		case "P-521":
   165  			pk.Curve = elliptic.P521()
   166  			expByteCount = 66
   167  		}
   168  
   169  		for i, c := range []string{k.CoordX, k.CoordY} {
   170  			ltr := "X"
   171  			if i > 0 {
   172  				ltr = "Y"
   173  			}
   174  			b, err := base64.RawURLEncoding.DecodeString(c)
   175  			if err != nil {
   176  				return errors.ErrJwksKeyDecodeCoord.WithArgs(k.KeyID, ltr, err)
   177  			}
   178  			if len(b) != expByteCount {
   179  				return errors.ErrJwksKeyCoordLength.WithArgs(k.KeyID, ltr, len(b), expByteCount)
   180  			}
   181  			bi := big.NewInt(0)
   182  			bi.SetBytes(b)
   183  			if i == 0 {
   184  				pk.X = bi
   185  				continue
   186  			}
   187  			pk.Y = bi
   188  		}
   189  		k.publicKey = pk
   190  	case "oct":
   191  		key, err := base64.RawURLEncoding.DecodeString(k.SharedSecret)
   192  		if err != nil {
   193  			return errors.ErrJwksKeyDecodeSharedSecret.WithArgs(k.KeyID, err)
   194  		}
   195  		k.publicKey = key
   196  	default:
   197  		return errors.ErrJwksKeyTypeNotImplemented.WithArgs(k.KeyID, k.KeyType, k)
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  // GetPublic returns pointer to public key.
   204  func (k *JwksKey) GetPublic() interface{} {
   205  	return k.publicKey
   206  }
   207  
   208  func createJwksKeyFromPubKey(pk *rsa.PublicKey) *JwksKey {
   209  	b := make([]byte, 8)
   210  	binary.BigEndian.PutUint64(b, uint64(pk.E))
   211  	i := 0
   212  	for ; i < len(b); i++ {
   213  		if b[i] != 0x0 {
   214  			break
   215  		}
   216  	}
   217  
   218  	return &JwksKey{
   219  		KeyType:      "RSA",
   220  		PublicKeyUse: "sig",
   221  		Exponent:     base64.RawURLEncoding.EncodeToString(b[i:]),
   222  		Modulus:      base64.RawURLEncoding.EncodeToString(pk.N.Bytes()),
   223  	}
   224  }
   225  
   226  // NewJwksKeyFromRSAPrivateKey returns an instance of Jwks from RSA private key.
   227  func NewJwksKeyFromRSAPrivateKey(privKey *rsa.PrivateKey) (*JwksKey, error) {
   228  	if len(privKey.Primes) != 2 {
   229  		return nil, fmt.Errorf("unexpected prime number count: %d", len(privKey.Primes))
   230  	}
   231  
   232  	jk := createJwksKeyFromPubKey(&privKey.PublicKey)
   233  	jk.KeyID = util.GetRandomStringFromRange(26, 32)
   234  	if err := jk.Validate(); err != nil {
   235  		return nil, fmt.Errorf("failed creating jwks key: %v", err)
   236  	}
   237  
   238  	return jk, nil
   239  }
   240  
   241  // NewJwksKeyFromRSAPublicKeyPEM returns an instance of Jwks from RSA public key in PEM format.
   242  func NewJwksKeyFromRSAPublicKeyPEM(kid, fp string) (*JwksKey, error) {
   243  	kb, err := ioutil.ReadFile(fp)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	var block *pem.Block
   249  	if block, _ = pem.Decode(kb); block == nil {
   250  		return nil, errors.ErrNotPEMEncodedKey
   251  	}
   252  
   253  	var pubKey *rsa.PublicKey
   254  
   255  	switch {
   256  	case bytes.Contains(kb, []byte("RSA PUBLIC KEY")):
   257  		pubKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
   258  		if err != nil {
   259  			return nil, err
   260  		}
   261  	case bytes.Contains(kb, []byte("PUBLIC KEY")):
   262  		pub, err := x509.ParsePKIXPublicKey(block.Bytes)
   263  		if err != nil {
   264  			return nil, err
   265  		}
   266  		switch pub := pub.(type) {
   267  		case *rsa.PublicKey:
   268  			pubKey = pub
   269  		default:
   270  			return nil, fmt.Errorf("key payload is not supported public key")
   271  		}
   272  	default:
   273  		return nil, fmt.Errorf("key payload is not RSA public key")
   274  	}
   275  
   276  	jk := createJwksKeyFromPubKey(pubKey)
   277  	jk.KeyID = kid
   278  	if err := jk.Validate(); err != nil {
   279  		return nil, fmt.Errorf("failed creating jwks key: %v", err)
   280  	}
   281  
   282  	return jk, nil
   283  }