github.com/mponton/terratest@v0.44.0/modules/ssh/ssh.go (about) 1 // Package ssh allows to manage SSH connections and send commands through them. 2 package ssh 3 4 import ( 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "os" 10 "path/filepath" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/hashicorp/go-multierror" 16 "github.com/mponton/terratest/modules/files" 17 "github.com/mponton/terratest/modules/logger" 18 "github.com/mponton/terratest/modules/retry" 19 "github.com/mponton/terratest/modules/testing" 20 "golang.org/x/crypto/ssh" 21 "golang.org/x/crypto/ssh/agent" 22 ) 23 24 // Host is a remote host. 25 type Host struct { 26 Hostname string // host name or ip address 27 SshUserName string // user name 28 // set one or more authentication methods, 29 // the first valid method will be used 30 SshKeyPair *KeyPair // ssh key pair to use as authentication method (disabled by default) 31 SshAgent bool // enable authentication using your existing local SSH agent (disabled by default) 32 OverrideSshAgent *SshAgent // enable an in process `SshAgent` for connections to this host (disabled by default) 33 Password string // plain text password (blank by default) 34 CustomPort int // port number to use to connect to the host (port 22 will be used if unset) 35 } 36 37 type ScpDownloadOptions struct { 38 FileNameFilters []string //File names to match. May include bash-style wildcards. E.g., *.log. 39 MaxFileSizeMB int //Don't grab any files > MaxFileSizeMB 40 RemoteDir string //Copy from this directory on the remote machine 41 LocalDir string //Copy RemoteDir to this directory on the local machine 42 RemoteHost Host //Connection information for the remote machine 43 } 44 45 // ScpFileToE uploads the contents using SCP to the given host and fails the test if the connection fails. 46 func ScpFileTo(t testing.TestingT, host Host, mode os.FileMode, remotePath, contents string) { 47 err := ScpFileToE(t, host, mode, remotePath, contents) 48 if err != nil { 49 t.Fatal(err) 50 } 51 } 52 53 // ScpFileToE uploads the contents using SCP to the given host and return an error if the process fails. 54 func ScpFileToE(t testing.TestingT, host Host, mode os.FileMode, remotePath, contents string) error { 55 authMethods, err := createAuthMethodsForHost(host) 56 if err != nil { 57 return err 58 } 59 dir, file := filepath.Split(remotePath) 60 61 hostOptions := SshConnectionOptions{ 62 Username: host.SshUserName, 63 Address: host.Hostname, 64 Port: host.getPort(), 65 Command: "/usr/bin/scp -t " + dir, 66 AuthMethods: authMethods, 67 } 68 69 scp := sendScpCommandsToCopyFile(mode, file, contents) 70 71 sshSession := &SshSession{ 72 Options: &hostOptions, 73 JumpHost: &JumpHostSession{}, 74 Input: &scp, 75 } 76 77 defer sshSession.Cleanup(t) 78 79 _, err = runSSHCommand(t, sshSession) 80 return err 81 } 82 83 // ScpFileFrom downloads the file from remotePath on the given host using SCP. 84 func ScpFileFrom(t testing.TestingT, host Host, remotePath string, localDestination *os.File, useSudo bool) { 85 err := ScpFileFromE(t, host, remotePath, localDestination, useSudo) 86 87 if err != nil { 88 t.Fatal(err) 89 } 90 } 91 92 // ScpFileFromE downloads the file from remotePath on the given host using SCP and returns an error if the process fails. 93 func ScpFileFromE(t testing.TestingT, host Host, remotePath string, localDestination *os.File, useSudo bool) error { 94 authMethods, err := createAuthMethodsForHost(host) 95 96 if err != nil { 97 return err 98 } 99 100 dir := filepath.Dir(remotePath) 101 102 hostOptions := SshConnectionOptions{ 103 Username: host.SshUserName, 104 Address: host.Hostname, 105 Port: host.getPort(), 106 Command: "/usr/bin/scp -t " + dir, 107 AuthMethods: authMethods, 108 } 109 110 sshSession := &SshSession{ 111 Options: &hostOptions, 112 JumpHost: &JumpHostSession{}, 113 } 114 115 defer sshSession.Cleanup(t) 116 117 return copyFileFromRemote(t, sshSession, localDestination, remotePath, useSudo) 118 } 119 120 // ScpDirFrom downloads all the files from remotePath on the given host using SCP. 121 func ScpDirFrom(t testing.TestingT, options ScpDownloadOptions, useSudo bool) { 122 err := ScpDirFromE(t, options, useSudo) 123 124 if err != nil { 125 t.Fatal(err) 126 } 127 } 128 129 // ScpDirFromE downloads all the files from remotePath on the given host using SCP 130 // and returns an error if the process fails. NOTE: only files within remotePath will 131 // be downloaded. This function will not recursively download subdirectories or follow 132 // symlinks. 133 func ScpDirFromE(t testing.TestingT, options ScpDownloadOptions, useSudo bool) error { 134 authMethods, err := createAuthMethodsForHost(options.RemoteHost) 135 if err != nil { 136 return err 137 } 138 139 hostOptions := SshConnectionOptions{ 140 Username: options.RemoteHost.SshUserName, 141 Address: options.RemoteHost.Hostname, 142 Port: options.RemoteHost.getPort(), 143 Command: "/usr/bin/scp -t " + options.RemoteDir, 144 AuthMethods: authMethods, 145 } 146 147 sshSession := &SshSession{ 148 Options: &hostOptions, 149 JumpHost: &JumpHostSession{}, 150 } 151 152 defer sshSession.Cleanup(t) 153 154 filesInDir, err := listFileInRemoteDir(t, sshSession, options, useSudo) 155 156 if err != nil { 157 return err 158 } 159 160 if !files.FileExists(options.LocalDir) { 161 err := os.MkdirAll(options.LocalDir, 0755) 162 163 if err != nil { 164 return err 165 } 166 } 167 168 var errorsOccurred = new(multierror.Error) 169 170 for _, fullRemoteFilePath := range filesInDir { 171 fileName := filepath.Base(fullRemoteFilePath) 172 173 localFilePath := filepath.Join(options.LocalDir, fileName) 174 localFile, err := os.Create(localFilePath) 175 176 if err != nil { 177 return err 178 } 179 180 logger.Logf(t, "Copying remote file: %s to local path %s", fullRemoteFilePath, localFilePath) 181 182 err = copyFileFromRemote(t, sshSession, localFile, fullRemoteFilePath, useSudo) 183 errorsOccurred = multierror.Append(errorsOccurred, err) 184 } 185 186 return errorsOccurred.ErrorOrNil() 187 } 188 189 // CheckSshConnection checks that you can connect via SSH to the given host and fail the test if the connection fails. 190 func CheckSshConnection(t testing.TestingT, host Host) { 191 err := CheckSshConnectionE(t, host) 192 if err != nil { 193 t.Fatal(err) 194 } 195 } 196 197 // CheckSshConnectionE checks that you can connect via SSH to the given host and return an error if the connection fails. 198 func CheckSshConnectionE(t testing.TestingT, host Host) error { 199 _, err := CheckSshCommandE(t, host, "'exit'") 200 return err 201 } 202 203 // CheckSshConnectionWithRetry attempts to connect via SSH until max retries has been exceeded and fails the test 204 // if the connection fails 205 func CheckSshConnectionWithRetry(t testing.TestingT, host Host, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host) error) { 206 handler := CheckSshConnectionE 207 if f != nil { 208 handler = f[0] 209 } 210 err := CheckSshConnectionWithRetryE(t, host, retries, sleepBetweenRetries, handler) 211 if err != nil { 212 t.Fatal(err) 213 } 214 } 215 216 // CheckSshConnectionWithRetryE attempts to connect via SSH until max retries has been exceeded and returns an error if 217 // the connection fails 218 func CheckSshConnectionWithRetryE(t testing.TestingT, host Host, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host) error) error { 219 handler := CheckSshConnectionE 220 if f != nil { 221 handler = f[0] 222 } 223 _, err := retry.DoWithRetryE(t, fmt.Sprintf("Checking SSH connection to %s", host.Hostname), retries, sleepBetweenRetries, func() (string, error) { 224 return "", handler(t, host) 225 }) 226 227 return err 228 } 229 230 // CheckSshCommand checks that you can connect via SSH to the given host and run the given command. Returns the stdout/stderr. 231 func CheckSshCommand(t testing.TestingT, host Host, command string) string { 232 out, err := CheckSshCommandE(t, host, command) 233 if err != nil { 234 t.Fatal(err) 235 } 236 return out 237 } 238 239 // CheckSshCommandE checks that you can connect via SSH to the given host and run the given command. Returns the stdout/stderr. 240 func CheckSshCommandE(t testing.TestingT, host Host, command string) (string, error) { 241 authMethods, err := createAuthMethodsForHost(host) 242 if err != nil { 243 return "", err 244 } 245 246 hostOptions := SshConnectionOptions{ 247 Username: host.SshUserName, 248 Address: host.Hostname, 249 Port: host.getPort(), 250 Command: command, 251 AuthMethods: authMethods, 252 } 253 254 sshSession := &SshSession{ 255 Options: &hostOptions, 256 JumpHost: &JumpHostSession{}, 257 } 258 259 defer sshSession.Cleanup(t) 260 261 return runSSHCommand(t, sshSession) 262 } 263 264 // CheckSshCommandWithRetry checks that you can connect via SSH to the given host and run the given command until max retries have been exceeded. Returns the stdout/stderr. 265 func CheckSshCommandWithRetry(t testing.TestingT, host Host, command string, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host, string) (string, error)) string { 266 handler := CheckSshCommandE 267 if f != nil { 268 handler = f[0] 269 } 270 out, err := CheckSshCommandWithRetryE(t, host, command, retries, sleepBetweenRetries, handler) 271 if err != nil { 272 t.Fatal(err) 273 } 274 return out 275 } 276 277 // CheckSshCommandWithRetryE checks that you can connect via SSH to the given host and run the given command until max retries has been exceeded. 278 // It return an error if the command fails after max retries has been exceeded. 279 280 func CheckSshCommandWithRetryE(t testing.TestingT, host Host, command string, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host, string) (string, error)) (string, error) { 281 handler := CheckSshCommandE 282 if f != nil { 283 handler = f[0] 284 } 285 return retry.DoWithRetryE(t, fmt.Sprintf("Checking SSH connection to %s", host.Hostname), retries, sleepBetweenRetries, func() (string, error) { 286 return handler(t, host, command) 287 }) 288 } 289 290 // CheckPrivateSshConnection attempts to connect to privateHost (which is not addressable from the Internet) via a 291 // separate publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns 292 // its output. It is useful for checking that it's possible to SSH from a Bastion Host to a private instance. 293 func CheckPrivateSshConnection(t testing.TestingT, publicHost Host, privateHost Host, command string) string { 294 out, err := CheckPrivateSshConnectionE(t, publicHost, privateHost, command) 295 if err != nil { 296 t.Fatal(err) 297 } 298 return out 299 } 300 301 // CheckPrivateSshConnectionE attempts to connect to privateHost (which is not addressable from the Internet) via a 302 // separate publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns 303 // its output. It is useful for checking that it's possible to SSH from a Bastion Host to a private instance. 304 func CheckPrivateSshConnectionE(t testing.TestingT, publicHost Host, privateHost Host, command string) (string, error) { 305 jumpHostAuthMethods, err := createAuthMethodsForHost(publicHost) 306 if err != nil { 307 return "", err 308 } 309 310 jumpHostOptions := SshConnectionOptions{ 311 Username: publicHost.SshUserName, 312 Address: publicHost.Hostname, 313 Port: publicHost.getPort(), 314 AuthMethods: jumpHostAuthMethods, 315 } 316 317 hostAuthMethods, err := createAuthMethodsForHost(privateHost) 318 if err != nil { 319 return "", err 320 } 321 322 hostOptions := SshConnectionOptions{ 323 Username: privateHost.SshUserName, 324 Address: privateHost.Hostname, 325 Port: privateHost.getPort(), 326 Command: command, 327 AuthMethods: hostAuthMethods, 328 JumpHost: &jumpHostOptions, 329 } 330 331 sshSession := &SshSession{ 332 Options: &hostOptions, 333 JumpHost: &JumpHostSession{}, 334 } 335 336 defer sshSession.Cleanup(t) 337 338 return runSSHCommand(t, sshSession) 339 } 340 341 // FetchContentsOfFiles connects to the given host via SSH and fetches the contents of the files at the given filePaths. 342 // If useSudo is true, then the contents will be retrieved using sudo. This method returns a map from file path to 343 // contents. 344 func FetchContentsOfFiles(t testing.TestingT, host Host, useSudo bool, filePaths ...string) map[string]string { 345 out, err := FetchContentsOfFilesE(t, host, useSudo, filePaths...) 346 if err != nil { 347 t.Fatal(err) 348 } 349 return out 350 } 351 352 // FetchContentsOfFilesE connects to the given host via SSH and fetches the contents of the files at the given filePaths. 353 // If useSudo is true, then the contents will be retrieved using sudo. This method returns a map from file path to 354 // contents. 355 func FetchContentsOfFilesE(t testing.TestingT, host Host, useSudo bool, filePaths ...string) (map[string]string, error) { 356 filePathToContents := map[string]string{} 357 358 for _, filePath := range filePaths { 359 contents, err := FetchContentsOfFileE(t, host, useSudo, filePath) 360 if err != nil { 361 return nil, err 362 } 363 364 filePathToContents[filePath] = contents 365 } 366 367 return filePathToContents, nil 368 } 369 370 // FetchContentsOfFile connects to the given host via SSH and fetches the contents of the file at the given filePath. 371 // If useSudo is true, then the contents will be retrieved using sudo. This method returns the contents of that file. 372 func FetchContentsOfFile(t testing.TestingT, host Host, useSudo bool, filePath string) string { 373 out, err := FetchContentsOfFileE(t, host, useSudo, filePath) 374 if err != nil { 375 t.Fatal(err) 376 } 377 return out 378 } 379 380 // FetchContentsOfFileE connects to the given host via SSH and fetches the contents of the file at the given filePath. 381 // If useSudo is true, then the contents will be retrieved using sudo. This method returns the contents of that file. 382 func FetchContentsOfFileE(t testing.TestingT, host Host, useSudo bool, filePath string) (string, error) { 383 command := fmt.Sprintf("cat %s", filePath) 384 if useSudo { 385 command = fmt.Sprintf("sudo %s", command) 386 } 387 388 return CheckSshCommandE(t, host, command) 389 } 390 391 func listFileInRemoteDir(t testing.TestingT, sshSession *SshSession, options ScpDownloadOptions, useSudo bool) ([]string, error) { 392 logger.Logf(t, "Running command %s on %s@%s", sshSession.Options.Command, sshSession.Options.Username, sshSession.Options.Address) 393 394 var result []string 395 var findCommandArgs []string 396 397 if useSudo { 398 findCommandArgs = append(findCommandArgs, "sudo") 399 } 400 401 findCommandArgs = append(findCommandArgs, "find", options.RemoteDir) 402 findCommandArgs = append(findCommandArgs, "-type", "f") 403 404 filtersLength := len(options.FileNameFilters) 405 if options.FileNameFilters != nil && filtersLength > 0 { 406 407 findCommandArgs = append(findCommandArgs, "\\(") 408 for i, curFilter := range options.FileNameFilters { 409 // due to inconsistent bash behavior we need to wrap the 410 // filter in single quotes 411 curFilter = fmt.Sprintf("'%s'", curFilter) 412 findCommandArgs = append(findCommandArgs, "-name", curFilter) 413 414 // only add the or flag if we're not the last element 415 if filtersLength-i > 1 { 416 findCommandArgs = append(findCommandArgs, "-o") 417 } 418 } 419 findCommandArgs = append(findCommandArgs, "\\)") 420 } 421 422 if options.MaxFileSizeMB != 0 { 423 findCommandArgs = append(findCommandArgs, "-size", fmt.Sprintf("-%dM", options.MaxFileSizeMB)) 424 } 425 426 finalCommandString := strings.Join(findCommandArgs, " ") 427 resultString, err := CheckSshCommandE(t, options.RemoteHost, finalCommandString) 428 429 if err != nil { 430 return result, err 431 } 432 433 // The last character returned is `\n` this results in an extra "" array 434 // member when we do the split below. Cut off the last character to avoid 435 // having to remove the blank entry in the array. 436 resultString = resultString[:len(resultString)-1] 437 438 result = append(result, strings.Split(resultString, "\n")...) 439 return result, nil 440 } 441 442 // Added based on code: https://github.com/bramvdbogaerde/go-scp/pull/6/files 443 func copyFileFromRemote(t testing.TestingT, sshSession *SshSession, file *os.File, remotePath string, useSudo bool) error { 444 logger.Logf(t, "Running command %s on %s@%s", sshSession.Options.Command, sshSession.Options.Username, sshSession.Options.Address) 445 if err := setUpSSHClient(sshSession); err != nil { 446 return err 447 } 448 449 if err := setUpSSHSession(sshSession); err != nil { 450 return err 451 } 452 453 command := fmt.Sprintf("dd if=%s", remotePath) 454 if useSudo { 455 command = fmt.Sprintf("sudo %s", command) 456 } 457 458 r, err := sshSession.Session.Output(command) 459 if err != nil { 460 fmt.Printf("error reading from remote stdout: %s", err) 461 } 462 defer sshSession.Session.Close() 463 //write to local file 464 _, err = file.Write(r) 465 466 return err 467 } 468 469 func runSSHCommand(t testing.TestingT, sshSession *SshSession) (string, error) { 470 logger.Logf(t, "Running command %s on %s@%s", sshSession.Options.Command, sshSession.Options.Username, sshSession.Options.Address) 471 if err := setUpSSHClient(sshSession); err != nil { 472 return "", err 473 } 474 475 if err := setUpSSHSession(sshSession); err != nil { 476 return "", err 477 } 478 479 if sshSession.Input != nil { 480 w, err := sshSession.Session.StdinPipe() 481 if err != nil { 482 return "", err 483 } 484 go func() { 485 defer w.Close() 486 (*sshSession.Input)(w) 487 }() 488 } 489 490 bytes, err := sshSession.Session.CombinedOutput(sshSession.Options.Command) 491 if err != nil { 492 return string(bytes), err 493 } 494 495 return string(bytes), nil 496 } 497 498 func setUpSSHClient(sshSession *SshSession) error { 499 if sshSession.Options.JumpHost == nil { 500 return fillSSHClientForHost(sshSession) 501 } 502 return fillSSHClientForJumpHost(sshSession) 503 } 504 505 func fillSSHClientForHost(sshSession *SshSession) error { 506 client, err := createSSHClient(sshSession.Options) 507 508 if err != nil { 509 return err 510 } 511 512 sshSession.Client = client 513 return nil 514 } 515 516 func fillSSHClientForJumpHost(sshSession *SshSession) error { 517 jumpHostClient, err := createSSHClient(sshSession.Options.JumpHost) 518 if err != nil { 519 return err 520 } 521 sshSession.JumpHost.JumpHostClient = jumpHostClient 522 523 hostVirtualConn, err := jumpHostClient.Dial("tcp", sshSession.Options.ConnectionString()) 524 if err != nil { 525 return err 526 } 527 sshSession.JumpHost.HostVirtualConnection = hostVirtualConn 528 529 hostConn, hostIncomingChannels, hostIncomingRequests, err := ssh.NewClientConn(hostVirtualConn, sshSession.Options.ConnectionString(), createSSHClientConfig(sshSession.Options)) 530 if err != nil { 531 return err 532 } 533 sshSession.JumpHost.HostConnection = hostConn 534 535 sshSession.Client = ssh.NewClient(hostConn, hostIncomingChannels, hostIncomingRequests) 536 return nil 537 } 538 539 func setUpSSHSession(sshSession *SshSession) error { 540 session, err := sshSession.Client.NewSession() 541 if err != nil { 542 return err 543 } 544 545 sshSession.Session = session 546 return nil 547 } 548 549 func createSSHClient(options *SshConnectionOptions) (*ssh.Client, error) { 550 sshClientConfig := createSSHClientConfig(options) 551 return ssh.Dial("tcp", options.ConnectionString(), sshClientConfig) 552 } 553 554 func createSSHClientConfig(hostOptions *SshConnectionOptions) *ssh.ClientConfig { 555 clientConfig := &ssh.ClientConfig{ 556 User: hostOptions.Username, 557 Auth: hostOptions.AuthMethods, 558 // Do not do a host key check, as Terratest is only used for testing, not prod 559 HostKeyCallback: NoOpHostKeyCallback, 560 // By default, Go does not impose a timeout, so a SSH connection attempt can hang for a LONG time. 561 Timeout: 10 * time.Second, 562 } 563 clientConfig.SetDefaults() 564 return clientConfig 565 } 566 567 // NoOpHostKeyCallback is an ssh.HostKeyCallback that does nothing. Only use this when you're sure you don't want to check the host key at all 568 // (e.g., only for testing and non-production use cases). 569 func NoOpHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { 570 return nil 571 } 572 573 // Returns an array of authentication methods 574 func createAuthMethodsForHost(host Host) ([]ssh.AuthMethod, error) { 575 var methods []ssh.AuthMethod 576 577 // override local ssh agent with given sshAgent instance 578 if host.OverrideSshAgent != nil { 579 conn, err := net.Dial("unix", host.OverrideSshAgent.socketFile) 580 if err != nil { 581 fmt.Print("Failed to dial in memory ssh agent") 582 return methods, err 583 } 584 agentClient := agent.NewClient(conn) 585 methods = append(methods, []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}...) 586 } 587 588 // use existing ssh agent socket 589 // if agent authentication is enabled and no agent is set up, returns an error 590 if host.SshAgent { 591 socket := os.Getenv("SSH_AUTH_SOCK") 592 conn, err := net.Dial("unix", socket) 593 if err != nil { 594 return methods, err 595 } 596 agentClient := agent.NewClient(conn) 597 methods = append(methods, []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}...) 598 } 599 600 // use provided ssh key pair 601 if host.SshKeyPair != nil { 602 signer, err := ssh.ParsePrivateKey([]byte(host.SshKeyPair.PrivateKey)) 603 if err != nil { 604 return methods, err 605 } 606 methods = append(methods, []ssh.AuthMethod{ssh.PublicKeys(signer)}...) 607 } 608 609 // Use given password 610 if len(host.Password) > 0 { 611 methods = append(methods, []ssh.AuthMethod{ssh.Password(host.Password)}...) 612 } 613 614 // no valid authentication method was provided 615 if len(methods) < 1 { 616 return methods, errors.New("no authentication method defined") 617 } 618 619 return methods, nil 620 } 621 622 // sendScpCommandsToCopyFile returns a function which will send commands to the SCP binary to output a file on the remote machine. 623 // A full explanation of the SCP protocol can be found at 624 // https://web.archive.org/web/20170215184048/https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works 625 func sendScpCommandsToCopyFile(mode os.FileMode, fileName, contents string) func(io.WriteCloser) { 626 return func(input io.WriteCloser) { 627 628 octalMode := "0" + strconv.FormatInt(int64(mode), 8) 629 630 // Create a file at <filename> with Unix permissions set to <octalMost> and the file will be <len(content)> bytes long. 631 fmt.Fprintln(input, "C"+octalMode, len(contents), fileName) 632 633 // Actually send the file 634 fmt.Fprint(input, contents) 635 636 // End of transfer 637 fmt.Fprint(input, "\x00") 638 } 639 } 640 641 // Gets the port that should be used to communicate with the host 642 func (h Host) getPort() int { 643 644 //If a CustomPort is not set use standard ssh port 645 if h.CustomPort == 0 { 646 return 22 647 } else { 648 return h.CustomPort 649 } 650 }