github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/internal/runner/signing_store.go (about)

     1  // Copyright 2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
     2  
     3  package runner
     4  
     5  // This file contains the implementation of a public key store that is used
     6  // by clients of the system to sign their messages being sent across queue
     7  // infrastructure
     8  //
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"fmt"
    13  	"os"
    14  	"path/filepath"
    15  	"sort"
    16  	"strings"
    17  	"time"
    18  
    19  	"io/ioutil"
    20  	"sync"
    21  
    22  	"github.com/go-stack/stack"
    23  	"github.com/jjeffery/kv"
    24  	"golang.org/x/crypto/ssh"
    25  )
    26  
    27  type Signatures struct {
    28  	sigs map[string]ssh.PublicKey // The known public keys retrieved from the secrets mount directory
    29  	dir  string                   // Secrets mount directory
    30  	sync.Mutex
    31  }
    32  
    33  type RefreshContext struct {
    34  	ctx    context.Context
    35  	cancel context.CancelFunc
    36  	sync.Mutex
    37  }
    38  
    39  var (
    40  	// signatures contains a map with the index being the prefix of queue names and their public keys
    41  	signatures = Signatures{
    42  		sigs: map[string]ssh.PublicKey{},
    43  	}
    44  
    45  	refreshContext = RefreshContext{}
    46  )
    47  
    48  func init() {
    49  	ctx, cancel := context.WithCancel(context.Background())
    50  	refreshContext = RefreshContext{
    51  		ctx:    ctx,
    52  		cancel: cancel,
    53  	}
    54  }
    55  
    56  func (refresh *RefreshContext) Reset() {
    57  	refresh.Lock()
    58  	defer refresh.Unlock()
    59  
    60  	refresh.cancel()
    61  	refresh.ctx, refresh.cancel = context.WithCancel(context.Background())
    62  }
    63  
    64  func extractPubKey(data []byte) (key ssh.PublicKey, err kv.Error) {
    65  	if !bytes.HasPrefix(data, []byte("ssh-ed25519 ")) {
    66  		return key, kv.NewError("no ssh-ed25519 prefix").With("stack", stack.Trace().TrimRuntime())
    67  	}
    68  
    69  	pub, _, _, _, errGo := ssh.ParseAuthorizedKey(data)
    70  	if errGo != nil {
    71  		return key, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
    72  	}
    73  	if pub.Type() != ssh.KeyAlgoED25519 {
    74  		return key, kv.NewError("not ssh-ed25519").With("stack", stack.Trace().TrimRuntime())
    75  	}
    76  	return pub, nil
    77  }
    78  
    79  func (s *Signatures) update(fn string) (err kv.Error) {
    80  	data, errGo := ioutil.ReadFile(fn)
    81  	if errGo != nil {
    82  		if os.IsNotExist(errGo) {
    83  			s.Lock()
    84  			delete(s.sigs, filepath.Base(fn))
    85  			s.Unlock()
    86  			return nil
    87  		}
    88  		return kv.Wrap(errGo).With("filename", fn).With("stack", stack.Trace().TrimRuntime())
    89  	}
    90  
    91  	pub, err := extractPubKey(data)
    92  	if err != nil {
    93  		return err.With("filename", fn)
    94  	}
    95  
    96  	s.Lock()
    97  	s.sigs[filepath.Base(fn)] = pub
    98  	s.Unlock()
    99  
   100  	return nil
   101  }
   102  
   103  // getFingerprint can be used to have the fingerprint of a file containing a pem formatted rsa public key.
   104  // A base64 string of the binary finger print will be returned.
   105  //
   106  func getFingerprint(fn string) (fingerprint string, err kv.Error) {
   107  	data, errGo := ioutil.ReadFile(fn)
   108  	if errGo != nil {
   109  		return "", kv.Wrap(errGo).With("filename", fn).With("stack", stack.Trace().TrimRuntime())
   110  	}
   111  
   112  	key, err := extractPubKey(data)
   113  	if err != nil {
   114  		return "", err.With("filename", fn)
   115  	}
   116  
   117  	return ssh.FingerprintSHA256(key), nil
   118  }
   119  
   120  // GetSignatures returns the signing public key struct for accessing
   121  // methods related to signature selection etc.
   122  //
   123  func GetSignatures() (s *Signatures) {
   124  	return &signatures
   125  }
   126  
   127  // GetSignaturesRefresh will return a context that will be cancelled on
   128  // the next refresh of signatures completing.  This us principally for testing
   129  // at this time
   130  //
   131  func GetSignaturesRefresh() (doneCtx context.Context) {
   132  	refreshContext.Lock()
   133  	defer refreshContext.Unlock()
   134  
   135  	return refreshContext.ctx
   136  }
   137  
   138  // Dir returns the absolute directory path from which signature files are being
   139  // retrieved and used
   140  func (s *Signatures) Dir() (dir string) {
   141  	signatures.Lock()
   142  	defer signatures.Unlock()
   143  
   144  	return signatures.dir
   145  }
   146  
   147  // Get retrieves a signature that has a queue name supplied by the caller
   148  // as an exact match
   149  //
   150  func (s *Signatures) Get(q string) (key ssh.PublicKey, fingerprint string, err kv.Error) {
   151  	s.Lock()
   152  	key, isPresent := s.sigs[q]
   153  	s.Unlock()
   154  
   155  	if !isPresent {
   156  		return nil, "", kv.NewError("not found").With("queue", q).With("stack", stack.Trace().TrimRuntime())
   157  	}
   158  	return key, ssh.FingerprintSHA256(key), nil
   159  }
   160  
   161  // Get retrieves a signature that has a queue name supplied by the caller
   162  // using the longest prefix matched queue name for the supplied queue name
   163  // that can be found.
   164  //
   165  func (s *Signatures) Select(q string) (key ssh.PublicKey, fingerprint string, err kv.Error) {
   166  	// The lock is kept until we are done to ensure once a prefix is matched to its longest length
   167  	// that we still have the public key for it
   168  	s.Lock()
   169  	defer s.Unlock()
   170  	prefixes := make([]string, 0, len(s.sigs))
   171  	for k := range s.sigs {
   172  		prefixes = append(prefixes, k)
   173  	}
   174  	sort.Strings(prefixes)
   175  
   176  	// Start with no valid match as a prefix
   177  	bestMatch := ""
   178  	wouldBeAt := 0
   179  
   180  	// Roll through the sorted prefixes while there is a still a valid signature name prefix of the q (queue)
   181  	// names, stop when the q supplied no longer satisfies the prefix and the one prior would be
   182  	// the shortest signature prefix of the q name.
   183  	for {
   184  		if prefixes[wouldBeAt] == q {
   185  			bestMatch = prefixes[wouldBeAt]
   186  			break
   187  		}
   188  		if strings.HasPrefix(q, prefixes[wouldBeAt]) {
   189  			if len(bestMatch) == 0 || len(bestMatch) < len(prefixes[wouldBeAt]) {
   190  				bestMatch = prefixes[wouldBeAt]
   191  			}
   192  		}
   193  		if wouldBeAt += 1; wouldBeAt >= len(prefixes) {
   194  			break
   195  		}
   196  	}
   197  
   198  	if len(bestMatch) == 0 {
   199  		return nil, "", kv.NewError("not found").With("queue", q).With("stack", stack.Trace().TrimRuntime())
   200  	}
   201  	key = s.sigs[bestMatch]
   202  	return key, ssh.FingerprintSHA256(key), nil
   203  }
   204  
   205  func reportErr(err kv.Error, errorC chan<- kv.Error) {
   206  	if err == nil {
   207  		return
   208  	}
   209  
   210  	// Remove the entry for this function from the stack
   211  	stk := stack.Trace().TrimRuntime()[1:]
   212  
   213  	defer func() {
   214  		_ = recover()
   215  		if err != nil {
   216  			fmt.Println(err.With("stack", stk).Error())
   217  		}
   218  	}()
   219  
   220  	// Try to send the error and backoff to simply printing it if
   221  	// we could not send it to the reporting module
   222  	select {
   223  	case errorC <- err.With("stack", stk):
   224  	case <-time.After(time.Second):
   225  		fmt.Println(err.With("stack", stk).Error())
   226  	}
   227  }
   228  
   229  // InitSignatures is used to initialize a watch for signatures
   230  func InitSignatures(ctx context.Context, configuredDir string, errorC chan<- kv.Error) {
   231  
   232  	dir, errGo := filepath.Abs(configuredDir)
   233  	if errGo != nil {
   234  		reportErr(kv.Wrap(errGo).With("dir", dir), errorC)
   235  	}
   236  
   237  	// Wait until the directory exists and accessed at least once
   238  	updatedEntries, errGo := ioutil.ReadDir(dir)
   239  	// Record the last modified time for the file representing a signature key
   240  	entries := make(map[string]time.Time, len(updatedEntries))
   241  
   242  	// Set the last time an error was reported to more then 15 minutes ago so
   243  	// that the first error is displayed immediately
   244  	lastErrNotify := time.Now().Add(-1 * time.Hour)
   245  
   246  	// Wait until we get at least one good read from the
   247  	// directory being monitored for signatures
   248  	for {
   249  		if errGo == nil {
   250  			break
   251  		}
   252  
   253  		// Only display this particular error
   254  		if time.Since(lastErrNotify) > time.Duration(15*time.Minute) {
   255  			if errGo != nil {
   256  				reportErr(kv.Wrap(errGo).With("dir", dir), errorC)
   257  			}
   258  			lastErrNotify = time.Now()
   259  		}
   260  
   261  		select {
   262  		case <-time.After(10 * time.Second):
   263  			_, errGo = ioutil.ReadDir(dir)
   264  		case <-ctx.Done():
   265  			return
   266  		}
   267  	}
   268  
   269  	// Once we know we have a working signatures storage directory save its location
   270  	// so that test software can inject certificates of their own when running
   271  	// with a production server under test
   272  	signatures.Lock()
   273  	signatures.dir = dir
   274  	signatures.Unlock()
   275  
   276  	// Event loop for the watcher until the server shuts down
   277  	for {
   278  		select {
   279  
   280  		case <-time.After(10 * time.Second):
   281  
   282  			// It is possible that the signatures store is changed during runtime
   283  			// so refresh the location
   284  			signatures.Lock()
   285  			dir = signatures.dir
   286  			signatures.Unlock()
   287  
   288  			// A lookaside collection for checking the presence of directory entries
   289  			// that are no longer found on the disk
   290  			deletionCheck := make(map[string]time.Time, len(entries))
   291  
   292  			if updatedEntries, errGo = ioutil.ReadDir(dir); errGo != nil {
   293  				reportErr(kv.Wrap(errGo).With("dir", dir), errorC)
   294  				continue
   295  			}
   296  
   297  			for _, entry := range updatedEntries {
   298  
   299  				if entry.IsDir() {
   300  					continue
   301  				}
   302  
   303  				if entry.Name()[0] == '.' {
   304  					continue
   305  				}
   306  
   307  				// Symbolic link checking
   308  				if entry.Mode()&os.ModeSymlink != 0 {
   309  					target, errGo := filepath.EvalSymlinks(filepath.Join(dir, entry.Name()))
   310  					if errGo != nil {
   311  						reportErr(kv.Wrap(errGo).With("dir", dir, "target", entry.Name()), errorC)
   312  						continue
   313  					}
   314  					if entry, errGo = os.Stat(target); errGo != nil {
   315  						reportErr(kv.Wrap(errGo).With("dir", dir, "target", entry.Name()), errorC)
   316  						continue
   317  					}
   318  				}
   319  
   320  				curEntry, isPresent := entries[entry.Name()]
   321  				if !isPresent || curEntry.Round(time.Second) != entry.ModTime().Round(time.Second) {
   322  					entries[entry.Name()] = entry.ModTime().Round(time.Second)
   323  					if err := signatures.update(filepath.Join(dir, entry.Name())); err != nil {
   324  						// info is a special file that is used to prevent the secret from not
   325  						// being created by Kubernetes when there are no secrets to be mounted
   326  						if entry.Name() != "info" {
   327  							reportErr(err, errorC)
   328  						}
   329  					}
   330  				}
   331  
   332  				deletionCheck[entry.Name()] = curEntry
   333  			}
   334  			for name := range entries {
   335  				if _, isPresent := deletionCheck[name]; !isPresent {
   336  					// Have the update method check for the presence of the file,
   337  					// it will cleanup if the file is not found
   338  					signatures.update(filepath.Join(dir, name))
   339  					// Now remove the missing from our small lookaside collection
   340  					delete(entries, name)
   341  				}
   342  			}
   343  
   344  			// Signal any waiters that the refresh has been processed and replace the context
   345  			// used for this with a new one that can be waited on by observers
   346  			refreshContext.Reset()
   347  
   348  		case <-ctx.Done():
   349  			return
   350  		}
   351  	}
   352  }