github.com/code-to-go/safepool.lib@v0.0.0-20221205180519-ee25e63c226e/pool/keystore.go (about)

     1  package pool
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/code-to-go/safepool.lib/core"
     9  	"github.com/code-to-go/safepool.lib/security"
    10  
    11  	"github.com/patrickmn/go-cache"
    12  )
    13  
    14  type Keystore map[uint64][]byte
    15  
    16  var cachedEncKeys = cache.New(time.Hour, 10*time.Hour)
    17  
    18  func (p *Pool) importKeystore(a AccessFile) (Keystore, error) {
    19  	masterKey := p.keyFunc(p.masterKeyId)
    20  	if masterKey == nil {
    21  		return nil, ErrNotAuthorized
    22  	}
    23  
    24  	ks, err := p.unmarshalKeystore(masterKey, a.Nonce, a.Keystore)
    25  	if core.IsErr(err, "cannot unmarshal keystore for pool '%s': %v", p.Name) {
    26  		return nil, err
    27  	}
    28  
    29  	for id, val := range ks {
    30  		err = p.sqlSetKey(id, val)
    31  		if core.IsErr(err, "cannot set key %d to DB for pool '%s': %v", id, p.Name) {
    32  			return nil, err
    33  		}
    34  	}
    35  	return ks, nil
    36  }
    37  
    38  func (p *Pool) marshalKeystore(masterKey []byte, nonce []byte, ks Keystore) ([]byte, error) {
    39  	data, err := json.Marshal(ks)
    40  	if core.IsErr(err, "cannot marshal keystore: %v") {
    41  		return nil, err
    42  	}
    43  	return security.EncryptBlock(masterKey, nonce, data)
    44  }
    45  
    46  func (p *Pool) unmarshalKeystore(masterKey []byte, nonce []byte, cipherdata []byte) (Keystore, error) {
    47  	data, err := security.DecryptBlock(masterKey, nonce, cipherdata)
    48  	if core.IsErr(err, "invalid key or corrupted keystore: %v") {
    49  		return nil, err
    50  	}
    51  
    52  	var ks Keystore
    53  	err = json.Unmarshal(data, &ks)
    54  	return ks, err
    55  }
    56  
    57  func (p *Pool) keyFunc(id uint64) []byte {
    58  	if id == 0 {
    59  		return p.masterKey
    60  	}
    61  
    62  	k := fmt.Sprintf("%s-%d", p.Name, id)
    63  	if v, found := cachedEncKeys.Get(k); found {
    64  		return v.([]byte)
    65  	}
    66  
    67  	v := p.sqlGetKey(id)
    68  	if v != nil {
    69  		cachedEncKeys.Set(k, v, cache.DefaultExpiration)
    70  	}
    71  	return v
    72  }