github.com/xmidt-org/webpa-common@v1.11.9/secure/tools/cmd/keyserver/keyStore.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/x509"
     8  	"encoding/pem"
     9  	"fmt"
    10  	"log"
    11  
    12  	"github.com/xmidt-org/webpa-common/secure/key"
    13  )
    14  
    15  // KeyStore provides a single access point for a set of keys, keyed by their key identifiers
    16  // or kid values in JWTs.
    17  type KeyStore struct {
    18  	keyIDs      []string
    19  	privateKeys map[string]*rsa.PrivateKey
    20  	publicKeys  map[string][]byte
    21  }
    22  
    23  func (ks *KeyStore) Len() int {
    24  	return len(ks.keyIDs)
    25  }
    26  
    27  func (ks *KeyStore) KeyIDs() []string {
    28  	return ks.keyIDs
    29  }
    30  
    31  func (ks *KeyStore) PrivateKey(keyID string) (privateKey *rsa.PrivateKey, ok bool) {
    32  	privateKey, ok = ks.privateKeys[keyID]
    33  	return
    34  }
    35  
    36  func (ks *KeyStore) PublicKey(keyID string) (data []byte, ok bool) {
    37  	data, ok = ks.publicKeys[keyID]
    38  	return
    39  }
    40  
    41  // NewKeyStore exchanges a Configuration for a KeyStore.
    42  func NewKeyStore(infoLogger *log.Logger, c *Configuration) (*KeyStore, error) {
    43  	if err := c.Validate(); err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	privateKeys := make(map[string]*rsa.PrivateKey, len(c.Keys)+len(c.Generate))
    48  	if err := resolveKeys(infoLogger, c, privateKeys); err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	if err := generateKeys(infoLogger, c, privateKeys); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	publicKeys := make(map[string][]byte, len(privateKeys))
    57  	if err := marshalPublicKeys(publicKeys, privateKeys); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	keyIDs := make([]string, 0, len(privateKeys))
    62  	for keyID := range privateKeys {
    63  		keyIDs = append(keyIDs, keyID)
    64  	}
    65  
    66  	return &KeyStore{
    67  		keyIDs:      keyIDs,
    68  		privateKeys: privateKeys,
    69  		publicKeys:  publicKeys,
    70  	}, nil
    71  }
    72  
    73  func resolveKeys(infoLogger *log.Logger, c *Configuration, privateKeys map[string]*rsa.PrivateKey) error {
    74  	for keyID, resourceFactory := range c.Keys {
    75  		infoLogger.Printf("Key [%s]: loading from resource %#v\n", keyID, resourceFactory)
    76  
    77  		keyResolver, err := (&key.ResolverFactory{
    78  			Factory: *resourceFactory,
    79  			Purpose: key.PurposeSign,
    80  		}).NewResolver()
    81  
    82  		if err != nil {
    83  			return err
    84  		}
    85  
    86  		resolvedPair, err := keyResolver.ResolveKey(keyID)
    87  		if err != nil {
    88  			return err
    89  		}
    90  
    91  		if resolvedPair.HasPrivate() {
    92  			privateKeys[keyID] = resolvedPair.Private().(*rsa.PrivateKey)
    93  		} else {
    94  			return fmt.Errorf("The key %s did not resolve to an RSA private key", keyID)
    95  		}
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func generateKeys(infoLogger *log.Logger, c *Configuration, privateKeys map[string]*rsa.PrivateKey) error {
   102  	bits := c.Bits
   103  	if bits < 1 {
   104  		bits = DefaultBits
   105  	}
   106  
   107  	for _, keyID := range c.Generate {
   108  		infoLogger.Printf("Key [%s]: generating ...", keyID)
   109  
   110  		if generatedKey, err := rsa.GenerateKey(rand.Reader, bits); err == nil {
   111  			privateKeys[keyID] = generatedKey
   112  		} else {
   113  			return err
   114  		}
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func marshalPublicKeys(publicKeys map[string][]byte, privateKeys map[string]*rsa.PrivateKey) error {
   121  	for keyID, privateKey := range privateKeys {
   122  		derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
   123  		if err != nil {
   124  			return err
   125  		}
   126  
   127  		block := pem.Block{
   128  			Type:  "PUBLIC KEY",
   129  			Bytes: derBytes,
   130  		}
   131  
   132  		var buffer bytes.Buffer
   133  		err = pem.Encode(&buffer, &block)
   134  		if err != nil {
   135  			return err
   136  		}
   137  
   138  		publicKeys[keyID] = buffer.Bytes()
   139  	}
   140  
   141  	return nil
   142  }