github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/ssh.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package sshutils defines several functions and types used across the
    18  // Teleport API and other Teleport packages when working with SSH.
    19  package sshutils
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"crypto"
    25  	"crypto/subtle"
    26  	"errors"
    27  	"io"
    28  	"net"
    29  	"regexp"
    30  	"strings"
    31  
    32  	"github.com/gravitational/trace"
    33  	"golang.org/x/crypto/ssh"
    34  
    35  	"github.com/gravitational/teleport/api/defaults"
    36  )
    37  
    38  // HandshakePayload structure is sent as a JSON blob by the teleport
    39  // proxy to every SSH server who identifies itself as Teleport server
    40  //
    41  // It allows teleport proxies to communicate additional data to server
    42  type HandshakePayload struct {
    43  	// ClientAddr is the IP address of the remote client
    44  	ClientAddr string `json:"clientAddr,omitempty"`
    45  	// TracingContext contains tracing information so that spans can be correlated
    46  	// across ssh boundaries
    47  	TracingContext map[string]string `json:"tracingContext,omitempty"`
    48  }
    49  
    50  // ParseCertificate parses an SSH certificate from the authorized_keys format.
    51  func ParseCertificate(buf []byte) (*ssh.Certificate, error) {
    52  	k, _, _, _, err := ssh.ParseAuthorizedKey(buf)
    53  	if err != nil {
    54  		return nil, trace.Wrap(err)
    55  	}
    56  
    57  	cert, ok := k.(*ssh.Certificate)
    58  	if !ok {
    59  		return nil, trace.BadParameter("not an SSH certificate")
    60  	}
    61  
    62  	return cert, nil
    63  }
    64  
    65  // ParseKnownHosts parses provided known_hosts entries into ssh.PublicKey list.
    66  // If one or more hostnames are provided, only keys that have at least one match
    67  // will be returned.
    68  func ParseKnownHosts(knownHosts [][]byte, matchHostnames ...string) ([]ssh.PublicKey, error) {
    69  	var keys []ssh.PublicKey
    70  	for _, line := range knownHosts {
    71  		for {
    72  			_, hosts, publicKey, _, bytes, err := ssh.ParseKnownHosts(line)
    73  			if errors.Is(err, io.EOF) {
    74  				break
    75  			} else if err != nil {
    76  				return nil, trace.Wrap(err, "failed parsing known hosts: %v; raw line: %q", err, line)
    77  			}
    78  
    79  			if len(matchHostnames) == 0 || HostNameMatch(matchHostnames, hosts) {
    80  				keys = append(keys, publicKey)
    81  			}
    82  
    83  			line = bytes
    84  		}
    85  	}
    86  	return keys, nil
    87  }
    88  
    89  // HostNameMatch returns whether at least one of the given hosts matches one
    90  // of the given matchHosts. If a host has a wildcard prefix "*.", it will be
    91  // used to match. Ex: "*.example.com" will  match "proxy.example.com".
    92  func HostNameMatch(matchHosts []string, hosts []string) bool {
    93  	for _, matchHost := range matchHosts {
    94  		for _, host := range hosts {
    95  			if host == matchHost || matchesWildcard(matchHost, host) {
    96  				return true
    97  			}
    98  		}
    99  	}
   100  	return false
   101  }
   102  
   103  // matchesWildcard ensures the given `hostname` matches the given `pattern`.
   104  // The `pattern` should be prefixed with `*.` which will match exactly one domain
   105  // segment, meaning `*.example.com` will match `foo.example.com` but not
   106  // `foo.bar.example.com`.
   107  func matchesWildcard(hostname, pattern string) bool {
   108  	pattern = strings.TrimSpace(pattern)
   109  
   110  	// Don't allow non-wildcard or empty patterns.
   111  	if !strings.HasPrefix(pattern, "*.") || len(pattern) < 3 {
   112  		return false
   113  	}
   114  	matchHost := pattern[2:]
   115  
   116  	// Trim any trailing "." in case of an absolute domain.
   117  	hostname = strings.TrimSuffix(hostname, ".")
   118  
   119  	_, hostnameRoot, found := strings.Cut(hostname, ".")
   120  	if !found {
   121  		return false
   122  	}
   123  
   124  	return hostnameRoot == matchHost
   125  }
   126  
   127  // ParseAuthorizedKeys parses provided authorized_keys entries into ssh.PublicKey list.
   128  func ParseAuthorizedKeys(authorizedKeys [][]byte) ([]ssh.PublicKey, error) {
   129  	var keys []ssh.PublicKey
   130  	for _, line := range authorizedKeys {
   131  		publicKey, _, _, _, err := ssh.ParseAuthorizedKey(line)
   132  		if err != nil {
   133  			return nil, trace.Wrap(err, "failed parsing authorized keys: %v; raw line: %q", err, line)
   134  		}
   135  		keys = append(keys, publicKey)
   136  	}
   137  	return keys, nil
   138  }
   139  
   140  // ProxyClientSSHConfig returns an ssh.ClientConfig from the given ssh.AuthMethod.
   141  // If known_hosts are provided, they will be used in the config's HostKeyCallback.
   142  //
   143  // The config is set up to authenticate to proxy with the first available principal.
   144  func ProxyClientSSHConfig(sshCert *ssh.Certificate, priv crypto.Signer, knownHosts ...[]byte) (*ssh.ClientConfig, error) {
   145  	authMethod, err := AsAuthMethod(sshCert, priv)
   146  	if err != nil {
   147  		return nil, trace.Wrap(err)
   148  	}
   149  
   150  	cfg := &ssh.ClientConfig{
   151  		Auth:    []ssh.AuthMethod{authMethod},
   152  		Timeout: defaults.DefaultIOTimeout,
   153  	}
   154  
   155  	// The KeyId is not always a valid principal, so we use the first valid principal instead.
   156  	cfg.User = sshCert.KeyId
   157  	if len(sshCert.ValidPrincipals) > 0 {
   158  		cfg.User = sshCert.ValidPrincipals[0]
   159  	}
   160  
   161  	if len(knownHosts) > 0 {
   162  		trustedKeys, err := ParseKnownHosts(knownHosts)
   163  		if err != nil {
   164  			return nil, trace.Wrap(err)
   165  		}
   166  
   167  		cfg.HostKeyCallback, err = HostKeyCallback(trustedKeys, false)
   168  		if err != nil {
   169  			return nil, trace.Wrap(err, "failed to convert certificate authorities to HostKeyCallback")
   170  		}
   171  	}
   172  
   173  	return cfg, nil
   174  }
   175  
   176  // SSHSigner returns an ssh.Signer from certificate and private key
   177  func SSHSigner(sshCert *ssh.Certificate, signer crypto.Signer) (ssh.Signer, error) {
   178  	sshSigner, err := ssh.NewSignerFromKey(signer)
   179  	if err != nil {
   180  		return nil, trace.Wrap(err)
   181  	}
   182  	sshSigner, err = ssh.NewCertSigner(sshCert, sshSigner)
   183  	if err != nil {
   184  		return nil, trace.Wrap(err)
   185  	}
   186  	return sshSigner, nil
   187  }
   188  
   189  // AsAuthMethod returns an "auth method" interface, a common abstraction
   190  // used by Golang SSH library. This is how you actually use a Key to feed
   191  // it into the SSH lib.
   192  func AsAuthMethod(sshCert *ssh.Certificate, signer crypto.Signer) (ssh.AuthMethod, error) {
   193  	sshSigner, err := SSHSigner(sshCert, signer)
   194  	if err != nil {
   195  		return nil, trace.Wrap(err)
   196  	}
   197  	return ssh.PublicKeys(sshSigner), nil
   198  }
   199  
   200  // HostKeyCallback returns an ssh.HostKeyCallback that validates host
   201  // keys/certs against trusted host keys, usually associated with trusted CAs.
   202  //
   203  // If no trusted keys are provided, the returned ssh.HostKeyCallback is nil.
   204  // This causes golang.org/x/crypto/ssh to prompt the user to verify host key
   205  // fingerprint (same as OpenSSH does for an unknown host).
   206  func HostKeyCallback(trustedKeys []ssh.PublicKey, withHostKeyFallback bool) (ssh.HostKeyCallback, error) {
   207  	// No trusted keys are provided, return a nil callback which will prompt the user for trust.
   208  	if len(trustedKeys) == 0 {
   209  		return nil, nil
   210  	}
   211  
   212  	callbackConfig := HostKeyCallbackConfig{
   213  		GetHostCheckers: func() ([]ssh.PublicKey, error) {
   214  			return trustedKeys, nil
   215  		},
   216  	}
   217  
   218  	if withHostKeyFallback {
   219  		callbackConfig.HostKeyFallback = hostKeyFallbackFunc(trustedKeys)
   220  	}
   221  
   222  	callback, err := NewHostKeyCallback(callbackConfig)
   223  	if err != nil {
   224  		return nil, trace.Wrap(err)
   225  	}
   226  
   227  	return callback, nil
   228  }
   229  
   230  func hostKeyFallbackFunc(knownHosts []ssh.PublicKey) func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   231  	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   232  		for _, knownHost := range knownHosts {
   233  			if KeysEqual(key, knownHost) {
   234  				return nil
   235  			}
   236  		}
   237  		return trace.AccessDenied("host %v presented a public key instead of a host certificate which isn't among known hosts", hostname)
   238  	}
   239  }
   240  
   241  // KeysEqual is constant time compare of the keys to avoid timing attacks
   242  func KeysEqual(ak, bk ssh.PublicKey) bool {
   243  	a := ak.Marshal()
   244  	b := bk.Marshal()
   245  	return subtle.ConstantTimeCompare(a, b) == 1
   246  }
   247  
   248  // OpenSSH cert types look like "<key-type>-cert-v<version>@openssh.com".
   249  var sshCertTypeRegex = regexp.MustCompile(`^[a-z0-9\-]+-cert-v[0-9]{2}@openssh\.com$`)
   250  
   251  // IsSSHCertType checks if the given string looks like an ssh cert type.
   252  // e.g. ssh-rsa-cert-v01@openssh.com.
   253  func IsSSHCertType(val string) bool {
   254  	return sshCertTypeRegex.MatchString(val)
   255  }
   256  
   257  type contextDialer func(ctx context.Context, network, addr string) (net.Conn, error)
   258  
   259  type runSSHOpts struct {
   260  	dialContext contextDialer
   261  }
   262  
   263  // RunSSHOption allows setting options as functional arguments to RunSSH.
   264  type RunSSHOption func(*runSSHOpts)
   265  
   266  // WithDialer connects to an SSH server with a custom dialer.
   267  func WithDialer(dialer contextDialer) RunSSHOption {
   268  	return func(opts *runSSHOpts) {
   269  		opts.dialContext = dialer
   270  	}
   271  }
   272  
   273  // RunSSH runs a command on an SSH server and returns the output.
   274  func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, opts ...RunSSHOption) ([]byte, []byte, error) {
   275  	var options runSSHOpts
   276  	for _, opt := range opts {
   277  		opt(&options)
   278  	}
   279  
   280  	conn, err := options.dialContext(ctx, "tcp", addr)
   281  	if err != nil {
   282  		return nil, nil, trace.Wrap(err)
   283  	}
   284  
   285  	clientConn, newCh, requestsCh, err := ssh.NewClientConn(conn, addr, cfg)
   286  	if err != nil {
   287  		return nil, nil, trace.Wrap(err)
   288  	}
   289  	sshClient := ssh.NewClient(clientConn, newCh, requestsCh)
   290  	defer sshClient.Close()
   291  	session, err := sshClient.NewSession()
   292  	if err != nil {
   293  		return nil, nil, trace.Wrap(err)
   294  	}
   295  	defer session.Close()
   296  
   297  	// Execute the command.
   298  	var stdout bytes.Buffer
   299  	session.Stdout = &stdout
   300  	var stderr bytes.Buffer
   301  	session.Stderr = &stderr
   302  	err = session.Run(command)
   303  	return stdout.Bytes(), stderr.Bytes(), trace.Wrap(err)
   304  }
   305  
   306  // ChannelReadWriter represents the data streams of an ssh.Channel-like object.
   307  type ChannelReadWriter interface {
   308  	io.ReadWriter
   309  	Stderr() io.ReadWriter
   310  }
   311  
   312  // DiscardChannelData discards all data received from an ssh channel in the
   313  // background.
   314  func DiscardChannelData(ch ChannelReadWriter) {
   315  	if ch == nil {
   316  		return
   317  	}
   318  	go io.Copy(io.Discard, ch)
   319  	go io.Copy(io.Discard, ch.Stderr())
   320  }