github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/crypto/keys/testonly/keys.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testonly
    16  
    17  import (
    18  	"crypto"
    19  	"crypto/ecdsa"
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"crypto/sha256"
    23  	"crypto/x509"
    24  	"encoding/asn1"
    25  	"errors"
    26  	"fmt"
    27  	"math/big"
    28  
    29  	"github.com/google/trillian/crypto/keys"
    30  	"github.com/google/trillian/crypto/keys/der"
    31  	"github.com/google/trillian/crypto/keys/pem"
    32  	"github.com/google/trillian/crypto/keyspb"
    33  )
    34  
    35  // MustMarshalPublicPEMToDER reads a PEM-encoded public key and returns it in DER encoding.
    36  // If an error occurs, it panics.
    37  func MustMarshalPublicPEMToDER(keyPEM string) []byte {
    38  	key, err := pem.UnmarshalPublicKey(keyPEM)
    39  	if err != nil {
    40  		panic(err)
    41  	}
    42  
    43  	keyDER, err := x509.MarshalPKIXPublicKey(key)
    44  	if err != nil {
    45  		panic(err)
    46  	}
    47  	return keyDER
    48  }
    49  
    50  // MustMarshalPrivatePEMToDER decrypts a PEM-encoded private key and returns it in DER encoding.
    51  // If an error occurs, it panics.
    52  func MustMarshalPrivatePEMToDER(keyPEM, password string) []byte {
    53  	key, err := pem.UnmarshalPrivateKey(keyPEM, password)
    54  	if err != nil {
    55  		panic(err)
    56  	}
    57  
    58  	keyDER, err := der.MarshalPrivateKey(key)
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	return keyDER
    63  }
    64  
    65  // SignAndVerify exercises a signer by using it to generate a signature, and
    66  // then verifies that this signature is correct.
    67  func SignAndVerify(signer crypto.Signer, pubKey crypto.PublicKey) error {
    68  	hasher := crypto.SHA256
    69  	digest := sha256.Sum256([]byte("test"))
    70  	signature, err := signer.Sign(rand.Reader, digest[:], hasher)
    71  	if err != nil {
    72  		return err
    73  	}
    74  
    75  	switch pubKey := pubKey.(type) {
    76  	case *ecdsa.PublicKey:
    77  		return verifyECDSA(pubKey, digest[:], signature)
    78  	case *rsa.PublicKey:
    79  		return verifyRSA(pubKey, digest[:], signature, hasher, hasher)
    80  	default:
    81  		return fmt.Errorf("unknown public key type: %T", pubKey)
    82  	}
    83  }
    84  
    85  func verifyECDSA(pubKey *ecdsa.PublicKey, digest, sig []byte) error {
    86  	var ecdsaSig struct {
    87  		R, S *big.Int
    88  	}
    89  
    90  	rest, err := asn1.Unmarshal(sig, &ecdsaSig)
    91  	if err != nil {
    92  		return err
    93  	}
    94  	if len(rest) != 0 {
    95  		return fmt.Errorf("ECDSA signature %v bytes longer than expected", len(rest))
    96  	}
    97  
    98  	if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) {
    99  		return errors.New("ECDSA signature failed verification")
   100  	}
   101  	return nil
   102  }
   103  
   104  func verifyRSA(pubKey *rsa.PublicKey, digest, sig []byte, hasher crypto.Hash, opts crypto.SignerOpts) error {
   105  	if pssOpts, ok := opts.(*rsa.PSSOptions); ok {
   106  		return rsa.VerifyPSS(pubKey, hasher, digest, sig, pssOpts)
   107  	}
   108  	return rsa.VerifyPKCS1v15(pubKey, hasher, digest, sig)
   109  }
   110  
   111  // CheckKeyMatchesSpec verifies that the key conforms to the specification.
   112  // If it does not, an error is returned.
   113  func CheckKeyMatchesSpec(key crypto.PrivateKey, spec *keyspb.Specification) error {
   114  	switch params := spec.Params.(type) {
   115  	case *keyspb.Specification_EcdsaParams:
   116  		if key, ok := key.(*ecdsa.PrivateKey); ok {
   117  			return checkEcdsaKeyMatchesParams(key, params.EcdsaParams)
   118  		}
   119  		return fmt.Errorf("%T, want *ecdsa.PrivateKey", key)
   120  	case *keyspb.Specification_RsaParams:
   121  		if key, ok := key.(*rsa.PrivateKey); ok {
   122  			return checkRsaKeyMatchesParams(key, params.RsaParams)
   123  		}
   124  		return fmt.Errorf("%T, want *rsa.PrivateKey", key)
   125  	}
   126  
   127  	return fmt.Errorf("%T is not a supported keyspb.Specification.Params type", spec.Params)
   128  }
   129  
   130  func checkEcdsaKeyMatchesParams(key *ecdsa.PrivateKey, params *keyspb.Specification_ECDSA) error {
   131  	wantCurve := keys.ECDSACurveFromParams(params)
   132  	if wantCurve.Params().Name != key.Params().Name {
   133  		return fmt.Errorf("ECDSA key on %v curve, want %v curve", key.Params().Name, wantCurve.Params().Name)
   134  	}
   135  
   136  	return nil
   137  }
   138  
   139  func checkRsaKeyMatchesParams(key *rsa.PrivateKey, params *keyspb.Specification_RSA) error {
   140  	wantBits := keys.DefaultRsaKeySizeInBits
   141  	if params.GetBits() != 0 {
   142  		wantBits = int(params.GetBits())
   143  	}
   144  
   145  	if got, want := key.N.BitLen(), wantBits; got != want {
   146  		return fmt.Errorf("%v-bit RSA key, want %v-bit", got, want)
   147  	}
   148  
   149  	return nil
   150  }