github.com/axw/juju@v0.0.0-20161005053422-4bd6544d08d4/worker/authenticationworker/worker.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package authenticationworker
     5  
     6  import (
     7  	"strings"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/loggo"
    11  	"github.com/juju/utils/os"
    12  	"github.com/juju/utils/set"
    13  	"github.com/juju/utils/ssh"
    14  	"gopkg.in/juju/names.v2"
    15  	"gopkg.in/tomb.v1"
    16  
    17  	"github.com/juju/juju/agent"
    18  	"github.com/juju/juju/api/keyupdater"
    19  	"github.com/juju/juju/watcher"
    20  	"github.com/juju/juju/worker"
    21  )
    22  
    23  // The user name used to ssh into Juju nodes.
    24  // Override for testing.
    25  var SSHUser = "ubuntu"
    26  
    27  var logger = loggo.GetLogger("juju.worker.authenticationworker")
    28  
    29  type keyupdaterWorker struct {
    30  	st   *keyupdater.State
    31  	tomb tomb.Tomb
    32  	tag  names.MachineTag
    33  	// jujuKeys are the most recently retrieved keys from state.
    34  	jujuKeys set.Strings
    35  	// nonJujuKeys are those added externally to auth keys file
    36  	// such keys do not have comments with the Juju: prefix.
    37  	nonJujuKeys []string
    38  }
    39  
    40  // NewWorker returns a worker that keeps track of
    41  // the machine's authorised ssh keys and ensures the
    42  // ~/.ssh/authorized_keys file is up to date.
    43  func NewWorker(st *keyupdater.State, agentConfig agent.Config) (worker.Worker, error) {
    44  	machineTag, ok := agentConfig.Tag().(names.MachineTag)
    45  	if !ok {
    46  		return nil, errors.NotValidf("machine tag %v", agentConfig.Tag())
    47  	}
    48  	if os.HostOS() == os.Windows {
    49  		return worker.NewNoOpWorker(), nil
    50  	}
    51  	w, err := watcher.NewNotifyWorker(watcher.NotifyConfig{
    52  		Handler: &keyupdaterWorker{
    53  			st:  st,
    54  			tag: machineTag,
    55  		},
    56  	})
    57  	if err != nil {
    58  		return nil, errors.Trace(err)
    59  	}
    60  	return w, nil
    61  }
    62  
    63  // SetUp is defined on the worker.NotifyWatchHandler interface.
    64  func (kw *keyupdaterWorker) SetUp() (watcher.NotifyWatcher, error) {
    65  	// Record the keys Juju knows about.
    66  	jujuKeys, err := kw.st.AuthorisedKeys(kw.tag)
    67  	if err != nil {
    68  		err = errors.Annotatef(err, "reading Juju ssh keys for %q", kw.tag)
    69  		logger.Infof(err.Error())
    70  		return nil, err
    71  	}
    72  	kw.jujuKeys = set.NewStrings(jujuKeys...)
    73  
    74  	// Read the keys currently in ~/.ssh/authorised_keys.
    75  	sshKeys, err := ssh.ListKeys(SSHUser, ssh.FullKeys)
    76  	if err != nil {
    77  		err = errors.Annotatef(err, "reading ssh authorized keys for %q", kw.tag)
    78  		logger.Infof(err.Error())
    79  		return nil, err
    80  	}
    81  	// Record any keys not added by Juju.
    82  	for _, key := range sshKeys {
    83  		_, comment, err := ssh.KeyFingerprint(key)
    84  		// Also record keys which we cannot parse.
    85  		if err != nil || !strings.HasPrefix(comment, ssh.JujuCommentPrefix) {
    86  			kw.nonJujuKeys = append(kw.nonJujuKeys, key)
    87  		}
    88  	}
    89  	// Write out the ssh authorised keys file to match the current state of the world.
    90  	if err := kw.writeSSHKeys(jujuKeys); err != nil {
    91  		err = errors.Annotate(err, "adding current Juju keys to ssh authorised keys")
    92  		logger.Infof(err.Error())
    93  		return nil, err
    94  	}
    95  
    96  	w, err := kw.st.WatchAuthorisedKeys(kw.tag)
    97  	if err != nil {
    98  		err = errors.Annotate(err, "starting key updater worker")
    99  		logger.Infof(err.Error())
   100  		return nil, err
   101  	}
   102  	logger.Infof("%q key updater worker started", kw.tag)
   103  	return w, nil
   104  }
   105  
   106  // writeSSHKeys writes out a new ~/.ssh/authorised_keys file, retaining any non Juju keys
   107  // and adding the specified set of Juju keys.
   108  func (kw *keyupdaterWorker) writeSSHKeys(jujuKeys []string) error {
   109  	allKeys := kw.nonJujuKeys
   110  	// Ensure any Juju keys have the required prefix in their comment.
   111  	for i, key := range jujuKeys {
   112  		jujuKeys[i] = ssh.EnsureJujuComment(key)
   113  	}
   114  	allKeys = append(allKeys, jujuKeys...)
   115  	return ssh.ReplaceKeys(SSHUser, allKeys...)
   116  }
   117  
   118  // Handle is defined on the worker.NotifyWatchHandler interface.
   119  func (kw *keyupdaterWorker) Handle(_ <-chan struct{}) error {
   120  	// Read the keys that Juju has.
   121  	newKeys, err := kw.st.AuthorisedKeys(kw.tag)
   122  	if err != nil {
   123  		err = errors.Annotatef(err, "reading Juju ssh keys for %q", kw.tag)
   124  		logger.Infof(err.Error())
   125  		return err
   126  	}
   127  	// Figure out if any keys have been added or deleted.
   128  	newJujuKeys := set.NewStrings(newKeys...)
   129  	deleted := kw.jujuKeys.Difference(newJujuKeys)
   130  	added := newJujuKeys.Difference(kw.jujuKeys)
   131  	if added.Size() > 0 || deleted.Size() > 0 {
   132  		logger.Debugf("adding ssh keys to authorised keys: %v", added)
   133  		logger.Debugf("deleting ssh keys from authorised keys: %v", deleted)
   134  		if err = kw.writeSSHKeys(newKeys); err != nil {
   135  			err = errors.Annotate(err, "updating ssh keys")
   136  			logger.Infof(err.Error())
   137  			return err
   138  		}
   139  	}
   140  	kw.jujuKeys = newJujuKeys
   141  	return nil
   142  }
   143  
   144  // TearDown is defined on the worker.NotifyWatchHandler interface.
   145  func (kw *keyupdaterWorker) TearDown() error {
   146  	// Nothing to do here.
   147  	return nil
   148  }