github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/internal/runner/secret_store.go (about)

     1  // Copyright 2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
     2  
     3  package runner
     4  
     5  import (
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/sha256"
     9  	"crypto/x509"
    10  	"encoding/base64"
    11  	"encoding/pem"
    12  	"io/ioutil"
    13  	"os"
    14  	"path/filepath"
    15  	"strings"
    16  
    17  	"github.com/go-stack/stack"
    18  	"github.com/jjeffery/kv"
    19  
    20  	"github.com/awnumar/memguard"
    21  )
    22  
    23  // This file contains the implementation of a credentials structure that has been
    24  // encrypted while in memory.
    25  
    26  var (
    27  	serverSecret = &memguard.Enclave{}
    28  )
    29  
    30  func init() {
    31  	// Safely terminate in case of an interrupt signal
    32  	memguard.CatchInterrupt()
    33  
    34  	// Generate a key sealed inside an encrypted container
    35  	serverSecret = memguard.NewEnclaveRandom(32)
    36  }
    37  
    38  func StopSecret() {
    39  	// Purge the session when we return
    40  	defer memguard.Purge()
    41  }
    42  
    43  type Wrapper struct {
    44  	publicPEM  []byte
    45  	privateKey *rsa.PrivateKey
    46  }
    47  
    48  // KubertesWrapper is used to obtain, if available, the Kubernetes stored encryption
    49  // parameters for the server
    50  func KubernetesWrapper(mountDir string) (w *Wrapper, err kv.Error) {
    51  
    52  	publicPEM, privatePEM, passphrase, err := SSHKeys(
    53  		filepath.Join(mountDir, "encryption"),
    54  		filepath.Join(mountDir, "passphrase"))
    55  
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	return NewWrapper(publicPEM, privatePEM, passphrase)
    61  }
    62  
    63  func SSHKeys(cryptoDir string, passphraseDir string) (publicPEM []byte, privatePEM []byte, passphrase []byte, err kv.Error) {
    64  
    65  	if err = IsAliveK8s(); err != nil {
    66  		return nil, nil, nil, nil
    67  	}
    68  
    69  	// First make sure all the appropriate mounts exist
    70  	info, errGo := os.Stat(cryptoDir)
    71  	if errGo == nil {
    72  		if !info.IsDir() {
    73  			return nil, nil, nil, kv.NewError("not a directory").With("dir", cryptoDir).With("stack", stack.Trace().TrimRuntime())
    74  		}
    75  	} else {
    76  		return nil, nil, nil, kv.Wrap(errGo).With("dir", cryptoDir).With("stack", stack.Trace().TrimRuntime())
    77  	}
    78  	if info, errGo := os.Stat(passphraseDir); errGo == nil {
    79  		if !info.IsDir() {
    80  			return nil, nil, nil, kv.NewError("not a directory").With("dir", passphraseDir).With("stack", stack.Trace().TrimRuntime())
    81  		}
    82  	} else {
    83  		return nil, nil, nil, kv.Wrap(errGo).With("dir", passphraseDir).With("stack", stack.Trace().TrimRuntime())
    84  	}
    85  
    86  	// We have ether directories at least needed to create our secrets, read in the PEMs and passphrase
    87  
    88  	if publicPEM, errGo = ioutil.ReadFile(filepath.Join(cryptoDir, "ssh-publickey")); errGo != nil {
    89  		return nil, nil, nil, kv.Wrap(errGo).With("dir", passphraseDir).With("stack", stack.Trace().TrimRuntime())
    90  	}
    91  	if privatePEM, errGo = ioutil.ReadFile(filepath.Join(cryptoDir, "ssh-privatekey")); errGo != nil {
    92  		return nil, nil, nil, kv.Wrap(errGo).With("dir", passphraseDir).With("stack", stack.Trace().TrimRuntime())
    93  	}
    94  	if passphrase, errGo = ioutil.ReadFile(filepath.Join(passphraseDir, "ssh-passphrase")); errGo != nil {
    95  		return nil, nil, nil, kv.Wrap(errGo).With("dir", passphraseDir).With("stack", stack.Trace().TrimRuntime())
    96  	}
    97  	return publicPEM, privatePEM, passphrase, nil
    98  }
    99  
   100  func NewWrapper(publicPEM []byte, privatePEM []byte, passphrase []byte) (w *Wrapper, err kv.Error) {
   101  
   102  	if len(publicPEM) == 0 {
   103  		return nil, kv.NewError("public PEM not supplied").With("stack", stack.Trace().TrimRuntime())
   104  	}
   105  
   106  	if len(privatePEM) == 0 {
   107  		return nil, kv.NewError("private PEM not supplied").With("stack", stack.Trace().TrimRuntime())
   108  	}
   109  
   110  	if len(passphrase) == 0 {
   111  		return nil, kv.NewError("passphrase not supplied").With("stack", stack.Trace().TrimRuntime())
   112  	}
   113  
   114  	w = &Wrapper{
   115  		publicPEM: publicPEM,
   116  	}
   117  	// Decrypt the RSA encrypted asymmetric key
   118  	prvBlock, _ := pem.Decode(privatePEM)
   119  	if prvBlock == nil {
   120  		return nil, kv.NewError("private PEM not decoded").With("stack", stack.Trace().TrimRuntime())
   121  	}
   122  	if got, want := prvBlock.Type, "RSA PRIVATE KEY"; got != want {
   123  		return nil, kv.NewError("unknown block type").With("got", got, "want", want).With("stack", stack.Trace().TrimRuntime())
   124  	}
   125  
   126  	// TODO Place the enclave handling here
   127  	decryptedBlock, errGo := x509.DecryptPEMBlock(prvBlock, passphrase)
   128  	if errGo != nil {
   129  		return nil, kv.Wrap(errGo).With("phrase", passphrase).With("stack", stack.Trace().TrimRuntime())
   130  	}
   131  
   132  	// TODO Place the enclave handling here
   133  	w.privateKey, errGo = x509.ParsePKCS1PrivateKey(decryptedBlock)
   134  	if errGo != nil {
   135  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   136  	}
   137  	w.privateKey.Precompute()
   138  
   139  	return w, nil
   140  }
   141  
   142  func (w *Wrapper) getPrivateKey() (privateKey *rsa.PrivateKey, err kv.Error) {
   143  	return w.privateKey, nil
   144  }
   145  
   146  func (w *Wrapper) WrapRequest(r *Request) (encrypted string, err kv.Error) {
   147  
   148  	if w == nil {
   149  		return "", kv.NewError("wrapper missing").With("stack", stack.Trace().TrimRuntime())
   150  	}
   151  
   152  	// Check to see if we have a public key
   153  	if len(w.publicPEM) == 0 {
   154  		return "", kv.NewError("public key missing").With("stack", stack.Trace().TrimRuntime())
   155  	}
   156  
   157  	// Serialize the request
   158  	buffer, err := r.Marshal()
   159  	if err != nil {
   160  		return "", err
   161  	}
   162  	pubBlock, _ := pem.Decode(w.publicPEM)
   163  	pub, errGo := x509.ParsePKCS1PublicKey(pubBlock.Bytes)
   164  	if errGo != nil {
   165  		return "", kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   166  	}
   167  
   168  	// encrypt the data and retrieve a symmetric key
   169  	asymKey, asymData, err := EncryptBlock(buffer)
   170  	if err != nil {
   171  		return "", err
   172  	}
   173  	asymDataB64 := base64.StdEncoding.EncodeToString(asymData)
   174  
   175  	// encrypt the symmetric key using the public RSA PEM
   176  	asymEncKey, errGo := rsa.EncryptOAEP(sha256.New(), rand.Reader, pub, asymKey[:], nil)
   177  	if errGo != nil {
   178  		return "", kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   179  	}
   180  	asymKeyB64 := base64.StdEncoding.EncodeToString(asymEncKey)
   181  
   182  	// append the encrypted semtric key, and the symmetrically encrypted data into a BASE64 result
   183  	return asymKeyB64 + "," + asymDataB64, nil
   184  }
   185  
   186  func (w *Wrapper) unwrapRaw(encrypted string) (decrypted []byte, err kv.Error) {
   187  	// Check we have a private key and a passphrase
   188  	if w == nil {
   189  		return nil, kv.NewError("wrapper missing").With("stack", stack.Trace().TrimRuntime())
   190  	}
   191  
   192  	// break off the fixed length symetric but RSA encrypted key using the comma delimiter
   193  	items := strings.Split(encrypted, ",")
   194  	if len(items) > 2 {
   195  		return nil, kv.NewError("too many values in encrypted data").With("stack", stack.Trace().TrimRuntime())
   196  	}
   197  	if len(items) < 2 {
   198  		return nil, kv.NewError("missing values in encrypted data").With("items", items, "stack", stack.Trace().TrimRuntime())
   199  	}
   200  
   201  	asymKeyDecoded, errGo := base64.StdEncoding.DecodeString(items[0])
   202  	if errGo != nil {
   203  		return nil, kv.Wrap(errGo, "asymmetric key bad").With("stack", stack.Trace().TrimRuntime())
   204  	}
   205  	asymBodyDecoded, errGo := base64.StdEncoding.DecodeString(items[1])
   206  	if errGo != nil {
   207  		return nil, kv.Wrap(errGo, "asymmetric encrypted data bad").With("stack", stack.Trace().TrimRuntime())
   208  	}
   209  
   210  	// Decrypt the RSA encrypted asymmetric key
   211  	prvKey, err := w.getPrivateKey()
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	asymSliceKey, errGo := rsa.DecryptOAEP(sha256.New(), rand.Reader, prvKey, asymKeyDecoded, nil)
   216  	if errGo != nil {
   217  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   218  	}
   219  	asymKey := [32]byte{}
   220  	copy(asymKey[:], asymSliceKey[:32])
   221  
   222  	// Decrypt the data using the decrypted asymmetric key
   223  	return DecryptBlock(asymKey, asymBodyDecoded)
   224  }
   225  
   226  func (w *Wrapper) UnwrapRequest(encrypted string) (r *Request, err kv.Error) {
   227  	decryptedBody, err := w.unwrapRaw(encrypted)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	r, errGo := UnmarshalRequest(decryptedBody)
   233  	if errGo != nil {
   234  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   235  	}
   236  
   237  	return r, nil
   238  }
   239  
   240  func (w *Wrapper) Envelope(r *Request) (e *Envelope, err kv.Error) {
   241  	e = &Envelope{
   242  		Message: Message{
   243  			Experiment: OpenExperiment{
   244  				Status:    r.Experiment.Status,
   245  				PythonVer: r.Experiment.PythonVer,
   246  			},
   247  			TimeAdded:          r.Experiment.TimeAdded,
   248  			ExperimentLifetime: r.Config.Lifetime,
   249  			Resource:           r.Experiment.Resource,
   250  		},
   251  	}
   252  
   253  	e.Message.Payload, err = w.WrapRequest(r)
   254  	return e, err
   255  }
   256  
   257  func (w *Wrapper) Request(e *Envelope) (r *Request, err kv.Error) {
   258  	return w.UnwrapRequest(e.Message.Payload)
   259  }