github.com/lestrrat-go/jwx/v2@v2.0.21/jws/ecdsa.go (about)

     1  package jws
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/ecdsa"
     6  	"crypto/rand"
     7  	"encoding/asn1"
     8  	"fmt"
     9  	"math/big"
    10  
    11  	"github.com/lestrrat-go/jwx/v2/internal/ecutil"
    12  	"github.com/lestrrat-go/jwx/v2/internal/keyconv"
    13  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    14  	"github.com/lestrrat-go/jwx/v2/jwa"
    15  )
    16  
    17  var ecdsaSigners map[jwa.SignatureAlgorithm]*ecdsaSigner
    18  var ecdsaVerifiers map[jwa.SignatureAlgorithm]*ecdsaVerifier
    19  
    20  func init() {
    21  	algs := map[jwa.SignatureAlgorithm]crypto.Hash{
    22  		jwa.ES256:  crypto.SHA256,
    23  		jwa.ES384:  crypto.SHA384,
    24  		jwa.ES512:  crypto.SHA512,
    25  		jwa.ES256K: crypto.SHA256,
    26  	}
    27  	ecdsaSigners = make(map[jwa.SignatureAlgorithm]*ecdsaSigner)
    28  	ecdsaVerifiers = make(map[jwa.SignatureAlgorithm]*ecdsaVerifier)
    29  
    30  	for alg, hash := range algs {
    31  		ecdsaSigners[alg] = &ecdsaSigner{
    32  			alg:  alg,
    33  			hash: hash,
    34  		}
    35  		ecdsaVerifiers[alg] = &ecdsaVerifier{
    36  			alg:  alg,
    37  			hash: hash,
    38  		}
    39  	}
    40  }
    41  
    42  func newECDSASigner(alg jwa.SignatureAlgorithm) Signer {
    43  	return ecdsaSigners[alg]
    44  }
    45  
    46  // ecdsaSigners are immutable.
    47  type ecdsaSigner struct {
    48  	alg  jwa.SignatureAlgorithm
    49  	hash crypto.Hash
    50  }
    51  
    52  func (es ecdsaSigner) Algorithm() jwa.SignatureAlgorithm {
    53  	return es.alg
    54  }
    55  
    56  func (es *ecdsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
    57  	if key == nil {
    58  		return nil, fmt.Errorf(`missing private key while signing payload`)
    59  	}
    60  
    61  	h := es.hash.New()
    62  	if _, err := h.Write(payload); err != nil {
    63  		return nil, fmt.Errorf(`failed to write payload using ecdsa: %w`, err)
    64  	}
    65  
    66  	signer, ok := key.(crypto.Signer)
    67  	if ok {
    68  		if !isValidECDSAKey(key) {
    69  			return nil, fmt.Errorf(`cannot use key of type %T to generate ECDSA based signatures`, key)
    70  		}
    71  		switch key.(type) {
    72  		case ecdsa.PrivateKey, *ecdsa.PrivateKey:
    73  			// if it's a ecdsa.PrivateKey, it's more efficient to
    74  			// go through the non-crypto.Signer route. Set ok to false
    75  			ok = false
    76  		}
    77  	}
    78  
    79  	var r, s *big.Int
    80  	var curveBits int
    81  	if ok {
    82  		signed, err := signer.Sign(rand.Reader, h.Sum(nil), es.hash)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  
    87  		var p struct {
    88  			R *big.Int
    89  			S *big.Int
    90  		}
    91  		if _, err := asn1.Unmarshal(signed, &p); err != nil {
    92  			return nil, fmt.Errorf(`failed to unmarshal ASN1 encoded signature: %w`, err)
    93  		}
    94  
    95  		// Okay, this is silly, but hear me out. When we use the
    96  		// crypto.Signer interface, the PrivateKey is hidden.
    97  		// But we need some information about the key (it's bit size).
    98  		//
    99  		// So while silly, we're going to have to make another call
   100  		// here and fetch the Public key.
   101  		// This probably means that this should be cached some where.
   102  		cpub := signer.Public()
   103  		pubkey, ok := cpub.(*ecdsa.PublicKey)
   104  		if !ok {
   105  			return nil, fmt.Errorf(`expected *ecdsa.PublicKey, got %T`, pubkey)
   106  		}
   107  		curveBits = pubkey.Curve.Params().BitSize
   108  
   109  		r = p.R
   110  		s = p.S
   111  	} else {
   112  		var privkey ecdsa.PrivateKey
   113  		if err := keyconv.ECDSAPrivateKey(&privkey, key); err != nil {
   114  			return nil, fmt.Errorf(`failed to retrieve ecdsa.PrivateKey out of %T: %w`, key, err)
   115  		}
   116  		curveBits = privkey.Curve.Params().BitSize
   117  		rtmp, stmp, err := ecdsa.Sign(rand.Reader, &privkey, h.Sum(nil))
   118  		if err != nil {
   119  			return nil, fmt.Errorf(`failed to sign payload using ecdsa: %w`, err)
   120  		}
   121  		r = rtmp
   122  		s = stmp
   123  	}
   124  
   125  	keyBytes := curveBits / 8
   126  	// Curve bits do not need to be a multiple of 8.
   127  	if curveBits%8 > 0 {
   128  		keyBytes++
   129  	}
   130  
   131  	rBytes := r.Bytes()
   132  	rBytesPadded := make([]byte, keyBytes)
   133  	copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
   134  
   135  	sBytes := s.Bytes()
   136  	sBytesPadded := make([]byte, keyBytes)
   137  	copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
   138  
   139  	out := append(rBytesPadded, sBytesPadded...)
   140  
   141  	return out, nil
   142  }
   143  
   144  // ecdsaVerifiers are immutable.
   145  type ecdsaVerifier struct {
   146  	alg  jwa.SignatureAlgorithm
   147  	hash crypto.Hash
   148  }
   149  
   150  func newECDSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
   151  	return ecdsaVerifiers[alg]
   152  }
   153  
   154  func (v ecdsaVerifier) Algorithm() jwa.SignatureAlgorithm {
   155  	return v.alg
   156  }
   157  
   158  func (v *ecdsaVerifier) Verify(payload []byte, signature []byte, key interface{}) error {
   159  	if key == nil {
   160  		return fmt.Errorf(`missing public key while verifying payload`)
   161  	}
   162  
   163  	var pubkey ecdsa.PublicKey
   164  	if cs, ok := key.(crypto.Signer); ok {
   165  		cpub := cs.Public()
   166  		switch cpub := cpub.(type) {
   167  		case ecdsa.PublicKey:
   168  			pubkey = cpub
   169  		case *ecdsa.PublicKey:
   170  			pubkey = *cpub
   171  		default:
   172  			return fmt.Errorf(`failed to retrieve ecdsa.PublicKey out of crypto.Signer %T`, key)
   173  		}
   174  	} else {
   175  		if err := keyconv.ECDSAPublicKey(&pubkey, key); err != nil {
   176  			return fmt.Errorf(`failed to retrieve ecdsa.PublicKey out of %T: %w`, key, err)
   177  		}
   178  	}
   179  
   180  	if !pubkey.Curve.IsOnCurve(pubkey.X, pubkey.Y) {
   181  		return fmt.Errorf(`public key used does not contain a point (X,Y) on the curve`)
   182  	}
   183  
   184  	r := pool.GetBigInt()
   185  	s := pool.GetBigInt()
   186  	defer pool.ReleaseBigInt(r)
   187  	defer pool.ReleaseBigInt(s)
   188  
   189  	keySize := ecutil.CalculateKeySize(pubkey.Curve)
   190  	if len(signature) != keySize*2 {
   191  		return fmt.Errorf(`invalid signature length for curve %q`, pubkey.Curve.Params().Name)
   192  	}
   193  
   194  	r.SetBytes(signature[:keySize])
   195  	s.SetBytes(signature[keySize:])
   196  
   197  	h := v.hash.New()
   198  	if _, err := h.Write(payload); err != nil {
   199  		return fmt.Errorf(`failed to write payload using ecdsa: %w`, err)
   200  	}
   201  
   202  	if !ecdsa.Verify(&pubkey, h.Sum(nil), r, s) {
   203  		return fmt.Errorf(`failed to verify signature using ecdsa`)
   204  	}
   205  	return nil
   206  }