github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/network/ssh/reachable.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"fmt"
     8  	"net"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/juju/collections/set"
    13  	"github.com/juju/errors"
    14  	"github.com/juju/loggo"
    15  	"golang.org/x/crypto/ssh"
    16  
    17  	"github.com/juju/juju/core/network"
    18  )
    19  
    20  var logger = loggo.GetLogger("juju.network.ssh")
    21  
    22  // Dialer defines a Dial() method matching the signature of net.Dial().
    23  type Dialer interface {
    24  	Dial(network, address string) (net.Conn, error)
    25  }
    26  
    27  // ReachableChecker tries to find ssh hosts that have a public key that matches
    28  // our expectations.
    29  type ReachableChecker interface {
    30  	// FindHost tries to connect to all of the host+port combinations supplied,
    31  	// and tries to do an SSH key negotiation. The first successful negotiation
    32  	// that includes one of the public keys supplied will be returned. If none
    33  	// of them can be validated, then an error will be returned.
    34  	FindHost(hostPorts network.HostPorts, publicKeys []string) (network.HostPort, error)
    35  }
    36  
    37  // NewReachableChecker creates a ReachableChecker that can be used to check for
    38  // Hosts that are viable SSH targets.
    39  // When FindHost is called, we will dial the entries in the given hostPorts, in
    40  // parallel, using the given dialer, closing successfully established
    41  // connections after checking the ssh key. Individual connection errors are
    42  // discarded, and an error is returned only if none of the hostPorts can be
    43  // reached when the given timeout expires.
    44  // If publicKeys is a non empty list, then the SSH host public key will be
    45  // checked. If it is not in the list, that host is not considered valid.
    46  //
    47  // Usually, a net.Dialer initialized with a non-empty Timeout field is passed
    48  // for dialer.
    49  func NewReachableChecker(dialer Dialer, timeout time.Duration) *reachableChecker {
    50  	return &reachableChecker{
    51  		dialer:  dialer,
    52  		timeout: timeout,
    53  	}
    54  }
    55  
    56  // hostKeyChecker checks if this host matches one of allowed public keys
    57  // it uses the golang/x/crypto/ssh/HostKeyCallback to find the host keys on a
    58  // given connection.
    59  type hostKeyChecker struct {
    60  
    61  	// AcceptedKeys is a set of the Marshalled PublicKey content.
    62  	AcceptedKeys set.Strings
    63  
    64  	// Stop will be polled for whether we should stop trying to do any work.
    65  	Stop <-chan struct{}
    66  
    67  	// HostPort is the identifier that corresponds to this connection.
    68  	HostPort network.HostPort
    69  
    70  	// Accepted will be populated with a HostPort if the checker successfully
    71  	// validated a collection.
    72  	Accepted chan network.HostPort
    73  
    74  	// Dialer is a Dialer that allows us to initiate the underlying TCP connection.
    75  	Dialer Dialer
    76  
    77  	// Finished will be set an event when we've finished our check (success or failure).
    78  	Finished chan struct{}
    79  }
    80  
    81  var hostKeyNotInList = errors.New("host key not in expected set")
    82  var hostKeyAccepted = errors.New("host key was accepted, retry")
    83  var hostKeyAcceptedButStopped = errors.New("host key was accepted, but search was stopped")
    84  
    85  func (h *hostKeyChecker) hostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
    86  	// Note: we don't do any advanced checking of the PublicKey, like whether
    87  	// the key is revoked or expired. All we care about is whether it matches
    88  	// the public keys that we consider acceptable
    89  	authKeyForm := ssh.MarshalAuthorizedKey(key)
    90  	debugName := hostname
    91  	if hostname != remote.String() {
    92  		debugName = fmt.Sprintf("%s at %s", hostname, remote.String())
    93  	}
    94  	logger.Tracef("checking host key for %s, with key %q", debugName, authKeyForm)
    95  
    96  	lookupKey := string(key.Marshal())
    97  	if len(h.AcceptedKeys) == 0 || h.AcceptedKeys.Contains(lookupKey) {
    98  		logger.Debugf("accepted host key for: %s", debugName)
    99  		// This key was valid, so return it, but if someone else was found
   100  		// first, still exit.
   101  		select {
   102  		case h.Accepted <- h.HostPort:
   103  			// We have accepted a host, we won't need to call Finished.
   104  			h.Finished = nil
   105  			return hostKeyAccepted
   106  		case <-h.Stop:
   107  			return hostKeyAcceptedButStopped
   108  		}
   109  	}
   110  	logger.Debugf("host key for %s not in our accepted set: log at TRACE to see raw keys", debugName)
   111  	return hostKeyNotInList
   112  }
   113  
   114  // publicKeysToSet converts all the public key values (eg id_rsa.pub) into
   115  // their short hash form. Problems with a key are logged at Warning level, but
   116  // otherwise ignored.
   117  func publicKeysToSet(publicKeys []string) set.Strings {
   118  	acceptedKeys := set.NewStrings()
   119  	for _, pubKey := range publicKeys {
   120  		// key, comment, options, rest, err
   121  		sshKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
   122  		if err != nil {
   123  			logger.Warningf("unable to handle public key: %q\n", pubKey)
   124  			continue
   125  		}
   126  		acceptedKeys.Add(string(sshKey.Marshal()))
   127  	}
   128  	return acceptedKeys
   129  }
   130  
   131  // Check initiates a connection to address described by the checker's HostPort
   132  // member and tries to do an SSH key exchange to determine the preferred public
   133  // key of the remote host.
   134  // It then checks if that key is in the accepted set of keys.
   135  func (h *hostKeyChecker) Check() {
   136  	defer func() {
   137  		// send a finished message unless we're already stopped and nobody
   138  		// is listening
   139  		if h.Finished != nil {
   140  			select {
   141  			case h.Finished <- struct{}{}:
   142  			case <-h.Stop:
   143  			}
   144  		}
   145  	}()
   146  	// TODO(jam): 2017-01-24 One limitation of our algorithm, is that we don't
   147  	// try to limit the negotiation of the keys to our set of possible keys.
   148  	// For example, say we only know about the RSA key for the remote host, but
   149  	// it has been updated to use a ECDSA key as well. Gocrypto/ssh might
   150  	// negotiate to use the "more secure" ECDSA key and we will see that
   151  	// as an invalid key.
   152  	sshConfig := &ssh.ClientConfig{
   153  		HostKeyCallback: h.hostKeyCallback,
   154  	}
   155  	addr := network.DialAddress(h.HostPort)
   156  	logger.Debugf("dialing %s to check host keys", addr)
   157  	conn, err := h.Dialer.Dial("tcp", addr)
   158  	if err != nil {
   159  		logger.Debugf("dial %s failed with: %v", addr, err)
   160  		return
   161  	}
   162  	// No need to do the key exchange if we're already stopping
   163  	select {
   164  	case <-h.Stop:
   165  		_ = conn.Close()
   166  		return
   167  	default:
   168  	}
   169  	logger.Debugf("connected to %s, initiating ssh handshake", addr)
   170  	// NewClientConn will close the underlying net.Conn if it gets an error
   171  	client, _, _, err := ssh.NewClientConn(conn, addr, sshConfig)
   172  	if err == nil {
   173  		// We don't expect this case, because we don't support Auth,
   174  		// but make sure to close it anyway.
   175  		_ = client.Close()
   176  	} else {
   177  		// no need to log these two messages, that's already been done
   178  		// in hostKeyCallback
   179  		if !strings.Contains(err.Error(), hostKeyAccepted.Error()) &&
   180  			!strings.Contains(err.Error(), hostKeyNotInList.Error()) {
   181  			logger.Debugf("%v", err)
   182  		}
   183  	}
   184  }
   185  
   186  type reachableChecker struct {
   187  	dialer  Dialer
   188  	timeout time.Duration
   189  }
   190  
   191  // FindHost takes a list of possible host+port combinations and possible public
   192  // keys that the SSH server could be using. We make an attempt to connect to
   193  // each of those addresses and do an SSH handshake negotiation. We then check
   194  // if the SSH server's negotiated public key is in our allowed set. The first
   195  // address to successfully negotiate will be returned. If none of them succeed,
   196  // and error will be returned.
   197  func (r *reachableChecker) FindHost(hostPorts network.HostPorts, publicKeys []string) (network.HostPort, error) {
   198  	uniqueHPs := hostPorts.Unique()
   199  	successful := make(chan network.HostPort)
   200  	stop := make(chan struct{})
   201  	// We use a channel instead of a sync.WaitGroup so that we can return as
   202  	// soon as we get one connected. We'll signal the rest to stop via the
   203  	// 'stop' channel.
   204  	finished := make(chan struct{}, len(uniqueHPs))
   205  
   206  	acceptedKeys := publicKeysToSet(publicKeys)
   207  	for _, hostPort := range uniqueHPs {
   208  		checker := &hostKeyChecker{
   209  			AcceptedKeys: acceptedKeys,
   210  			Stop:         stop,
   211  			Accepted:     successful,
   212  			HostPort:     hostPort,
   213  			Dialer:       r.dialer,
   214  			Finished:     finished,
   215  		}
   216  		go checker.Check()
   217  	}
   218  
   219  	timeout := time.After(r.timeout)
   220  	for finishedCount := 0; finishedCount < len(uniqueHPs); {
   221  		select {
   222  		case result := <-successful:
   223  			logger.Infof("found %v has an acceptable ssh key", result)
   224  			close(stop)
   225  			return result, nil
   226  		case <-finished:
   227  			finishedCount++
   228  		case <-timeout:
   229  			break
   230  		}
   231  	}
   232  	close(stop)
   233  	return nil, errors.Errorf("cannot connect to any address: %v", hostPorts)
   234  }