github.com/trustbloc/kms-go@v1.1.2/doc/jose/jwk/jwk.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package jwk
     8  
     9  import (
    10  	"crypto/ecdsa"
    11  	"crypto/elliptic"
    12  	"crypto/rsa"
    13  	"crypto/x509"
    14  	"encoding/base64"
    15  	"encoding/json"
    16  	"errors"
    17  	"fmt"
    18  	"math/big"
    19  	"strings"
    20  
    21  	"github.com/btcsuite/btcd/btcec/v2"
    22  	"github.com/go-jose/go-jose/v3"
    23  	"github.com/trustbloc/bbs-signature-go/bbs12381g2pub"
    24  	"golang.org/x/crypto/ed25519"
    25  
    26  	"github.com/trustbloc/kms-go/spi/kms"
    27  	"github.com/trustbloc/kms-go/util/cryptoutil"
    28  )
    29  
    30  const (
    31  	secp256k1Alg   = "ES256K"
    32  	secp256k1Crv   = "secp256k1"
    33  	secp256k1Size  = 32
    34  	bitsPerByte    = 8
    35  	ecKty          = "EC"
    36  	okpKty         = "OKP"
    37  	x25519Crv      = "X25519"
    38  	ed25519Crv     = "Ed25519"
    39  	bls12381G2Crv  = "BLS12381_G2"
    40  	bls12381G2Size = 96
    41  	blsComprPrivSz = 32
    42  )
    43  
    44  // JWK (JSON Web Key) is a JSON data structure that represents a cryptographic key.
    45  type JWK struct {
    46  	jose.JSONWebKey
    47  
    48  	Kty string
    49  	Crv string
    50  }
    51  
    52  // PublicKeyBytes converts a public key to bytes.
    53  // Note: the Public() member function is in go-jose, this means keys not supported by go-jose are not supported using
    54  // j.Public(). Instead use this function to get the public raw bytes.
    55  func (j *JWK) PublicKeyBytes() ([]byte, error) { //nolint:gocyclo
    56  	if j.isBLS12381G2() {
    57  		switch bbsKey := j.Key.(type) {
    58  		case *bbs12381g2pub.PrivateKey:
    59  			return bbsKey.PublicKey().Marshal()
    60  		case *bbs12381g2pub.PublicKey:
    61  			return bbsKey.Marshal()
    62  		}
    63  	}
    64  
    65  	if j.isX25519() {
    66  		x25519Key, ok := j.Key.([]byte)
    67  		if !ok {
    68  			return nil, fmt.Errorf("invalid public key in kid '%s'", j.KeyID)
    69  		}
    70  
    71  		return x25519Key, nil
    72  	}
    73  
    74  	if j.isSecp256k1() {
    75  		var ecPubKey *ecdsa.PublicKey
    76  
    77  		ecPubKey, ok := j.Key.(*ecdsa.PublicKey)
    78  		if !ok {
    79  			ecPubKey = &j.Key.(*ecdsa.PrivateKey).PublicKey
    80  		}
    81  
    82  		x := &btcec.FieldVal{}
    83  		x.SetByteSlice(ecPubKey.X.Bytes())
    84  
    85  		y := &btcec.FieldVal{}
    86  		y.SetByteSlice(ecPubKey.Y.Bytes())
    87  
    88  		pubKey := btcec.NewPublicKey(x, y)
    89  
    90  		return pubKey.SerializeCompressed(), nil
    91  	}
    92  
    93  	switch pubKey := j.Public().Key.(type) {
    94  	case ed25519.PublicKey:
    95  		return pubKey, nil
    96  	case *ecdsa.PublicKey:
    97  		return elliptic.Marshal(pubKey, pubKey.X, pubKey.Y), nil
    98  	case *rsa.PublicKey:
    99  		return x509.MarshalPKCS1PublicKey(pubKey), nil
   100  	default:
   101  		return nil, fmt.Errorf("unsupported public key type in kid '%s'", j.KeyID)
   102  	}
   103  }
   104  
   105  // UnmarshalJSON reads a key from its JSON representation.
   106  func (j *JWK) UnmarshalJSON(jwkBytes []byte) error {
   107  	var key jsonWebKey
   108  
   109  	marshalErr := json.Unmarshal(jwkBytes, &key)
   110  	if marshalErr != nil {
   111  		return fmt.Errorf("unable to read JWK: %w", marshalErr)
   112  	}
   113  
   114  	// nolint: gocritic, nestif
   115  	if isSecp256k1(key.Alg, key.Kty, key.Crv) {
   116  		jwk, err := unmarshalSecp256k1(&key)
   117  		if err != nil {
   118  			return fmt.Errorf("unable to read JWK: %w", err)
   119  		}
   120  
   121  		*j = *jwk
   122  	} else if isBLS12381G2(key.Kty, key.Crv) {
   123  		jwk, err := unmarshalBLS12381G2(&key)
   124  		if err != nil {
   125  			return fmt.Errorf("unable to read BBS+ JWE: %w", err)
   126  		}
   127  
   128  		*j = *jwk
   129  	} else if isX25519(key.Kty, key.Crv) {
   130  		jwk, err := unmarshalX25519(&key)
   131  		if err != nil {
   132  			return fmt.Errorf("unable to read X25519 JWE: %w", err)
   133  		}
   134  
   135  		*j = *jwk
   136  	} else {
   137  		var joseJWK jose.JSONWebKey
   138  
   139  		err := json.Unmarshal(jwkBytes, &joseJWK)
   140  		if err != nil {
   141  			return fmt.Errorf("unable to read jose JWK, %w", err)
   142  		}
   143  
   144  		j.JSONWebKey = joseJWK
   145  	}
   146  
   147  	j.Kty = key.Kty
   148  	j.Crv = key.Crv
   149  
   150  	return nil
   151  }
   152  
   153  // MarshalJSON serializes the given key to its JSON representation.
   154  func (j *JWK) MarshalJSON() ([]byte, error) {
   155  	if j.isSecp256k1() {
   156  		return marshalSecp256k1(j)
   157  	}
   158  
   159  	if j.isX25519() {
   160  		return marshalX25519(j)
   161  	}
   162  
   163  	if j.isBLS12381G2() {
   164  		return marshalBLS12381G2(j)
   165  	}
   166  
   167  	return (&j.JSONWebKey).MarshalJSON()
   168  }
   169  
   170  // KeyType returns the kms KeyType of the JWK, or an error if the JWK is of an unrecognized type.
   171  func (j *JWK) KeyType() (kms.KeyType, error) {
   172  	switch key := j.Key.(type) {
   173  	case ed25519.PublicKey, ed25519.PrivateKey:
   174  		return kms.ED25519Type, nil
   175  	case *bbs12381g2pub.PublicKey, *bbs12381g2pub.PrivateKey:
   176  		return kms.BLS12381G2Type, nil
   177  	case *ecdsa.PublicKey:
   178  		return ecdsaPubKeyType(key)
   179  	case *ecdsa.PrivateKey:
   180  		return ecdsaPubKeyType(&(key.PublicKey))
   181  	case *rsa.PublicKey, *rsa.PrivateKey:
   182  		return kms.RSAPS256Type, nil
   183  	}
   184  
   185  	switch {
   186  	case isX25519(j.Kty, j.Crv):
   187  		return kms.X25519ECDHKWType, nil
   188  	case isEd25519(j.Kty, j.Crv):
   189  		return kms.ED25519Type, nil
   190  	case isSecp256k1(j.Algorithm, j.Kty, j.Crv):
   191  		return kms.ECDSASecp256k1TypeIEEEP1363, nil
   192  	default:
   193  		return "", fmt.Errorf("no keytype recognized for jwk")
   194  	}
   195  }
   196  
   197  func ecdsaPubKeyType(pub *ecdsa.PublicKey) (kms.KeyType, error) {
   198  	switch pub.Curve {
   199  	case btcec.S256():
   200  		return kms.ECDSASecp256k1TypeIEEEP1363, nil
   201  	case elliptic.P256():
   202  		return kms.ECDSAP256TypeIEEEP1363, nil
   203  	case elliptic.P384():
   204  		return kms.ECDSAP384TypeIEEEP1363, nil
   205  	case elliptic.P521():
   206  		return kms.ECDSAP521TypeIEEEP1363, nil
   207  	}
   208  
   209  	return "", fmt.Errorf("no keytype recognized for ecdsa jwk")
   210  }
   211  
   212  func (j *JWK) isX25519() bool {
   213  	switch j.Key.(type) {
   214  	case []byte:
   215  		return isX25519(j.Kty, j.Crv)
   216  	default:
   217  		return false
   218  	}
   219  }
   220  
   221  func (j *JWK) isBLS12381G2() bool {
   222  	switch j.Key.(type) {
   223  	case *bbs12381g2pub.PublicKey, *bbs12381g2pub.PrivateKey:
   224  		return true
   225  	default:
   226  		return false
   227  	}
   228  }
   229  
   230  func (j *JWK) isSecp256k1() bool {
   231  	return isSecp256k1Key(j.Key) || isSecp256k1(j.Algorithm, j.Kty, j.Crv)
   232  }
   233  
   234  func isSecp256k1Key(pubKey interface{}) bool {
   235  	switch key := pubKey.(type) {
   236  	case *ecdsa.PublicKey:
   237  		return key.Curve == btcec.S256()
   238  	case *ecdsa.PrivateKey:
   239  		return key.Curve == btcec.S256()
   240  	default:
   241  		return false
   242  	}
   243  }
   244  
   245  func isX25519(kty, crv string) bool {
   246  	return strings.EqualFold(kty, okpKty) && strings.EqualFold(crv, x25519Crv)
   247  }
   248  
   249  func isEd25519(kty, crv string) bool {
   250  	return strings.EqualFold(kty, okpKty) && strings.EqualFold(crv, ed25519Crv)
   251  }
   252  
   253  func isBLS12381G2(kty, crv string) bool {
   254  	return strings.EqualFold(kty, ecKty) && strings.EqualFold(crv, bls12381G2Crv)
   255  }
   256  
   257  func isSecp256k1(alg, kty, crv string) bool {
   258  	return strings.EqualFold(alg, secp256k1Alg) ||
   259  		(strings.EqualFold(kty, ecKty) && strings.EqualFold(crv, secp256k1Crv))
   260  }
   261  
   262  func unmarshalSecp256k1(jwk *jsonWebKey) (*JWK, error) {
   263  	if jwk.X == nil {
   264  		return nil, ErrInvalidKey
   265  	}
   266  
   267  	if jwk.Y == nil {
   268  		return nil, ErrInvalidKey
   269  	}
   270  
   271  	curve := btcec.S256()
   272  
   273  	if curveSize(curve) != len(jwk.X.data) {
   274  		return nil, ErrInvalidKey
   275  	}
   276  
   277  	if curveSize(curve) != len(jwk.Y.data) {
   278  		return nil, ErrInvalidKey
   279  	}
   280  
   281  	if jwk.D != nil && dSize(curve) != len(jwk.D.data) {
   282  		return nil, ErrInvalidKey
   283  	}
   284  
   285  	x := jwk.X.bigInt()
   286  	y := jwk.Y.bigInt()
   287  
   288  	if !curve.IsOnCurve(x, y) {
   289  		return nil, ErrInvalidKey
   290  	}
   291  
   292  	var key interface{}
   293  
   294  	if jwk.D != nil {
   295  		key = &ecdsa.PrivateKey{
   296  			PublicKey: ecdsa.PublicKey{
   297  				Curve: curve,
   298  				X:     x,
   299  				Y:     y,
   300  			},
   301  			D: jwk.D.bigInt(),
   302  		}
   303  	} else {
   304  		key = &ecdsa.PublicKey{
   305  			Curve: curve,
   306  			X:     x,
   307  			Y:     y,
   308  		}
   309  	}
   310  
   311  	return &JWK{
   312  		JSONWebKey: jose.JSONWebKey{
   313  			Key: key, KeyID: jwk.Kid, Algorithm: jwk.Alg, Use: jwk.Use,
   314  		},
   315  	}, nil
   316  }
   317  
   318  func unmarshalX25519(jwk *jsonWebKey) (*JWK, error) {
   319  	if jwk.X == nil {
   320  		return nil, ErrInvalidKey
   321  	}
   322  
   323  	if len(jwk.X.data) != cryptoutil.Curve25519KeySize {
   324  		return nil, ErrInvalidKey
   325  	}
   326  
   327  	return &JWK{
   328  		JSONWebKey: jose.JSONWebKey{
   329  			Key: jwk.X.data, KeyID: jwk.Kid, Algorithm: jwk.Alg, Use: jwk.Use,
   330  		},
   331  		Crv: jwk.Crv,
   332  		Kty: jwk.Kty,
   333  	}, nil
   334  }
   335  
   336  func marshalX25519(jwk *JWK) ([]byte, error) {
   337  	var raw jsonWebKey
   338  
   339  	key, ok := jwk.Key.([]byte)
   340  	if !ok {
   341  		return nil, errors.New("marshalX25519: invalid key")
   342  	}
   343  
   344  	if len(key) != cryptoutil.Curve25519KeySize {
   345  		return nil, errors.New("marshalX25519: invalid key")
   346  	}
   347  
   348  	raw = jsonWebKey{
   349  		Kty: okpKty,
   350  		Crv: x25519Crv,
   351  		X:   newFixedSizeBuffer(key, cryptoutil.Curve25519KeySize),
   352  	}
   353  
   354  	raw.Kid = jwk.KeyID
   355  	raw.Alg = jwk.Algorithm
   356  	raw.Use = jwk.Use
   357  
   358  	return json.Marshal(raw)
   359  }
   360  
   361  func unmarshalBLS12381G2(jwk *jsonWebKey) (*JWK, error) {
   362  	if jwk.X == nil {
   363  		return nil, ErrInvalidKey
   364  	}
   365  
   366  	if len(jwk.X.data) != bls12381G2Size {
   367  		return nil, ErrInvalidKey
   368  	}
   369  
   370  	if jwk.D != nil && blsComprPrivSz != len(jwk.D.data) {
   371  		return nil, ErrInvalidKey
   372  	}
   373  
   374  	var (
   375  		key interface{}
   376  		err error
   377  	)
   378  
   379  	if jwk.D != nil {
   380  		key, err = bbs12381g2pub.UnmarshalPrivateKey(jwk.D.data)
   381  		if err != nil {
   382  			return nil, fmt.Errorf("jwk invalid private key unmarshal: %w", err)
   383  		}
   384  	} else {
   385  		key, err = bbs12381g2pub.UnmarshalPublicKey(jwk.X.data)
   386  		if err != nil {
   387  			return nil, fmt.Errorf("jwk invalid public key unmarshal: %w", err)
   388  		}
   389  	}
   390  
   391  	return &JWK{
   392  		JSONWebKey: jose.JSONWebKey{
   393  			Key: key, KeyID: jwk.Kid, Algorithm: jwk.Alg, Use: jwk.Use,
   394  		},
   395  		Crv: jwk.Crv,
   396  		Kty: jwk.Kty,
   397  	}, nil
   398  }
   399  
   400  func marshalBLS12381G2(jwk *JWK) ([]byte, error) {
   401  	var raw jsonWebKey
   402  
   403  	switch key := jwk.Key.(type) {
   404  	case *bbs12381g2pub.PublicKey:
   405  		mKey, err := key.Marshal()
   406  		if err != nil {
   407  			return nil, err
   408  		}
   409  
   410  		if len(mKey) != bls12381G2Size {
   411  			return nil, errors.New("marshal BBS public key: invalid key")
   412  		}
   413  
   414  		raw = jsonWebKey{
   415  			Kty: ecKty,
   416  			Crv: bls12381G2Crv,
   417  			X:   newFixedSizeBuffer(mKey, bls12381G2Size),
   418  		}
   419  	case *bbs12381g2pub.PrivateKey:
   420  		mPubKey, err := key.PublicKey().Marshal()
   421  		if err != nil {
   422  			return nil, err
   423  		}
   424  
   425  		if len(mPubKey) != bls12381G2Size {
   426  			return nil, errors.New("marshal BBS public key: invalid key")
   427  		}
   428  
   429  		mPrivKey, err := key.Marshal()
   430  		if err != nil {
   431  			return nil, err
   432  		}
   433  
   434  		if len(mPrivKey) != blsComprPrivSz {
   435  			return nil, errors.New("marshal BBS private key: invalid key")
   436  		}
   437  
   438  		raw = jsonWebKey{
   439  			Kty: ecKty,
   440  			Crv: bls12381G2Crv,
   441  			X:   newFixedSizeBuffer(mPubKey, bls12381G2Size),
   442  			D:   newFixedSizeBuffer(mPrivKey, blsComprPrivSz),
   443  		}
   444  	default:
   445  		return nil, errors.New("marshalBLS12381G2: invalid key")
   446  	}
   447  
   448  	raw.Kid = jwk.KeyID
   449  	raw.Alg = jwk.Algorithm
   450  	raw.Use = jwk.Use
   451  
   452  	return json.Marshal(raw)
   453  }
   454  
   455  func marshalSecp256k1(jwk *JWK) ([]byte, error) {
   456  	var raw jsonWebKey
   457  
   458  	switch ecdsaKey := jwk.Key.(type) {
   459  	case *ecdsa.PublicKey:
   460  		raw = jsonWebKey{
   461  			Kty: ecKty,
   462  			Crv: secp256k1Crv,
   463  			X:   newFixedSizeBuffer(ecdsaKey.X.Bytes(), secp256k1Size),
   464  			Y:   newFixedSizeBuffer(ecdsaKey.Y.Bytes(), secp256k1Size),
   465  		}
   466  
   467  	case *ecdsa.PrivateKey:
   468  		raw = jsonWebKey{
   469  			Kty: ecKty,
   470  			Crv: secp256k1Crv,
   471  			X:   newFixedSizeBuffer(ecdsaKey.X.Bytes(), secp256k1Size),
   472  			Y:   newFixedSizeBuffer(ecdsaKey.Y.Bytes(), secp256k1Size),
   473  			D:   newFixedSizeBuffer(ecdsaKey.D.Bytes(), dSize(ecdsaKey.Curve)),
   474  		}
   475  	}
   476  
   477  	raw.Kid = jwk.KeyID
   478  	raw.Alg = jwk.Algorithm
   479  	raw.Use = jwk.Use
   480  
   481  	return json.Marshal(raw)
   482  }
   483  
   484  // jsonWebKey contains subset of json web key json properties.
   485  type jsonWebKey struct {
   486  	Use string `json:"use,omitempty"`
   487  	Kty string `json:"kty,omitempty"`
   488  	Kid string `json:"kid,omitempty"`
   489  	Crv string `json:"crv,omitempty"`
   490  	Alg string `json:"alg,omitempty"`
   491  
   492  	X *byteBuffer `json:"x,omitempty"`
   493  	Y *byteBuffer `json:"y,omitempty"`
   494  
   495  	D *byteBuffer `json:"d,omitempty"`
   496  }
   497  
   498  // Get size of curve in bytes.
   499  func curveSize(crv elliptic.Curve) int {
   500  	bits := crv.Params().BitSize
   501  
   502  	div := bits / bitsPerByte
   503  	mod := bits % bitsPerByte
   504  
   505  	if mod == 0 {
   506  		return div
   507  	}
   508  
   509  	return div + 1
   510  }
   511  
   512  func dSize(curve elliptic.Curve) int {
   513  	order := curve.Params().P
   514  	bitLen := order.BitLen()
   515  	size := bitLen / bitsPerByte
   516  
   517  	if bitLen%bitsPerByte != 0 {
   518  		size++
   519  	}
   520  
   521  	return size
   522  }
   523  
   524  // byteBuffer represents a slice of bytes that can be serialized to url-safe base64.
   525  type byteBuffer struct {
   526  	data []byte
   527  }
   528  
   529  func (b *byteBuffer) UnmarshalJSON(data []byte) error {
   530  	var encoded string
   531  
   532  	err := json.Unmarshal(data, &encoded)
   533  	if err != nil {
   534  		return err
   535  	}
   536  
   537  	if encoded == "" {
   538  		return nil
   539  	}
   540  
   541  	decoded, err := base64.RawURLEncoding.DecodeString(encoded)
   542  	if err != nil {
   543  		return err
   544  	}
   545  
   546  	*b = byteBuffer{
   547  		data: decoded,
   548  	}
   549  
   550  	return nil
   551  }
   552  
   553  func (b *byteBuffer) MarshalJSON() ([]byte, error) {
   554  	return json.Marshal(b.base64())
   555  }
   556  
   557  func (b *byteBuffer) base64() string {
   558  	return base64.RawURLEncoding.EncodeToString(b.data)
   559  }
   560  
   561  func (b byteBuffer) bigInt() *big.Int {
   562  	return new(big.Int).SetBytes(b.data)
   563  }
   564  
   565  func newFixedSizeBuffer(data []byte, length int) *byteBuffer {
   566  	paddedData := make([]byte, length-len(data))
   567  
   568  	return &byteBuffer{
   569  		data: append(paddedData, data...),
   570  	}
   571  }
   572  
   573  // ErrInvalidKey is returned when passed JWK is invalid.
   574  var ErrInvalidKey = errors.New("invalid JWK")