github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/network/ssh/reachable.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package ssh 5 6 import ( 7 "fmt" 8 "net" 9 "strings" 10 "time" 11 12 "github.com/juju/collections/set" 13 "github.com/juju/errors" 14 "github.com/juju/loggo" 15 "golang.org/x/crypto/ssh" 16 17 "github.com/juju/juju/core/network" 18 ) 19 20 var logger = loggo.GetLogger("juju.network.ssh") 21 22 // Dialer defines a Dial() method matching the signature of net.Dial(). 23 type Dialer interface { 24 Dial(network, address string) (net.Conn, error) 25 } 26 27 // ReachableChecker tries to find ssh hosts that have a public key that matches 28 // our expectations. 29 type ReachableChecker interface { 30 // FindHost tries to connect to all of the host+port combinations supplied, 31 // and tries to do an SSH key negotiation. The first successful negotiation 32 // that includes one of the public keys supplied will be returned. If none 33 // of them can be validated, then an error will be returned. 34 FindHost(hostPorts network.HostPorts, publicKeys []string) (network.HostPort, error) 35 } 36 37 // NewReachableChecker creates a ReachableChecker that can be used to check for 38 // Hosts that are viable SSH targets. 39 // When FindHost is called, we will dial the entries in the given hostPorts, in 40 // parallel, using the given dialer, closing successfully established 41 // connections after checking the ssh key. Individual connection errors are 42 // discarded, and an error is returned only if none of the hostPorts can be 43 // reached when the given timeout expires. 44 // If publicKeys is a non empty list, then the SSH host public key will be 45 // checked. If it is not in the list, that host is not considered valid. 46 // 47 // Usually, a net.Dialer initialized with a non-empty Timeout field is passed 48 // for dialer. 49 func NewReachableChecker(dialer Dialer, timeout time.Duration) *reachableChecker { 50 return &reachableChecker{ 51 dialer: dialer, 52 timeout: timeout, 53 } 54 } 55 56 // hostKeyChecker checks if this host matches one of allowed public keys 57 // it uses the golang/x/crypto/ssh/HostKeyCallback to find the host keys on a 58 // given connection. 59 type hostKeyChecker struct { 60 61 // AcceptedKeys is a set of the Marshalled PublicKey content. 62 AcceptedKeys set.Strings 63 64 // Stop will be polled for whether we should stop trying to do any work. 65 Stop <-chan struct{} 66 67 // HostPort is the identifier that corresponds to this connection. 68 HostPort network.HostPort 69 70 // Accepted will be populated with a HostPort if the checker successfully 71 // validated a collection. 72 Accepted chan network.HostPort 73 74 // Dialer is a Dialer that allows us to initiate the underlying TCP connection. 75 Dialer Dialer 76 77 // Finished will be set an event when we've finished our check (success or failure). 78 Finished chan struct{} 79 } 80 81 var hostKeyNotInList = errors.New("host key not in expected set") 82 var hostKeyAccepted = errors.New("host key was accepted, retry") 83 var hostKeyAcceptedButStopped = errors.New("host key was accepted, but search was stopped") 84 85 func (h *hostKeyChecker) hostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { 86 // Note: we don't do any advanced checking of the PublicKey, like whether 87 // the key is revoked or expired. All we care about is whether it matches 88 // the public keys that we consider acceptable 89 authKeyForm := ssh.MarshalAuthorizedKey(key) 90 debugName := hostname 91 if hostname != remote.String() { 92 debugName = fmt.Sprintf("%s at %s", hostname, remote.String()) 93 } 94 logger.Tracef("checking host key for %s, with key %q", debugName, authKeyForm) 95 96 lookupKey := string(key.Marshal()) 97 if len(h.AcceptedKeys) == 0 || h.AcceptedKeys.Contains(lookupKey) { 98 logger.Debugf("accepted host key for: %s", debugName) 99 // This key was valid, so return it, but if someone else was found 100 // first, still exit. 101 select { 102 case h.Accepted <- h.HostPort: 103 // We have accepted a host, we won't need to call Finished. 104 h.Finished = nil 105 return hostKeyAccepted 106 case <-h.Stop: 107 return hostKeyAcceptedButStopped 108 } 109 } 110 logger.Debugf("host key for %s not in our accepted set: log at TRACE to see raw keys", debugName) 111 return hostKeyNotInList 112 } 113 114 // publicKeysToSet converts all the public key values (eg id_rsa.pub) into 115 // their short hash form. Problems with a key are logged at Warning level, but 116 // otherwise ignored. 117 func publicKeysToSet(publicKeys []string) set.Strings { 118 acceptedKeys := set.NewStrings() 119 for _, pubKey := range publicKeys { 120 // key, comment, options, rest, err 121 sshKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) 122 if err != nil { 123 logger.Warningf("unable to handle public key: %q\n", pubKey) 124 continue 125 } 126 acceptedKeys.Add(string(sshKey.Marshal())) 127 } 128 return acceptedKeys 129 } 130 131 // Check initiates a connection to address described by the checker's HostPort 132 // member and tries to do an SSH key exchange to determine the preferred public 133 // key of the remote host. 134 // It then checks if that key is in the accepted set of keys. 135 func (h *hostKeyChecker) Check() { 136 defer func() { 137 // send a finished message unless we're already stopped and nobody 138 // is listening 139 if h.Finished != nil { 140 select { 141 case h.Finished <- struct{}{}: 142 case <-h.Stop: 143 } 144 } 145 }() 146 // TODO(jam): 2017-01-24 One limitation of our algorithm, is that we don't 147 // try to limit the negotiation of the keys to our set of possible keys. 148 // For example, say we only know about the RSA key for the remote host, but 149 // it has been updated to use a ECDSA key as well. Gocrypto/ssh might 150 // negotiate to use the "more secure" ECDSA key and we will see that 151 // as an invalid key. 152 sshConfig := &ssh.ClientConfig{ 153 HostKeyCallback: h.hostKeyCallback, 154 } 155 addr := network.DialAddress(h.HostPort) 156 logger.Debugf("dialing %s to check host keys", addr) 157 conn, err := h.Dialer.Dial("tcp", addr) 158 if err != nil { 159 logger.Debugf("dial %s failed with: %v", addr, err) 160 return 161 } 162 // No need to do the key exchange if we're already stopping 163 select { 164 case <-h.Stop: 165 _ = conn.Close() 166 return 167 default: 168 } 169 logger.Debugf("connected to %s, initiating ssh handshake", addr) 170 // NewClientConn will close the underlying net.Conn if it gets an error 171 client, _, _, err := ssh.NewClientConn(conn, addr, sshConfig) 172 if err == nil { 173 // We don't expect this case, because we don't support Auth, 174 // but make sure to close it anyway. 175 _ = client.Close() 176 } else { 177 // no need to log these two messages, that's already been done 178 // in hostKeyCallback 179 if !strings.Contains(err.Error(), hostKeyAccepted.Error()) && 180 !strings.Contains(err.Error(), hostKeyNotInList.Error()) { 181 logger.Debugf("%v", err) 182 } 183 } 184 } 185 186 type reachableChecker struct { 187 dialer Dialer 188 timeout time.Duration 189 } 190 191 // FindHost takes a list of possible host+port combinations and possible public 192 // keys that the SSH server could be using. We make an attempt to connect to 193 // each of those addresses and do an SSH handshake negotiation. We then check 194 // if the SSH server's negotiated public key is in our allowed set. The first 195 // address to successfully negotiate will be returned. If none of them succeed, 196 // and error will be returned. 197 func (r *reachableChecker) FindHost(hostPorts network.HostPorts, publicKeys []string) (network.HostPort, error) { 198 uniqueHPs := hostPorts.Unique() 199 successful := make(chan network.HostPort) 200 stop := make(chan struct{}) 201 // We use a channel instead of a sync.WaitGroup so that we can return as 202 // soon as we get one connected. We'll signal the rest to stop via the 203 // 'stop' channel. 204 finished := make(chan struct{}, len(uniqueHPs)) 205 206 acceptedKeys := publicKeysToSet(publicKeys) 207 for _, hostPort := range uniqueHPs { 208 checker := &hostKeyChecker{ 209 AcceptedKeys: acceptedKeys, 210 Stop: stop, 211 Accepted: successful, 212 HostPort: hostPort, 213 Dialer: r.dialer, 214 Finished: finished, 215 } 216 go checker.Check() 217 } 218 219 timeout := time.After(r.timeout) 220 for finishedCount := 0; finishedCount < len(uniqueHPs); { 221 select { 222 case result := <-successful: 223 logger.Infof("found %v has an acceptable ssh key", result) 224 close(stop) 225 return result, nil 226 case <-finished: 227 finishedCount++ 228 case <-timeout: 229 break 230 } 231 } 232 close(stop) 233 return nil, errors.Errorf("cannot connect to any address: %v", hostPorts) 234 }