github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/network/ssh/reachable.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package ssh 5 6 import ( 7 "fmt" 8 "net" 9 "strings" 10 "time" 11 12 "github.com/juju/collections/set" 13 "github.com/juju/errors" 14 "github.com/juju/loggo" 15 "golang.org/x/crypto/ssh" 16 17 "github.com/juju/juju/network" 18 ) 19 20 var logger = loggo.GetLogger("juju.network.ssh") 21 22 // Dialer defines a Dial() method matching the signature of net.Dial(). 23 type Dialer interface { 24 Dial(network, address string) (net.Conn, error) 25 } 26 27 // ReachableChecker tries to find ssh hosts that have a public key that matches 28 // our expectations. 29 type ReachableChecker interface { 30 // FindHost tries to connect to all of the host+port combinations supplied, 31 // and tries to do an SSH key negotiation. The first successful negotiation 32 // that includes one of the public keys supplied will be returned. If none 33 // of them can be validated, then an error will be returned. 34 FindHost(hostPorts []network.HostPort, publicKeys []string) (network.HostPort, error) 35 } 36 37 // NewReachableChecker creates a ReachableChecker that can be used to check for 38 // Hosts that are viable SSH targets. 39 // When FindHost is called, we will dial the entries in the given hostPorts, in 40 // parallel, using the given dialer, closing successfully established 41 // connections after checking the ssh key. Individual connection errors are 42 // discarded, and an error is returned only if none of the hostPorts can be 43 // reached when the given timeout expires. 44 // If publicKeys is a non empty list, then the SSH host public key will be 45 // checked. If it is not in the list, that host is not considered valid. 46 // 47 // Usually, a net.Dialer initialized with a non-empty Timeout field is passed 48 // for dialer. 49 func NewReachableChecker(dialer Dialer, timeout time.Duration) *reachableChecker { 50 return &reachableChecker{ 51 dialer: dialer, 52 timeout: timeout, 53 } 54 } 55 56 // hostKeyChecker checks if this host matches one of allowed public keys 57 // it uses the golang/x/crypto/ssh/HostKeyCallback to find the host keys on a 58 // given connection. 59 type hostKeyChecker struct { 60 61 // AcceptedKeys is a set of the Marshalled PublicKey content. 62 AcceptedKeys set.Strings 63 64 // Stop will be polled for whether we should stop trying to do any work 65 Stop <-chan struct{} 66 67 // HostPort is the identifier that corresponds to this connection 68 HostPort network.HostPort 69 70 // Accepted will be passed HostPort if it validated the connection 71 Accepted chan network.HostPort 72 73 // Dialer is a Dialer that allows us to initiate the underlying TCP connection 74 Dialer Dialer 75 76 // Finished will be set an event when we've finished our check (success or failure) 77 Finished chan struct{} 78 } 79 80 var hostKeyNotInList = errors.New("host key not in expected set") 81 var hostKeyAccepted = errors.New("host key was accepted, retry") 82 var hostKeyAcceptedButStopped = errors.New("host key was accepted, but search was stopped") 83 84 func (h *hostKeyChecker) hostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { 85 // Note: we don't do any advanced checking of the PublicKey, like whether 86 // the key is revoked or expired. All we care about is whether it matches 87 // the public keys that we consider acceptable 88 authKeyForm := ssh.MarshalAuthorizedKey(key) 89 debugName := hostname 90 if hostname != remote.String() { 91 debugName = fmt.Sprintf("%s at %s", hostname, remote.String()) 92 } 93 logger.Tracef("checking host key for %s, with key %q", debugName, authKeyForm) 94 95 lookupKey := string(key.Marshal()) 96 if len(h.AcceptedKeys) == 0 || h.AcceptedKeys.Contains(lookupKey) { 97 logger.Debugf("accepted host key for: %s", debugName) 98 // This key was valid, so return it, but if someone else was found 99 // first, still exit. 100 select { 101 case h.Accepted <- h.HostPort: 102 // We have accepted a host, we won't need to call Finished. 103 h.Finished = nil 104 return hostKeyAccepted 105 case <-h.Stop: 106 return hostKeyAcceptedButStopped 107 } 108 } 109 logger.Debugf("host key for %s not in our accepted set: log at TRACE to see raw keys", debugName) 110 return hostKeyNotInList 111 } 112 113 // publicKeysToSet converts all the public key values (eg id_rsa.pub) into 114 // their short hash form. Problems with a key are logged at Warning level, but 115 // otherwise ignored. 116 func publicKeysToSet(publicKeys []string) set.Strings { 117 acceptedKeys := set.NewStrings() 118 for _, pubKey := range publicKeys { 119 // key, comment, options, rest, err 120 sshKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) 121 if err != nil { 122 logger.Warningf("unable to handle public key: %q\n", pubKey) 123 continue 124 } 125 acceptedKeys.Add(string(sshKey.Marshal())) 126 } 127 return acceptedKeys 128 } 129 130 // Check initiates a connection to HostPort and tries to do an SSH key 131 // exchange to determine the preferred public key of the remote host. 132 // It then checks if that key is in the accepted set of keys. 133 func (h *hostKeyChecker) Check() { 134 defer func() { 135 // send a finished message unless we're already stopped and nobody 136 // is listening 137 if h.Finished != nil { 138 select { 139 case h.Finished <- struct{}{}: 140 case <-h.Stop: 141 } 142 } 143 }() 144 // TODO(jam): 2017-01-24 One limitation of our algorithm, is that we don't 145 // try to limit the negotiation of the keys to our set of possible keys. 146 // For example, say we only know about the RSA key for the remote host, but 147 // it has been updated to use a ECDSA key as well. Gocrypto/ssh might 148 // negotiate to use the "more secure" ECDSA key and we will see that 149 // as an invalid key. 150 sshconfig := &ssh.ClientConfig{ 151 HostKeyCallback: h.hostKeyCallback, 152 } 153 addr := h.HostPort.NetAddr() 154 logger.Debugf("dialing %s to check host keys", addr) 155 conn, err := h.Dialer.Dial("tcp", addr) 156 if err != nil { 157 logger.Debugf("dial %s failed with: %v", addr, err) 158 return 159 } 160 // No need to do the key exchange if we're already stopping 161 select { 162 case <-h.Stop: 163 conn.Close() 164 return 165 default: 166 } 167 logger.Debugf("connected to %s, initiating ssh handshake", addr) 168 // NewClientConn will close the underlying net.Conn if it gets an error 169 client, _, _, err := ssh.NewClientConn(conn, addr, sshconfig) 170 if err == nil { 171 // We don't expect this case, because we don't support Auth, 172 // but make sure to close it anyway. 173 client.Close() 174 } else { 175 // no need to log these two messages, that's already been done 176 // in hostKeyCallback 177 if !strings.Contains(err.Error(), hostKeyAccepted.Error()) && 178 !strings.Contains(err.Error(), hostKeyNotInList.Error()) { 179 logger.Debugf("%v", err) 180 } 181 } 182 } 183 184 type reachableChecker struct { 185 dialer Dialer 186 timeout time.Duration 187 } 188 189 // FindHost takes a list of possible host+port combinations and possible public 190 // keys that the SSH server could be using. We make an attempt to connect to 191 // each of those addresses and do an SSH handshake negotiation. We then check 192 // if the SSH server's negotiated public key is in our allowed set. The first 193 // address to successfully negotiate will be returned. If none of them succeed, 194 // and error will be returned. 195 func (r *reachableChecker) FindHost(hostPorts []network.HostPort, publicKeys []string) (network.HostPort, error) { 196 uniqueHPs := network.UniqueHostPorts(hostPorts) 197 successful := make(chan network.HostPort) 198 stop := make(chan struct{}) 199 // We use a channel instead of a sync.WaitGroup so that we can return as 200 // soon as we get one connected. We'll signal the rest to stop via the 201 // 'stop' channel. 202 finished := make(chan struct{}, len(uniqueHPs)) 203 204 acceptedKeys := publicKeysToSet(publicKeys) 205 for _, hostPort := range uniqueHPs { 206 checker := &hostKeyChecker{ 207 AcceptedKeys: acceptedKeys, 208 Stop: stop, 209 Accepted: successful, 210 HostPort: hostPort, 211 Dialer: r.dialer, 212 Finished: finished, 213 } 214 go checker.Check() 215 } 216 217 timeout := time.After(r.timeout) 218 for finishedCount := 0; finishedCount < len(uniqueHPs); { 219 select { 220 case result := <-successful: 221 logger.Infof("found %v has an acceptable ssh key", result) 222 close(stop) 223 return result, nil 224 case <-finished: 225 finishedCount++ 226 case <-timeout: 227 break 228 } 229 } 230 close(stop) 231 return network.HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts) 232 }