github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/ssh.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // Package sshutils defines several functions and types used across the 18 // Teleport API and other Teleport packages when working with SSH. 19 package sshutils 20 21 import ( 22 "bytes" 23 "context" 24 "crypto" 25 "crypto/subtle" 26 "errors" 27 "io" 28 "net" 29 "regexp" 30 "strings" 31 32 "github.com/gravitational/trace" 33 "golang.org/x/crypto/ssh" 34 35 "github.com/gravitational/teleport/api/defaults" 36 ) 37 38 // HandshakePayload structure is sent as a JSON blob by the teleport 39 // proxy to every SSH server who identifies itself as Teleport server 40 // 41 // It allows teleport proxies to communicate additional data to server 42 type HandshakePayload struct { 43 // ClientAddr is the IP address of the remote client 44 ClientAddr string `json:"clientAddr,omitempty"` 45 // TracingContext contains tracing information so that spans can be correlated 46 // across ssh boundaries 47 TracingContext map[string]string `json:"tracingContext,omitempty"` 48 } 49 50 // ParseCertificate parses an SSH certificate from the authorized_keys format. 51 func ParseCertificate(buf []byte) (*ssh.Certificate, error) { 52 k, _, _, _, err := ssh.ParseAuthorizedKey(buf) 53 if err != nil { 54 return nil, trace.Wrap(err) 55 } 56 57 cert, ok := k.(*ssh.Certificate) 58 if !ok { 59 return nil, trace.BadParameter("not an SSH certificate") 60 } 61 62 return cert, nil 63 } 64 65 // ParseKnownHosts parses provided known_hosts entries into ssh.PublicKey list. 66 // If one or more hostnames are provided, only keys that have at least one match 67 // will be returned. 68 func ParseKnownHosts(knownHosts [][]byte, matchHostnames ...string) ([]ssh.PublicKey, error) { 69 var keys []ssh.PublicKey 70 for _, line := range knownHosts { 71 for { 72 _, hosts, publicKey, _, bytes, err := ssh.ParseKnownHosts(line) 73 if errors.Is(err, io.EOF) { 74 break 75 } else if err != nil { 76 return nil, trace.Wrap(err, "failed parsing known hosts: %v; raw line: %q", err, line) 77 } 78 79 if len(matchHostnames) == 0 || HostNameMatch(matchHostnames, hosts) { 80 keys = append(keys, publicKey) 81 } 82 83 line = bytes 84 } 85 } 86 return keys, nil 87 } 88 89 // HostNameMatch returns whether at least one of the given hosts matches one 90 // of the given matchHosts. If a host has a wildcard prefix "*.", it will be 91 // used to match. Ex: "*.example.com" will match "proxy.example.com". 92 func HostNameMatch(matchHosts []string, hosts []string) bool { 93 for _, matchHost := range matchHosts { 94 for _, host := range hosts { 95 if host == matchHost || matchesWildcard(matchHost, host) { 96 return true 97 } 98 } 99 } 100 return false 101 } 102 103 // matchesWildcard ensures the given `hostname` matches the given `pattern`. 104 // The `pattern` should be prefixed with `*.` which will match exactly one domain 105 // segment, meaning `*.example.com` will match `foo.example.com` but not 106 // `foo.bar.example.com`. 107 func matchesWildcard(hostname, pattern string) bool { 108 pattern = strings.TrimSpace(pattern) 109 110 // Don't allow non-wildcard or empty patterns. 111 if !strings.HasPrefix(pattern, "*.") || len(pattern) < 3 { 112 return false 113 } 114 matchHost := pattern[2:] 115 116 // Trim any trailing "." in case of an absolute domain. 117 hostname = strings.TrimSuffix(hostname, ".") 118 119 _, hostnameRoot, found := strings.Cut(hostname, ".") 120 if !found { 121 return false 122 } 123 124 return hostnameRoot == matchHost 125 } 126 127 // ParseAuthorizedKeys parses provided authorized_keys entries into ssh.PublicKey list. 128 func ParseAuthorizedKeys(authorizedKeys [][]byte) ([]ssh.PublicKey, error) { 129 var keys []ssh.PublicKey 130 for _, line := range authorizedKeys { 131 publicKey, _, _, _, err := ssh.ParseAuthorizedKey(line) 132 if err != nil { 133 return nil, trace.Wrap(err, "failed parsing authorized keys: %v; raw line: %q", err, line) 134 } 135 keys = append(keys, publicKey) 136 } 137 return keys, nil 138 } 139 140 // ProxyClientSSHConfig returns an ssh.ClientConfig from the given ssh.AuthMethod. 141 // If known_hosts are provided, they will be used in the config's HostKeyCallback. 142 // 143 // The config is set up to authenticate to proxy with the first available principal. 144 func ProxyClientSSHConfig(sshCert *ssh.Certificate, priv crypto.Signer, knownHosts ...[]byte) (*ssh.ClientConfig, error) { 145 authMethod, err := AsAuthMethod(sshCert, priv) 146 if err != nil { 147 return nil, trace.Wrap(err) 148 } 149 150 cfg := &ssh.ClientConfig{ 151 Auth: []ssh.AuthMethod{authMethod}, 152 Timeout: defaults.DefaultIOTimeout, 153 } 154 155 // The KeyId is not always a valid principal, so we use the first valid principal instead. 156 cfg.User = sshCert.KeyId 157 if len(sshCert.ValidPrincipals) > 0 { 158 cfg.User = sshCert.ValidPrincipals[0] 159 } 160 161 if len(knownHosts) > 0 { 162 trustedKeys, err := ParseKnownHosts(knownHosts) 163 if err != nil { 164 return nil, trace.Wrap(err) 165 } 166 167 cfg.HostKeyCallback, err = HostKeyCallback(trustedKeys, false) 168 if err != nil { 169 return nil, trace.Wrap(err, "failed to convert certificate authorities to HostKeyCallback") 170 } 171 } 172 173 return cfg, nil 174 } 175 176 // SSHSigner returns an ssh.Signer from certificate and private key 177 func SSHSigner(sshCert *ssh.Certificate, signer crypto.Signer) (ssh.Signer, error) { 178 sshSigner, err := ssh.NewSignerFromKey(signer) 179 if err != nil { 180 return nil, trace.Wrap(err) 181 } 182 sshSigner, err = ssh.NewCertSigner(sshCert, sshSigner) 183 if err != nil { 184 return nil, trace.Wrap(err) 185 } 186 return sshSigner, nil 187 } 188 189 // AsAuthMethod returns an "auth method" interface, a common abstraction 190 // used by Golang SSH library. This is how you actually use a Key to feed 191 // it into the SSH lib. 192 func AsAuthMethod(sshCert *ssh.Certificate, signer crypto.Signer) (ssh.AuthMethod, error) { 193 sshSigner, err := SSHSigner(sshCert, signer) 194 if err != nil { 195 return nil, trace.Wrap(err) 196 } 197 return ssh.PublicKeys(sshSigner), nil 198 } 199 200 // HostKeyCallback returns an ssh.HostKeyCallback that validates host 201 // keys/certs against trusted host keys, usually associated with trusted CAs. 202 // 203 // If no trusted keys are provided, the returned ssh.HostKeyCallback is nil. 204 // This causes golang.org/x/crypto/ssh to prompt the user to verify host key 205 // fingerprint (same as OpenSSH does for an unknown host). 206 func HostKeyCallback(trustedKeys []ssh.PublicKey, withHostKeyFallback bool) (ssh.HostKeyCallback, error) { 207 // No trusted keys are provided, return a nil callback which will prompt the user for trust. 208 if len(trustedKeys) == 0 { 209 return nil, nil 210 } 211 212 callbackConfig := HostKeyCallbackConfig{ 213 GetHostCheckers: func() ([]ssh.PublicKey, error) { 214 return trustedKeys, nil 215 }, 216 } 217 218 if withHostKeyFallback { 219 callbackConfig.HostKeyFallback = hostKeyFallbackFunc(trustedKeys) 220 } 221 222 callback, err := NewHostKeyCallback(callbackConfig) 223 if err != nil { 224 return nil, trace.Wrap(err) 225 } 226 227 return callback, nil 228 } 229 230 func hostKeyFallbackFunc(knownHosts []ssh.PublicKey) func(hostname string, remote net.Addr, key ssh.PublicKey) error { 231 return func(hostname string, remote net.Addr, key ssh.PublicKey) error { 232 for _, knownHost := range knownHosts { 233 if KeysEqual(key, knownHost) { 234 return nil 235 } 236 } 237 return trace.AccessDenied("host %v presented a public key instead of a host certificate which isn't among known hosts", hostname) 238 } 239 } 240 241 // KeysEqual is constant time compare of the keys to avoid timing attacks 242 func KeysEqual(ak, bk ssh.PublicKey) bool { 243 a := ak.Marshal() 244 b := bk.Marshal() 245 return subtle.ConstantTimeCompare(a, b) == 1 246 } 247 248 // OpenSSH cert types look like "<key-type>-cert-v<version>@openssh.com". 249 var sshCertTypeRegex = regexp.MustCompile(`^[a-z0-9\-]+-cert-v[0-9]{2}@openssh\.com$`) 250 251 // IsSSHCertType checks if the given string looks like an ssh cert type. 252 // e.g. ssh-rsa-cert-v01@openssh.com. 253 func IsSSHCertType(val string) bool { 254 return sshCertTypeRegex.MatchString(val) 255 } 256 257 type contextDialer func(ctx context.Context, network, addr string) (net.Conn, error) 258 259 type runSSHOpts struct { 260 dialContext contextDialer 261 } 262 263 // RunSSHOption allows setting options as functional arguments to RunSSH. 264 type RunSSHOption func(*runSSHOpts) 265 266 // WithDialer connects to an SSH server with a custom dialer. 267 func WithDialer(dialer contextDialer) RunSSHOption { 268 return func(opts *runSSHOpts) { 269 opts.dialContext = dialer 270 } 271 } 272 273 // RunSSH runs a command on an SSH server and returns the output. 274 func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, opts ...RunSSHOption) ([]byte, []byte, error) { 275 var options runSSHOpts 276 for _, opt := range opts { 277 opt(&options) 278 } 279 280 conn, err := options.dialContext(ctx, "tcp", addr) 281 if err != nil { 282 return nil, nil, trace.Wrap(err) 283 } 284 285 clientConn, newCh, requestsCh, err := ssh.NewClientConn(conn, addr, cfg) 286 if err != nil { 287 return nil, nil, trace.Wrap(err) 288 } 289 sshClient := ssh.NewClient(clientConn, newCh, requestsCh) 290 defer sshClient.Close() 291 session, err := sshClient.NewSession() 292 if err != nil { 293 return nil, nil, trace.Wrap(err) 294 } 295 defer session.Close() 296 297 // Execute the command. 298 var stdout bytes.Buffer 299 session.Stdout = &stdout 300 var stderr bytes.Buffer 301 session.Stderr = &stderr 302 err = session.Run(command) 303 return stdout.Bytes(), stderr.Bytes(), trace.Wrap(err) 304 } 305 306 // ChannelReadWriter represents the data streams of an ssh.Channel-like object. 307 type ChannelReadWriter interface { 308 io.ReadWriter 309 Stderr() io.ReadWriter 310 } 311 312 // DiscardChannelData discards all data received from an ssh channel in the 313 // background. 314 func DiscardChannelData(ch ChannelReadWriter) { 315 if ch == nil { 316 return 317 } 318 go io.Copy(io.Discard, ch) 319 go io.Copy(io.Discard, ch.Stderr()) 320 }