k8s.io/kubernetes@v1.29.3/test/e2e/framework/ssh/ssh.go (about)

     1  /*
     2  Copyright 2018 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package ssh
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"net"
    24  	"os"
    25  	"path/filepath"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/onsi/gomega"
    30  
    31  	"golang.org/x/crypto/ssh"
    32  
    33  	v1 "k8s.io/api/core/v1"
    34  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    35  	"k8s.io/apimachinery/pkg/fields"
    36  	"k8s.io/apimachinery/pkg/util/wait"
    37  	clientset "k8s.io/client-go/kubernetes"
    38  	"k8s.io/kubernetes/test/e2e/framework"
    39  )
    40  
    41  const (
    42  	// SSHPort is tcp port number of SSH
    43  	SSHPort = "22"
    44  
    45  	// pollNodeInterval is how often to Poll pods.
    46  	pollNodeInterval = 2 * time.Second
    47  
    48  	// singleCallTimeout is how long to try single API calls (like 'get' or 'list'). Used to prevent
    49  	// transient failures from failing tests.
    50  	singleCallTimeout = 5 * time.Minute
    51  
    52  	// sshBastionEnvKey is the environment variable key for running SSH commands via bastion.
    53  	sshBastionEnvKey = "KUBE_SSH_BASTION"
    54  )
    55  
    56  // GetSigner returns an ssh.Signer for the provider ("gce", etc.) that can be
    57  // used to SSH to their nodes.
    58  func GetSigner(provider string) (ssh.Signer, error) {
    59  	// honor a consistent SSH key across all providers
    60  	if path := os.Getenv("KUBE_SSH_KEY_PATH"); len(path) > 0 {
    61  		return makePrivateKeySignerFromFile(path)
    62  	}
    63  
    64  	// Select the key itself to use. When implementing more providers here,
    65  	// please also add them to any SSH tests that are disabled because of signer
    66  	// support.
    67  	keyfile := ""
    68  	switch provider {
    69  	case "gce", "gke", "kubemark":
    70  		keyfile = os.Getenv("GCE_SSH_KEY")
    71  		if keyfile == "" {
    72  			keyfile = "google_compute_engine"
    73  		}
    74  	case "aws", "eks":
    75  		keyfile = os.Getenv("AWS_SSH_KEY")
    76  		if keyfile == "" {
    77  			keyfile = "kube_aws_rsa"
    78  		}
    79  	case "local", "vsphere":
    80  		keyfile = os.Getenv("LOCAL_SSH_KEY")
    81  		if keyfile == "" {
    82  			keyfile = "id_rsa"
    83  		}
    84  	case "skeleton":
    85  		keyfile = os.Getenv("KUBE_SSH_KEY")
    86  		if keyfile == "" {
    87  			keyfile = "id_rsa"
    88  		}
    89  	case "azure":
    90  		keyfile = os.Getenv("AZURE_SSH_KEY")
    91  		if keyfile == "" {
    92  			keyfile = "id_rsa"
    93  		}
    94  	default:
    95  		return nil, fmt.Errorf("GetSigner(...) not implemented for %s", provider)
    96  	}
    97  
    98  	// Respect absolute paths for keys given by user, fallback to assuming
    99  	// relative paths are in ~/.ssh
   100  	if !filepath.IsAbs(keyfile) {
   101  		keydir := filepath.Join(os.Getenv("HOME"), ".ssh")
   102  		keyfile = filepath.Join(keydir, keyfile)
   103  	}
   104  
   105  	return makePrivateKeySignerFromFile(keyfile)
   106  }
   107  
   108  func makePrivateKeySignerFromFile(key string) (ssh.Signer, error) {
   109  	buffer, err := os.ReadFile(key)
   110  	if err != nil {
   111  		return nil, fmt.Errorf("error reading SSH key %s: %w", key, err)
   112  	}
   113  
   114  	signer, err := ssh.ParsePrivateKey(buffer)
   115  	if err != nil {
   116  		return nil, fmt.Errorf("error parsing SSH key: %w", err)
   117  	}
   118  
   119  	return signer, err
   120  }
   121  
   122  // NodeSSHHosts returns SSH-able host names for all schedulable nodes.
   123  // If it can't find any external IPs, it falls back to
   124  // looking for internal IPs. If it can't find an internal IP for every node it
   125  // returns an error, though it still returns all hosts that it found in that
   126  // case.
   127  func NodeSSHHosts(ctx context.Context, c clientset.Interface) ([]string, error) {
   128  	nodelist := waitListSchedulableNodesOrDie(ctx, c)
   129  
   130  	hosts := nodeAddresses(nodelist, v1.NodeExternalIP)
   131  	// If  ExternalIPs aren't available for all nodes, try falling back to the InternalIPs.
   132  	if len(hosts) < len(nodelist.Items) {
   133  		framework.Logf("No external IP address on nodes, falling back to internal IPs")
   134  		hosts = nodeAddresses(nodelist, v1.NodeInternalIP)
   135  	}
   136  
   137  	// Error if neither External nor Internal IPs weren't available for all nodes.
   138  	if len(hosts) != len(nodelist.Items) {
   139  		return hosts, fmt.Errorf(
   140  			"only found %d IPs on nodes, but found %d nodes. Nodelist: %v",
   141  			len(hosts), len(nodelist.Items), nodelist)
   142  	}
   143  
   144  	lenHosts := len(hosts)
   145  	wg := &sync.WaitGroup{}
   146  	wg.Add(lenHosts)
   147  	sshHosts := make([]string, 0, lenHosts)
   148  	var sshHostsLock sync.Mutex
   149  
   150  	for _, host := range hosts {
   151  		go func(host string) {
   152  			defer wg.Done()
   153  			if canConnect(host) {
   154  				framework.Logf("Assuming SSH on host %s", host)
   155  				sshHostsLock.Lock()
   156  				sshHosts = append(sshHosts, net.JoinHostPort(host, SSHPort))
   157  				sshHostsLock.Unlock()
   158  			} else {
   159  				framework.Logf("Skipping host %s because it does not run anything on port %s", host, SSHPort)
   160  			}
   161  		}(host)
   162  	}
   163  	wg.Wait()
   164  
   165  	return sshHosts, nil
   166  }
   167  
   168  // canConnect returns true if a network connection is possible to the SSHPort.
   169  func canConnect(host string) bool {
   170  	if _, ok := os.LookupEnv(sshBastionEnvKey); ok {
   171  		return true
   172  	}
   173  	hostPort := net.JoinHostPort(host, SSHPort)
   174  	conn, err := net.DialTimeout("tcp", hostPort, 3*time.Second)
   175  	if err != nil {
   176  		framework.Logf("cannot dial %s: %v", hostPort, err)
   177  		return false
   178  	}
   179  	conn.Close()
   180  	return true
   181  }
   182  
   183  // Result holds the execution result of SSH command
   184  type Result struct {
   185  	User   string
   186  	Host   string
   187  	Cmd    string
   188  	Stdout string
   189  	Stderr string
   190  	Code   int
   191  }
   192  
   193  // NodeExec execs the given cmd on node via SSH. Note that the nodeName is an sshable name,
   194  // eg: the name returned by framework.GetMasterHost(). This is also not guaranteed to work across
   195  // cloud providers since it involves ssh.
   196  func NodeExec(ctx context.Context, nodeName, cmd, provider string) (Result, error) {
   197  	return SSH(ctx, cmd, net.JoinHostPort(nodeName, SSHPort), provider)
   198  }
   199  
   200  // SSH synchronously SSHs to a node running on provider and runs cmd. If there
   201  // is no error performing the SSH, the stdout, stderr, and exit code are
   202  // returned.
   203  func SSH(ctx context.Context, cmd, host, provider string) (Result, error) {
   204  	result := Result{Host: host, Cmd: cmd}
   205  
   206  	// Get a signer for the provider.
   207  	signer, err := GetSigner(provider)
   208  	if err != nil {
   209  		return result, fmt.Errorf("error getting signer for provider %s: %w", provider, err)
   210  	}
   211  
   212  	// RunSSHCommand will default to Getenv("USER") if user == "", but we're
   213  	// defaulting here as well for logging clarity.
   214  	result.User = os.Getenv("KUBE_SSH_USER")
   215  	if result.User == "" {
   216  		result.User = os.Getenv("USER")
   217  	}
   218  
   219  	if bastion := os.Getenv(sshBastionEnvKey); len(bastion) > 0 {
   220  		stdout, stderr, code, err := runSSHCommandViaBastion(ctx, cmd, result.User, bastion, host, signer)
   221  		result.Stdout = stdout
   222  		result.Stderr = stderr
   223  		result.Code = code
   224  		return result, err
   225  	}
   226  
   227  	stdout, stderr, code, err := runSSHCommand(ctx, cmd, result.User, host, signer)
   228  	result.Stdout = stdout
   229  	result.Stderr = stderr
   230  	result.Code = code
   231  
   232  	return result, err
   233  }
   234  
   235  // runSSHCommandViaBastion returns the stdout, stderr, and exit code from running cmd on
   236  // host as specific user, along with any SSH-level error.
   237  func runSSHCommand(ctx context.Context, cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
   238  	if user == "" {
   239  		user = os.Getenv("USER")
   240  	}
   241  	// Setup the config, dial the server, and open a session.
   242  	config := &ssh.ClientConfig{
   243  		User:            user,
   244  		Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
   245  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
   246  	}
   247  	client, err := ssh.Dial("tcp", host, config)
   248  	if err != nil {
   249  		err = wait.PollWithContext(ctx, 5*time.Second, 20*time.Second, func(ctx context.Context) (bool, error) {
   250  			fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err)
   251  			if client, err = ssh.Dial("tcp", host, config); err != nil {
   252  				return false, nil // retrying, error will be logged above
   253  			}
   254  			return true, nil
   255  		})
   256  	}
   257  	if err != nil {
   258  		return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: %w", user, host, err)
   259  	}
   260  	defer client.Close()
   261  	session, err := client.NewSession()
   262  	if err != nil {
   263  		return "", "", 0, fmt.Errorf("error creating session to %s@%s: %w", user, host, err)
   264  	}
   265  	defer session.Close()
   266  
   267  	// Run the command.
   268  	code := 0
   269  	var bout, berr bytes.Buffer
   270  	session.Stdout, session.Stderr = &bout, &berr
   271  	if err = session.Run(cmd); err != nil {
   272  		// Check whether the command failed to run or didn't complete.
   273  		if exiterr, ok := err.(*ssh.ExitError); ok {
   274  			// If we got an ExitError and the exit code is nonzero, we'll
   275  			// consider the SSH itself successful (just that the command run
   276  			// errored on the host).
   277  			if code = exiterr.ExitStatus(); code != 0 {
   278  				err = nil
   279  			}
   280  		} else {
   281  			// Some other kind of error happened (e.g. an IOError); consider the
   282  			// SSH unsuccessful.
   283  			err = fmt.Errorf("failed running `%s` on %s@%s: %w", cmd, user, host, err)
   284  		}
   285  	}
   286  	return bout.String(), berr.String(), code, err
   287  }
   288  
   289  // runSSHCommandViaBastion returns the stdout, stderr, and exit code from running cmd on
   290  // host as specific user, along with any SSH-level error. It uses an SSH proxy to connect
   291  // to bastion, then via that tunnel connects to the remote host. Similar to
   292  // sshutil.RunSSHCommand but scoped to the needs of the test infrastructure.
   293  func runSSHCommandViaBastion(ctx context.Context, cmd, user, bastion, host string, signer ssh.Signer) (string, string, int, error) {
   294  	// Setup the config, dial the server, and open a session.
   295  	config := &ssh.ClientConfig{
   296  		User:            user,
   297  		Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
   298  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
   299  		Timeout:         150 * time.Second,
   300  	}
   301  	bastionClient, err := ssh.Dial("tcp", bastion, config)
   302  	if err != nil {
   303  		err = wait.PollWithContext(ctx, 5*time.Second, 20*time.Second, func(ctx context.Context) (bool, error) {
   304  			fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, bastion, err)
   305  			if bastionClient, err = ssh.Dial("tcp", bastion, config); err != nil {
   306  				return false, err
   307  			}
   308  			return true, nil
   309  		})
   310  	}
   311  	if err != nil {
   312  		return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: %w", user, bastion, err)
   313  	}
   314  	defer bastionClient.Close()
   315  
   316  	conn, err := bastionClient.Dial("tcp", host)
   317  	if err != nil {
   318  		return "", "", 0, fmt.Errorf("error dialing %s from bastion: %w", host, err)
   319  	}
   320  	defer conn.Close()
   321  
   322  	ncc, chans, reqs, err := ssh.NewClientConn(conn, host, config)
   323  	if err != nil {
   324  		return "", "", 0, fmt.Errorf("error creating forwarding connection %s from bastion: %w", host, err)
   325  	}
   326  	client := ssh.NewClient(ncc, chans, reqs)
   327  	defer client.Close()
   328  
   329  	session, err := client.NewSession()
   330  	if err != nil {
   331  		return "", "", 0, fmt.Errorf("error creating session to %s@%s from bastion: %w", user, host, err)
   332  	}
   333  	defer session.Close()
   334  
   335  	// Run the command.
   336  	code := 0
   337  	var bout, berr bytes.Buffer
   338  	session.Stdout, session.Stderr = &bout, &berr
   339  	if err = session.Run(cmd); err != nil {
   340  		// Check whether the command failed to run or didn't complete.
   341  		if exiterr, ok := err.(*ssh.ExitError); ok {
   342  			// If we got an ExitError and the exit code is nonzero, we'll
   343  			// consider the SSH itself successful (just that the command run
   344  			// errored on the host).
   345  			if code = exiterr.ExitStatus(); code != 0 {
   346  				err = nil
   347  			}
   348  		} else {
   349  			// Some other kind of error happened (e.g. an IOError); consider the
   350  			// SSH unsuccessful.
   351  			err = fmt.Errorf("failed running `%s` on %s@%s: %w", cmd, user, host, err)
   352  		}
   353  	}
   354  	return bout.String(), berr.String(), code, err
   355  }
   356  
   357  // LogResult records result log
   358  func LogResult(result Result) {
   359  	remote := fmt.Sprintf("%s@%s", result.User, result.Host)
   360  	framework.Logf("ssh %s: command:   %s", remote, result.Cmd)
   361  	framework.Logf("ssh %s: stdout:    %q", remote, result.Stdout)
   362  	framework.Logf("ssh %s: stderr:    %q", remote, result.Stderr)
   363  	framework.Logf("ssh %s: exit code: %d", remote, result.Code)
   364  }
   365  
   366  // IssueSSHCommandWithResult tries to execute a SSH command and returns the execution result
   367  func IssueSSHCommandWithResult(ctx context.Context, cmd, provider string, node *v1.Node) (*Result, error) {
   368  	framework.Logf("Getting external IP address for %s", node.Name)
   369  	host := ""
   370  	for _, a := range node.Status.Addresses {
   371  		if a.Type == v1.NodeExternalIP && a.Address != "" {
   372  			host = net.JoinHostPort(a.Address, SSHPort)
   373  			break
   374  		}
   375  	}
   376  
   377  	if host == "" {
   378  		// No external IPs were found, let's try to use internal as plan B
   379  		for _, a := range node.Status.Addresses {
   380  			if a.Type == v1.NodeInternalIP && a.Address != "" {
   381  				host = net.JoinHostPort(a.Address, SSHPort)
   382  				break
   383  			}
   384  		}
   385  	}
   386  
   387  	if host == "" {
   388  		return nil, fmt.Errorf("couldn't find any IP address for node %s", node.Name)
   389  	}
   390  
   391  	framework.Logf("SSH %q on %s(%s)", cmd, node.Name, host)
   392  	result, err := SSH(ctx, cmd, host, provider)
   393  	LogResult(result)
   394  
   395  	if result.Code != 0 || err != nil {
   396  		return nil, fmt.Errorf("failed running %q: %v (exit code %d, stderr %v)",
   397  			cmd, err, result.Code, result.Stderr)
   398  	}
   399  
   400  	return &result, nil
   401  }
   402  
   403  // IssueSSHCommand tries to execute a SSH command
   404  func IssueSSHCommand(ctx context.Context, cmd, provider string, node *v1.Node) error {
   405  	_, err := IssueSSHCommandWithResult(ctx, cmd, provider, node)
   406  	if err != nil {
   407  		return err
   408  	}
   409  	return nil
   410  }
   411  
   412  // nodeAddresses returns the first address of the given type of each node.
   413  func nodeAddresses(nodelist *v1.NodeList, addrType v1.NodeAddressType) []string {
   414  	hosts := []string{}
   415  	for _, n := range nodelist.Items {
   416  		for _, addr := range n.Status.Addresses {
   417  			if addr.Type == addrType && addr.Address != "" {
   418  				hosts = append(hosts, addr.Address)
   419  				break
   420  			}
   421  		}
   422  	}
   423  	return hosts
   424  }
   425  
   426  // waitListSchedulableNodes is a wrapper around listing nodes supporting retries.
   427  func waitListSchedulableNodes(ctx context.Context, c clientset.Interface) (*v1.NodeList, error) {
   428  	var nodes *v1.NodeList
   429  	var err error
   430  	if wait.PollUntilContextTimeout(ctx, pollNodeInterval, singleCallTimeout, true, func(ctx context.Context) (bool, error) {
   431  		nodes, err = c.CoreV1().Nodes().List(ctx, metav1.ListOptions{FieldSelector: fields.Set{
   432  			"spec.unschedulable": "false",
   433  		}.AsSelector().String()})
   434  		if err != nil {
   435  			return false, err
   436  		}
   437  		return true, nil
   438  	}) != nil {
   439  		return nodes, err
   440  	}
   441  	return nodes, nil
   442  }
   443  
   444  // waitListSchedulableNodesOrDie is a wrapper around listing nodes supporting retries.
   445  func waitListSchedulableNodesOrDie(ctx context.Context, c clientset.Interface) *v1.NodeList {
   446  	nodes, err := waitListSchedulableNodes(ctx, c)
   447  	if err != nil {
   448  		expectNoError(err, "Non-retryable failure or timed out while listing nodes for e2e cluster.")
   449  	}
   450  	return nodes
   451  }
   452  
   453  // expectNoError checks if "err" is set, and if so, fails assertion while logging the error.
   454  func expectNoError(err error, explain ...interface{}) {
   455  	expectNoErrorWithOffset(1, err, explain...)
   456  }
   457  
   458  // expectNoErrorWithOffset checks if "err" is set, and if so, fails assertion while logging the error at "offset" levels above its caller
   459  // (for example, for call chain f -> g -> ExpectNoErrorWithOffset(1, ...) error would be logged for "f").
   460  func expectNoErrorWithOffset(offset int, err error, explain ...interface{}) {
   461  	if err != nil {
   462  		framework.Logf("Unexpected error occurred: %v", err)
   463  	}
   464  	gomega.ExpectWithOffset(1+offset, err).NotTo(gomega.HaveOccurred(), explain...)
   465  }