github.com/mponton/terratest@v0.44.0/modules/ssh/ssh.go (about)

     1  // Package ssh allows to manage SSH connections and send commands through them.
     2  package ssh
     3  
     4  import (
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/hashicorp/go-multierror"
    16  	"github.com/mponton/terratest/modules/files"
    17  	"github.com/mponton/terratest/modules/logger"
    18  	"github.com/mponton/terratest/modules/retry"
    19  	"github.com/mponton/terratest/modules/testing"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  )
    23  
    24  // Host is a remote host.
    25  type Host struct {
    26  	Hostname    string // host name or ip address
    27  	SshUserName string // user name
    28  	// set one or more authentication methods,
    29  	// the first valid method will be used
    30  	SshKeyPair       *KeyPair  // ssh key pair to use as authentication method (disabled by default)
    31  	SshAgent         bool      // enable authentication using your existing local SSH agent (disabled by default)
    32  	OverrideSshAgent *SshAgent // enable an in process `SshAgent` for connections to this host (disabled by default)
    33  	Password         string    // plain text password (blank by default)
    34  	CustomPort       int       // port number to use to connect to the host (port 22 will be used if unset)
    35  }
    36  
    37  type ScpDownloadOptions struct {
    38  	FileNameFilters []string //File names to match. May include bash-style wildcards. E.g., *.log.
    39  	MaxFileSizeMB   int      //Don't grab any files > MaxFileSizeMB
    40  	RemoteDir       string   //Copy from this directory on the remote machine
    41  	LocalDir        string   //Copy RemoteDir to this directory on the local machine
    42  	RemoteHost      Host     //Connection information for the remote machine
    43  }
    44  
    45  // ScpFileToE uploads the contents using SCP to the given host and fails the test if the connection fails.
    46  func ScpFileTo(t testing.TestingT, host Host, mode os.FileMode, remotePath, contents string) {
    47  	err := ScpFileToE(t, host, mode, remotePath, contents)
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  }
    52  
    53  // ScpFileToE uploads the contents using SCP to the given host and return an error if the process fails.
    54  func ScpFileToE(t testing.TestingT, host Host, mode os.FileMode, remotePath, contents string) error {
    55  	authMethods, err := createAuthMethodsForHost(host)
    56  	if err != nil {
    57  		return err
    58  	}
    59  	dir, file := filepath.Split(remotePath)
    60  
    61  	hostOptions := SshConnectionOptions{
    62  		Username:    host.SshUserName,
    63  		Address:     host.Hostname,
    64  		Port:        host.getPort(),
    65  		Command:     "/usr/bin/scp -t " + dir,
    66  		AuthMethods: authMethods,
    67  	}
    68  
    69  	scp := sendScpCommandsToCopyFile(mode, file, contents)
    70  
    71  	sshSession := &SshSession{
    72  		Options:  &hostOptions,
    73  		JumpHost: &JumpHostSession{},
    74  		Input:    &scp,
    75  	}
    76  
    77  	defer sshSession.Cleanup(t)
    78  
    79  	_, err = runSSHCommand(t, sshSession)
    80  	return err
    81  }
    82  
    83  // ScpFileFrom downloads the file from remotePath on the given host using SCP.
    84  func ScpFileFrom(t testing.TestingT, host Host, remotePath string, localDestination *os.File, useSudo bool) {
    85  	err := ScpFileFromE(t, host, remotePath, localDestination, useSudo)
    86  
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  }
    91  
    92  // ScpFileFromE downloads the file from remotePath on the given host using SCP and returns an error if the process fails.
    93  func ScpFileFromE(t testing.TestingT, host Host, remotePath string, localDestination *os.File, useSudo bool) error {
    94  	authMethods, err := createAuthMethodsForHost(host)
    95  
    96  	if err != nil {
    97  		return err
    98  	}
    99  
   100  	dir := filepath.Dir(remotePath)
   101  
   102  	hostOptions := SshConnectionOptions{
   103  		Username:    host.SshUserName,
   104  		Address:     host.Hostname,
   105  		Port:        host.getPort(),
   106  		Command:     "/usr/bin/scp -t " + dir,
   107  		AuthMethods: authMethods,
   108  	}
   109  
   110  	sshSession := &SshSession{
   111  		Options:  &hostOptions,
   112  		JumpHost: &JumpHostSession{},
   113  	}
   114  
   115  	defer sshSession.Cleanup(t)
   116  
   117  	return copyFileFromRemote(t, sshSession, localDestination, remotePath, useSudo)
   118  }
   119  
   120  // ScpDirFrom downloads all the files from remotePath on the given host using SCP.
   121  func ScpDirFrom(t testing.TestingT, options ScpDownloadOptions, useSudo bool) {
   122  	err := ScpDirFromE(t, options, useSudo)
   123  
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  }
   128  
   129  // ScpDirFromE downloads all the files from remotePath on the given host using SCP
   130  // and returns an error if the process fails. NOTE: only files within remotePath will
   131  // be downloaded. This function will not recursively download subdirectories or follow
   132  // symlinks.
   133  func ScpDirFromE(t testing.TestingT, options ScpDownloadOptions, useSudo bool) error {
   134  	authMethods, err := createAuthMethodsForHost(options.RemoteHost)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	hostOptions := SshConnectionOptions{
   140  		Username:    options.RemoteHost.SshUserName,
   141  		Address:     options.RemoteHost.Hostname,
   142  		Port:        options.RemoteHost.getPort(),
   143  		Command:     "/usr/bin/scp -t " + options.RemoteDir,
   144  		AuthMethods: authMethods,
   145  	}
   146  
   147  	sshSession := &SshSession{
   148  		Options:  &hostOptions,
   149  		JumpHost: &JumpHostSession{},
   150  	}
   151  
   152  	defer sshSession.Cleanup(t)
   153  
   154  	filesInDir, err := listFileInRemoteDir(t, sshSession, options, useSudo)
   155  
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	if !files.FileExists(options.LocalDir) {
   161  		err := os.MkdirAll(options.LocalDir, 0755)
   162  
   163  		if err != nil {
   164  			return err
   165  		}
   166  	}
   167  
   168  	var errorsOccurred = new(multierror.Error)
   169  
   170  	for _, fullRemoteFilePath := range filesInDir {
   171  		fileName := filepath.Base(fullRemoteFilePath)
   172  
   173  		localFilePath := filepath.Join(options.LocalDir, fileName)
   174  		localFile, err := os.Create(localFilePath)
   175  
   176  		if err != nil {
   177  			return err
   178  		}
   179  
   180  		logger.Logf(t, "Copying remote file: %s to local path %s", fullRemoteFilePath, localFilePath)
   181  
   182  		err = copyFileFromRemote(t, sshSession, localFile, fullRemoteFilePath, useSudo)
   183  		errorsOccurred = multierror.Append(errorsOccurred, err)
   184  	}
   185  
   186  	return errorsOccurred.ErrorOrNil()
   187  }
   188  
   189  // CheckSshConnection checks that you can connect via SSH to the given host and fail the test if the connection fails.
   190  func CheckSshConnection(t testing.TestingT, host Host) {
   191  	err := CheckSshConnectionE(t, host)
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  }
   196  
   197  // CheckSshConnectionE checks that you can connect via SSH to the given host and return an error if the connection fails.
   198  func CheckSshConnectionE(t testing.TestingT, host Host) error {
   199  	_, err := CheckSshCommandE(t, host, "'exit'")
   200  	return err
   201  }
   202  
   203  // CheckSshConnectionWithRetry attempts to connect via SSH until max retries has been exceeded and fails the test
   204  // if the connection fails
   205  func CheckSshConnectionWithRetry(t testing.TestingT, host Host, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host) error) {
   206  	handler := CheckSshConnectionE
   207  	if f != nil {
   208  		handler = f[0]
   209  	}
   210  	err := CheckSshConnectionWithRetryE(t, host, retries, sleepBetweenRetries, handler)
   211  	if err != nil {
   212  		t.Fatal(err)
   213  	}
   214  }
   215  
   216  // CheckSshConnectionWithRetryE attempts to connect via SSH until max retries has been exceeded and returns an error if
   217  // the connection fails
   218  func CheckSshConnectionWithRetryE(t testing.TestingT, host Host, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host) error) error {
   219  	handler := CheckSshConnectionE
   220  	if f != nil {
   221  		handler = f[0]
   222  	}
   223  	_, err := retry.DoWithRetryE(t, fmt.Sprintf("Checking SSH connection to %s", host.Hostname), retries, sleepBetweenRetries, func() (string, error) {
   224  		return "", handler(t, host)
   225  	})
   226  
   227  	return err
   228  }
   229  
   230  // CheckSshCommand checks that you can connect via SSH to the given host and run the given command. Returns the stdout/stderr.
   231  func CheckSshCommand(t testing.TestingT, host Host, command string) string {
   232  	out, err := CheckSshCommandE(t, host, command)
   233  	if err != nil {
   234  		t.Fatal(err)
   235  	}
   236  	return out
   237  }
   238  
   239  // CheckSshCommandE checks that you can connect via SSH to the given host and run the given command. Returns the stdout/stderr.
   240  func CheckSshCommandE(t testing.TestingT, host Host, command string) (string, error) {
   241  	authMethods, err := createAuthMethodsForHost(host)
   242  	if err != nil {
   243  		return "", err
   244  	}
   245  
   246  	hostOptions := SshConnectionOptions{
   247  		Username:    host.SshUserName,
   248  		Address:     host.Hostname,
   249  		Port:        host.getPort(),
   250  		Command:     command,
   251  		AuthMethods: authMethods,
   252  	}
   253  
   254  	sshSession := &SshSession{
   255  		Options:  &hostOptions,
   256  		JumpHost: &JumpHostSession{},
   257  	}
   258  
   259  	defer sshSession.Cleanup(t)
   260  
   261  	return runSSHCommand(t, sshSession)
   262  }
   263  
   264  // CheckSshCommandWithRetry checks that you can connect via SSH to the given host and run the given command until max retries have been exceeded. Returns the stdout/stderr.
   265  func CheckSshCommandWithRetry(t testing.TestingT, host Host, command string, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host, string) (string, error)) string {
   266  	handler := CheckSshCommandE
   267  	if f != nil {
   268  		handler = f[0]
   269  	}
   270  	out, err := CheckSshCommandWithRetryE(t, host, command, retries, sleepBetweenRetries, handler)
   271  	if err != nil {
   272  		t.Fatal(err)
   273  	}
   274  	return out
   275  }
   276  
   277  // CheckSshCommandWithRetryE checks that you can connect via SSH to the given host and run the given command until max retries has been exceeded.
   278  // It return an error if the command fails after max retries has been exceeded.
   279  
   280  func CheckSshCommandWithRetryE(t testing.TestingT, host Host, command string, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host, string) (string, error)) (string, error) {
   281  	handler := CheckSshCommandE
   282  	if f != nil {
   283  		handler = f[0]
   284  	}
   285  	return retry.DoWithRetryE(t, fmt.Sprintf("Checking SSH connection to %s", host.Hostname), retries, sleepBetweenRetries, func() (string, error) {
   286  		return handler(t, host, command)
   287  	})
   288  }
   289  
   290  // CheckPrivateSshConnection attempts to connect to privateHost (which is not addressable from the Internet) via a
   291  // separate publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns
   292  // its output. It is useful for checking that it's possible to SSH from a Bastion Host to a private instance.
   293  func CheckPrivateSshConnection(t testing.TestingT, publicHost Host, privateHost Host, command string) string {
   294  	out, err := CheckPrivateSshConnectionE(t, publicHost, privateHost, command)
   295  	if err != nil {
   296  		t.Fatal(err)
   297  	}
   298  	return out
   299  }
   300  
   301  // CheckPrivateSshConnectionE attempts to connect to privateHost (which is not addressable from the Internet) via a
   302  // separate publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns
   303  // its output. It is useful for checking that it's possible to SSH from a Bastion Host to a private instance.
   304  func CheckPrivateSshConnectionE(t testing.TestingT, publicHost Host, privateHost Host, command string) (string, error) {
   305  	jumpHostAuthMethods, err := createAuthMethodsForHost(publicHost)
   306  	if err != nil {
   307  		return "", err
   308  	}
   309  
   310  	jumpHostOptions := SshConnectionOptions{
   311  		Username:    publicHost.SshUserName,
   312  		Address:     publicHost.Hostname,
   313  		Port:        publicHost.getPort(),
   314  		AuthMethods: jumpHostAuthMethods,
   315  	}
   316  
   317  	hostAuthMethods, err := createAuthMethodsForHost(privateHost)
   318  	if err != nil {
   319  		return "", err
   320  	}
   321  
   322  	hostOptions := SshConnectionOptions{
   323  		Username:    privateHost.SshUserName,
   324  		Address:     privateHost.Hostname,
   325  		Port:        privateHost.getPort(),
   326  		Command:     command,
   327  		AuthMethods: hostAuthMethods,
   328  		JumpHost:    &jumpHostOptions,
   329  	}
   330  
   331  	sshSession := &SshSession{
   332  		Options:  &hostOptions,
   333  		JumpHost: &JumpHostSession{},
   334  	}
   335  
   336  	defer sshSession.Cleanup(t)
   337  
   338  	return runSSHCommand(t, sshSession)
   339  }
   340  
   341  // FetchContentsOfFiles connects to the given host via SSH and fetches the contents of the files at the given filePaths.
   342  // If useSudo is true, then the contents will be retrieved using sudo. This method returns a map from file path to
   343  // contents.
   344  func FetchContentsOfFiles(t testing.TestingT, host Host, useSudo bool, filePaths ...string) map[string]string {
   345  	out, err := FetchContentsOfFilesE(t, host, useSudo, filePaths...)
   346  	if err != nil {
   347  		t.Fatal(err)
   348  	}
   349  	return out
   350  }
   351  
   352  // FetchContentsOfFilesE connects to the given host via SSH and fetches the contents of the files at the given filePaths.
   353  // If useSudo is true, then the contents will be retrieved using sudo. This method returns a map from file path to
   354  // contents.
   355  func FetchContentsOfFilesE(t testing.TestingT, host Host, useSudo bool, filePaths ...string) (map[string]string, error) {
   356  	filePathToContents := map[string]string{}
   357  
   358  	for _, filePath := range filePaths {
   359  		contents, err := FetchContentsOfFileE(t, host, useSudo, filePath)
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  
   364  		filePathToContents[filePath] = contents
   365  	}
   366  
   367  	return filePathToContents, nil
   368  }
   369  
   370  // FetchContentsOfFile connects to the given host via SSH and fetches the contents of the file at the given filePath.
   371  // If useSudo is true, then the contents will be retrieved using sudo. This method returns the contents of that file.
   372  func FetchContentsOfFile(t testing.TestingT, host Host, useSudo bool, filePath string) string {
   373  	out, err := FetchContentsOfFileE(t, host, useSudo, filePath)
   374  	if err != nil {
   375  		t.Fatal(err)
   376  	}
   377  	return out
   378  }
   379  
   380  // FetchContentsOfFileE connects to the given host via SSH and fetches the contents of the file at the given filePath.
   381  // If useSudo is true, then the contents will be retrieved using sudo. This method returns the contents of that file.
   382  func FetchContentsOfFileE(t testing.TestingT, host Host, useSudo bool, filePath string) (string, error) {
   383  	command := fmt.Sprintf("cat %s", filePath)
   384  	if useSudo {
   385  		command = fmt.Sprintf("sudo %s", command)
   386  	}
   387  
   388  	return CheckSshCommandE(t, host, command)
   389  }
   390  
   391  func listFileInRemoteDir(t testing.TestingT, sshSession *SshSession, options ScpDownloadOptions, useSudo bool) ([]string, error) {
   392  	logger.Logf(t, "Running command %s on %s@%s", sshSession.Options.Command, sshSession.Options.Username, sshSession.Options.Address)
   393  
   394  	var result []string
   395  	var findCommandArgs []string
   396  
   397  	if useSudo {
   398  		findCommandArgs = append(findCommandArgs, "sudo")
   399  	}
   400  
   401  	findCommandArgs = append(findCommandArgs, "find", options.RemoteDir)
   402  	findCommandArgs = append(findCommandArgs, "-type", "f")
   403  
   404  	filtersLength := len(options.FileNameFilters)
   405  	if options.FileNameFilters != nil && filtersLength > 0 {
   406  
   407  		findCommandArgs = append(findCommandArgs, "\\(")
   408  		for i, curFilter := range options.FileNameFilters {
   409  			// due to inconsistent bash behavior we need to wrap the
   410  			// filter in single quotes
   411  			curFilter = fmt.Sprintf("'%s'", curFilter)
   412  			findCommandArgs = append(findCommandArgs, "-name", curFilter)
   413  
   414  			// only add the or flag if we're not the last element
   415  			if filtersLength-i > 1 {
   416  				findCommandArgs = append(findCommandArgs, "-o")
   417  			}
   418  		}
   419  		findCommandArgs = append(findCommandArgs, "\\)")
   420  	}
   421  
   422  	if options.MaxFileSizeMB != 0 {
   423  		findCommandArgs = append(findCommandArgs, "-size", fmt.Sprintf("-%dM", options.MaxFileSizeMB))
   424  	}
   425  
   426  	finalCommandString := strings.Join(findCommandArgs, " ")
   427  	resultString, err := CheckSshCommandE(t, options.RemoteHost, finalCommandString)
   428  
   429  	if err != nil {
   430  		return result, err
   431  	}
   432  
   433  	// The last character returned is `\n` this results in an extra "" array
   434  	// member when we do the split below. Cut off the last character to avoid
   435  	// having to remove the blank entry in the array.
   436  	resultString = resultString[:len(resultString)-1]
   437  
   438  	result = append(result, strings.Split(resultString, "\n")...)
   439  	return result, nil
   440  }
   441  
   442  // Added based on code: https://github.com/bramvdbogaerde/go-scp/pull/6/files
   443  func copyFileFromRemote(t testing.TestingT, sshSession *SshSession, file *os.File, remotePath string, useSudo bool) error {
   444  	logger.Logf(t, "Running command %s on %s@%s", sshSession.Options.Command, sshSession.Options.Username, sshSession.Options.Address)
   445  	if err := setUpSSHClient(sshSession); err != nil {
   446  		return err
   447  	}
   448  
   449  	if err := setUpSSHSession(sshSession); err != nil {
   450  		return err
   451  	}
   452  
   453  	command := fmt.Sprintf("dd if=%s", remotePath)
   454  	if useSudo {
   455  		command = fmt.Sprintf("sudo %s", command)
   456  	}
   457  
   458  	r, err := sshSession.Session.Output(command)
   459  	if err != nil {
   460  		fmt.Printf("error reading from remote stdout: %s", err)
   461  	}
   462  	defer sshSession.Session.Close()
   463  	//write to local file
   464  	_, err = file.Write(r)
   465  
   466  	return err
   467  }
   468  
   469  func runSSHCommand(t testing.TestingT, sshSession *SshSession) (string, error) {
   470  	logger.Logf(t, "Running command %s on %s@%s", sshSession.Options.Command, sshSession.Options.Username, sshSession.Options.Address)
   471  	if err := setUpSSHClient(sshSession); err != nil {
   472  		return "", err
   473  	}
   474  
   475  	if err := setUpSSHSession(sshSession); err != nil {
   476  		return "", err
   477  	}
   478  
   479  	if sshSession.Input != nil {
   480  		w, err := sshSession.Session.StdinPipe()
   481  		if err != nil {
   482  			return "", err
   483  		}
   484  		go func() {
   485  			defer w.Close()
   486  			(*sshSession.Input)(w)
   487  		}()
   488  	}
   489  
   490  	bytes, err := sshSession.Session.CombinedOutput(sshSession.Options.Command)
   491  	if err != nil {
   492  		return string(bytes), err
   493  	}
   494  
   495  	return string(bytes), nil
   496  }
   497  
   498  func setUpSSHClient(sshSession *SshSession) error {
   499  	if sshSession.Options.JumpHost == nil {
   500  		return fillSSHClientForHost(sshSession)
   501  	}
   502  	return fillSSHClientForJumpHost(sshSession)
   503  }
   504  
   505  func fillSSHClientForHost(sshSession *SshSession) error {
   506  	client, err := createSSHClient(sshSession.Options)
   507  
   508  	if err != nil {
   509  		return err
   510  	}
   511  
   512  	sshSession.Client = client
   513  	return nil
   514  }
   515  
   516  func fillSSHClientForJumpHost(sshSession *SshSession) error {
   517  	jumpHostClient, err := createSSHClient(sshSession.Options.JumpHost)
   518  	if err != nil {
   519  		return err
   520  	}
   521  	sshSession.JumpHost.JumpHostClient = jumpHostClient
   522  
   523  	hostVirtualConn, err := jumpHostClient.Dial("tcp", sshSession.Options.ConnectionString())
   524  	if err != nil {
   525  		return err
   526  	}
   527  	sshSession.JumpHost.HostVirtualConnection = hostVirtualConn
   528  
   529  	hostConn, hostIncomingChannels, hostIncomingRequests, err := ssh.NewClientConn(hostVirtualConn, sshSession.Options.ConnectionString(), createSSHClientConfig(sshSession.Options))
   530  	if err != nil {
   531  		return err
   532  	}
   533  	sshSession.JumpHost.HostConnection = hostConn
   534  
   535  	sshSession.Client = ssh.NewClient(hostConn, hostIncomingChannels, hostIncomingRequests)
   536  	return nil
   537  }
   538  
   539  func setUpSSHSession(sshSession *SshSession) error {
   540  	session, err := sshSession.Client.NewSession()
   541  	if err != nil {
   542  		return err
   543  	}
   544  
   545  	sshSession.Session = session
   546  	return nil
   547  }
   548  
   549  func createSSHClient(options *SshConnectionOptions) (*ssh.Client, error) {
   550  	sshClientConfig := createSSHClientConfig(options)
   551  	return ssh.Dial("tcp", options.ConnectionString(), sshClientConfig)
   552  }
   553  
   554  func createSSHClientConfig(hostOptions *SshConnectionOptions) *ssh.ClientConfig {
   555  	clientConfig := &ssh.ClientConfig{
   556  		User: hostOptions.Username,
   557  		Auth: hostOptions.AuthMethods,
   558  		// Do not do a host key check, as Terratest is only used for testing, not prod
   559  		HostKeyCallback: NoOpHostKeyCallback,
   560  		// By default, Go does not impose a timeout, so a SSH connection attempt can hang for a LONG time.
   561  		Timeout: 10 * time.Second,
   562  	}
   563  	clientConfig.SetDefaults()
   564  	return clientConfig
   565  }
   566  
   567  // NoOpHostKeyCallback is an ssh.HostKeyCallback that does nothing. Only use this when you're sure you don't want to check the host key at all
   568  // (e.g., only for testing and non-production use cases).
   569  func NoOpHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
   570  	return nil
   571  }
   572  
   573  // Returns an array of authentication methods
   574  func createAuthMethodsForHost(host Host) ([]ssh.AuthMethod, error) {
   575  	var methods []ssh.AuthMethod
   576  
   577  	// override local ssh agent with given sshAgent instance
   578  	if host.OverrideSshAgent != nil {
   579  		conn, err := net.Dial("unix", host.OverrideSshAgent.socketFile)
   580  		if err != nil {
   581  			fmt.Print("Failed to dial in memory ssh agent")
   582  			return methods, err
   583  		}
   584  		agentClient := agent.NewClient(conn)
   585  		methods = append(methods, []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}...)
   586  	}
   587  
   588  	// use existing ssh agent socket
   589  	// if agent authentication is enabled and no agent is set up, returns an error
   590  	if host.SshAgent {
   591  		socket := os.Getenv("SSH_AUTH_SOCK")
   592  		conn, err := net.Dial("unix", socket)
   593  		if err != nil {
   594  			return methods, err
   595  		}
   596  		agentClient := agent.NewClient(conn)
   597  		methods = append(methods, []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}...)
   598  	}
   599  
   600  	// use provided ssh key pair
   601  	if host.SshKeyPair != nil {
   602  		signer, err := ssh.ParsePrivateKey([]byte(host.SshKeyPair.PrivateKey))
   603  		if err != nil {
   604  			return methods, err
   605  		}
   606  		methods = append(methods, []ssh.AuthMethod{ssh.PublicKeys(signer)}...)
   607  	}
   608  
   609  	// Use given password
   610  	if len(host.Password) > 0 {
   611  		methods = append(methods, []ssh.AuthMethod{ssh.Password(host.Password)}...)
   612  	}
   613  
   614  	// no valid authentication method was provided
   615  	if len(methods) < 1 {
   616  		return methods, errors.New("no authentication method defined")
   617  	}
   618  
   619  	return methods, nil
   620  }
   621  
   622  // sendScpCommandsToCopyFile returns a function which will send commands to the SCP binary to output a file on the remote machine.
   623  // A full explanation of the SCP protocol can be found at
   624  // https://web.archive.org/web/20170215184048/https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works
   625  func sendScpCommandsToCopyFile(mode os.FileMode, fileName, contents string) func(io.WriteCloser) {
   626  	return func(input io.WriteCloser) {
   627  
   628  		octalMode := "0" + strconv.FormatInt(int64(mode), 8)
   629  
   630  		// Create a file at <filename> with Unix permissions set to <octalMost> and the file will be <len(content)> bytes long.
   631  		fmt.Fprintln(input, "C"+octalMode, len(contents), fileName)
   632  
   633  		// Actually send the file
   634  		fmt.Fprint(input, contents)
   635  
   636  		// End of transfer
   637  		fmt.Fprint(input, "\x00")
   638  	}
   639  }
   640  
   641  // Gets the port that should be used to communicate with the host
   642  func (h Host) getPort() int {
   643  
   644  	//If a CustomPort is not set use standard ssh port
   645  	if h.CustomPort == 0 {
   646  		return 22
   647  	} else {
   648  		return h.CustomPort
   649  	}
   650  }