github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/network/ssh/reachable.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"fmt"
     8  	"net"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/juju/collections/set"
    13  	"github.com/juju/errors"
    14  	"github.com/juju/loggo"
    15  	"golang.org/x/crypto/ssh"
    16  
    17  	"github.com/juju/juju/network"
    18  )
    19  
    20  var logger = loggo.GetLogger("juju.network.ssh")
    21  
    22  // Dialer defines a Dial() method matching the signature of net.Dial().
    23  type Dialer interface {
    24  	Dial(network, address string) (net.Conn, error)
    25  }
    26  
    27  // ReachableChecker tries to find ssh hosts that have a public key that matches
    28  // our expectations.
    29  type ReachableChecker interface {
    30  	// FindHost tries to connect to all of the host+port combinations supplied,
    31  	// and tries to do an SSH key negotiation. The first successful negotiation
    32  	// that includes one of the public keys supplied will be returned. If none
    33  	// of them can be validated, then an error will be returned.
    34  	FindHost(hostPorts []network.HostPort, publicKeys []string) (network.HostPort, error)
    35  }
    36  
    37  // NewReachableChecker creates a ReachableChecker that can be used to check for
    38  // Hosts that are viable SSH targets.
    39  // When FindHost is called, we will dial the entries in the given hostPorts, in
    40  // parallel, using the given dialer, closing successfully established
    41  // connections after checking the ssh key. Individual connection errors are
    42  // discarded, and an error is returned only if none of the hostPorts can be
    43  // reached when the given timeout expires.
    44  // If publicKeys is a non empty list, then the SSH host public key will be
    45  // checked. If it is not in the list, that host is not considered valid.
    46  //
    47  // Usually, a net.Dialer initialized with a non-empty Timeout field is passed
    48  // for dialer.
    49  func NewReachableChecker(dialer Dialer, timeout time.Duration) *reachableChecker {
    50  	return &reachableChecker{
    51  		dialer:  dialer,
    52  		timeout: timeout,
    53  	}
    54  }
    55  
    56  // hostKeyChecker checks if this host matches one of allowed public keys
    57  // it uses the golang/x/crypto/ssh/HostKeyCallback to find the host keys on a
    58  // given connection.
    59  type hostKeyChecker struct {
    60  
    61  	// AcceptedKeys is a set of the Marshalled PublicKey content.
    62  	AcceptedKeys set.Strings
    63  
    64  	// Stop will be polled for whether we should stop trying to do any work
    65  	Stop <-chan struct{}
    66  
    67  	// HostPort is the identifier that corresponds to this connection
    68  	HostPort network.HostPort
    69  
    70  	// Accepted will be passed HostPort if it validated the connection
    71  	Accepted chan network.HostPort
    72  
    73  	// Dialer is a Dialer that allows us to initiate the underlying TCP connection
    74  	Dialer Dialer
    75  
    76  	// Finished will be set an event when we've finished our check (success or failure)
    77  	Finished chan struct{}
    78  }
    79  
    80  var hostKeyNotInList = errors.New("host key not in expected set")
    81  var hostKeyAccepted = errors.New("host key was accepted, retry")
    82  var hostKeyAcceptedButStopped = errors.New("host key was accepted, but search was stopped")
    83  
    84  func (h *hostKeyChecker) hostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
    85  	// Note: we don't do any advanced checking of the PublicKey, like whether
    86  	// the key is revoked or expired. All we care about is whether it matches
    87  	// the public keys that we consider acceptable
    88  	authKeyForm := ssh.MarshalAuthorizedKey(key)
    89  	debugName := hostname
    90  	if hostname != remote.String() {
    91  		debugName = fmt.Sprintf("%s at %s", hostname, remote.String())
    92  	}
    93  	logger.Tracef("checking host key for %s, with key %q", debugName, authKeyForm)
    94  
    95  	lookupKey := string(key.Marshal())
    96  	if len(h.AcceptedKeys) == 0 || h.AcceptedKeys.Contains(lookupKey) {
    97  		logger.Debugf("accepted host key for: %s", debugName)
    98  		// This key was valid, so return it, but if someone else was found
    99  		// first, still exit.
   100  		select {
   101  		case h.Accepted <- h.HostPort:
   102  			// We have accepted a host, we won't need to call Finished.
   103  			h.Finished = nil
   104  			return hostKeyAccepted
   105  		case <-h.Stop:
   106  			return hostKeyAcceptedButStopped
   107  		}
   108  	}
   109  	logger.Debugf("host key for %s not in our accepted set: log at TRACE to see raw keys", debugName)
   110  	return hostKeyNotInList
   111  }
   112  
   113  // publicKeysToSet converts all the public key values (eg id_rsa.pub) into
   114  // their short hash form. Problems with a key are logged at Warning level, but
   115  // otherwise ignored.
   116  func publicKeysToSet(publicKeys []string) set.Strings {
   117  	acceptedKeys := set.NewStrings()
   118  	for _, pubKey := range publicKeys {
   119  		// key, comment, options, rest, err
   120  		sshKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
   121  		if err != nil {
   122  			logger.Warningf("unable to handle public key: %q\n", pubKey)
   123  			continue
   124  		}
   125  		acceptedKeys.Add(string(sshKey.Marshal()))
   126  	}
   127  	return acceptedKeys
   128  }
   129  
   130  // Check initiates a connection to HostPort and tries to do an SSH key
   131  // exchange to determine the preferred public key of the remote host.
   132  // It then checks if that key is in the accepted set of keys.
   133  func (h *hostKeyChecker) Check() {
   134  	defer func() {
   135  		// send a finished message unless we're already stopped and nobody
   136  		// is listening
   137  		if h.Finished != nil {
   138  			select {
   139  			case h.Finished <- struct{}{}:
   140  			case <-h.Stop:
   141  			}
   142  		}
   143  	}()
   144  	// TODO(jam): 2017-01-24 One limitation of our algorithm, is that we don't
   145  	// try to limit the negotiation of the keys to our set of possible keys.
   146  	// For example, say we only know about the RSA key for the remote host, but
   147  	// it has been updated to use a ECDSA key as well. Gocrypto/ssh might
   148  	// negotiate to use the "more secure" ECDSA key and we will see that
   149  	// as an invalid key.
   150  	sshconfig := &ssh.ClientConfig{
   151  		HostKeyCallback: h.hostKeyCallback,
   152  	}
   153  	addr := h.HostPort.NetAddr()
   154  	logger.Debugf("dialing %s to check host keys", addr)
   155  	conn, err := h.Dialer.Dial("tcp", addr)
   156  	if err != nil {
   157  		logger.Debugf("dial %s failed with: %v", addr, err)
   158  		return
   159  	}
   160  	// No need to do the key exchange if we're already stopping
   161  	select {
   162  	case <-h.Stop:
   163  		conn.Close()
   164  		return
   165  	default:
   166  	}
   167  	logger.Debugf("connected to %s, initiating ssh handshake", addr)
   168  	// NewClientConn will close the underlying net.Conn if it gets an error
   169  	client, _, _, err := ssh.NewClientConn(conn, addr, sshconfig)
   170  	if err == nil {
   171  		// We don't expect this case, because we don't support Auth,
   172  		// but make sure to close it anyway.
   173  		client.Close()
   174  	} else {
   175  		// no need to log these two messages, that's already been done
   176  		// in hostKeyCallback
   177  		if !strings.Contains(err.Error(), hostKeyAccepted.Error()) &&
   178  			!strings.Contains(err.Error(), hostKeyNotInList.Error()) {
   179  			logger.Debugf("%v", err)
   180  		}
   181  	}
   182  }
   183  
   184  type reachableChecker struct {
   185  	dialer  Dialer
   186  	timeout time.Duration
   187  }
   188  
   189  // FindHost takes a list of possible host+port combinations and possible public
   190  // keys that the SSH server could be using. We make an attempt to connect to
   191  // each of those addresses and do an SSH handshake negotiation. We then check
   192  // if the SSH server's negotiated public key is in our allowed set. The first
   193  // address to successfully negotiate will be returned. If none of them succeed,
   194  // and error will be returned.
   195  func (r *reachableChecker) FindHost(hostPorts []network.HostPort, publicKeys []string) (network.HostPort, error) {
   196  	uniqueHPs := network.UniqueHostPorts(hostPorts)
   197  	successful := make(chan network.HostPort)
   198  	stop := make(chan struct{})
   199  	// We use a channel instead of a sync.WaitGroup so that we can return as
   200  	// soon as we get one connected. We'll signal the rest to stop via the
   201  	// 'stop' channel.
   202  	finished := make(chan struct{}, len(uniqueHPs))
   203  
   204  	acceptedKeys := publicKeysToSet(publicKeys)
   205  	for _, hostPort := range uniqueHPs {
   206  		checker := &hostKeyChecker{
   207  			AcceptedKeys: acceptedKeys,
   208  			Stop:         stop,
   209  			Accepted:     successful,
   210  			HostPort:     hostPort,
   211  			Dialer:       r.dialer,
   212  			Finished:     finished,
   213  		}
   214  		go checker.Check()
   215  	}
   216  
   217  	timeout := time.After(r.timeout)
   218  	for finishedCount := 0; finishedCount < len(uniqueHPs); {
   219  		select {
   220  		case result := <-successful:
   221  			logger.Infof("found %v has an acceptable ssh key", result)
   222  			close(stop)
   223  			return result, nil
   224  		case <-finished:
   225  			finishedCount++
   226  		case <-timeout:
   227  			break
   228  		}
   229  	}
   230  	close(stop)
   231  	return network.HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts)
   232  }