github.com/kjmkznr/terraform@v0.5.2-0.20180216194316-1d0f5fdac99e/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "math/rand" 12 "net" 13 "os" 14 "path/filepath" 15 "strconv" 16 "strings" 17 "sync" 18 "time" 19 20 "github.com/hashicorp/terraform/communicator/remote" 21 "github.com/hashicorp/terraform/terraform" 22 "golang.org/x/crypto/ssh" 23 "golang.org/x/crypto/ssh/agent" 24 ) 25 26 const ( 27 // DefaultShebang is added at the top of a SSH script file 28 DefaultShebang = "#!/bin/sh\n" 29 ) 30 31 // randShared is a global random generator object that is shared. 32 // This must be shared since it is seeded by the current time and 33 // creating multiple can result in the same values. By using a shared 34 // RNG we assure different numbers per call. 35 var randLock sync.Mutex 36 var randShared *rand.Rand 37 38 // Communicator represents the SSH communicator 39 type Communicator struct { 40 connInfo *connectionInfo 41 client *ssh.Client 42 config *sshConfig 43 conn net.Conn 44 address string 45 46 lock sync.Mutex 47 } 48 49 type sshConfig struct { 50 // The configuration of the Go SSH connection 51 config *ssh.ClientConfig 52 53 // connection returns a new connection. The current connection 54 // in use will be closed as part of the Close method, or in the 55 // case an error occurs. 56 connection func() (net.Conn, error) 57 58 // noPty, if true, will not request a pty from the remote end. 59 noPty bool 60 61 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 62 // to the SSH Agent. It is nil if no SSH agent is configured 63 sshAgent *sshAgent 64 } 65 66 type fatalError struct { 67 error 68 } 69 70 func (e fatalError) FatalError() error { 71 return e.error 72 } 73 74 // New creates a new communicator implementation over SSH. 75 func New(s *terraform.InstanceState) (*Communicator, error) { 76 connInfo, err := parseConnectionInfo(s) 77 if err != nil { 78 return nil, err 79 } 80 81 config, err := prepareSSHConfig(connInfo) 82 if err != nil { 83 return nil, err 84 } 85 86 // Setup the random number generator once. The seed value is the 87 // time multiplied by the PID. This can overflow the int64 but that 88 // is okay. We multiply by the PID in case we have multiple processes 89 // grabbing this at the same time. This is possible with Terraform and 90 // if we communicate to the same host at the same instance, we could 91 // overwrite the same files. Multiplying by the PID prevents this. 92 randLock.Lock() 93 defer randLock.Unlock() 94 if randShared == nil { 95 randShared = rand.New(rand.NewSource( 96 time.Now().UnixNano() * int64(os.Getpid()))) 97 } 98 99 comm := &Communicator{ 100 connInfo: connInfo, 101 config: config, 102 } 103 104 return comm, nil 105 } 106 107 // Connect implementation of communicator.Communicator interface 108 func (c *Communicator) Connect(o terraform.UIOutput) (err error) { 109 // Grab a lock so we can modify our internal attributes 110 c.lock.Lock() 111 defer c.lock.Unlock() 112 113 if c.conn != nil { 114 c.conn.Close() 115 } 116 117 // Set the conn and client to nil since we'll recreate it 118 c.conn = nil 119 c.client = nil 120 121 if o != nil { 122 o.Output(fmt.Sprintf( 123 "Connecting to remote host via SSH...\n"+ 124 " Host: %s\n"+ 125 " User: %s\n"+ 126 " Password: %t\n"+ 127 " Private key: %t\n"+ 128 " SSH Agent: %t\n"+ 129 " Checking Host Key: %t", 130 c.connInfo.Host, c.connInfo.User, 131 c.connInfo.Password != "", 132 c.connInfo.PrivateKey != "", 133 c.connInfo.Agent, 134 c.connInfo.HostKey != "", 135 )) 136 137 if c.connInfo.BastionHost != "" { 138 o.Output(fmt.Sprintf( 139 "Using configured bastion host...\n"+ 140 " Host: %s\n"+ 141 " User: %s\n"+ 142 " Password: %t\n"+ 143 " Private key: %t\n"+ 144 " SSH Agent: %t\n"+ 145 " Checking Host Key: %t", 146 c.connInfo.BastionHost, c.connInfo.BastionUser, 147 c.connInfo.BastionPassword != "", 148 c.connInfo.BastionPrivateKey != "", 149 c.connInfo.Agent, 150 c.connInfo.BastionHostKey != "", 151 )) 152 } 153 } 154 155 log.Printf("connecting to TCP connection for SSH") 156 c.conn, err = c.config.connection() 157 if err != nil { 158 // Explicitly set this to the REAL nil. Connection() can return 159 // a nil implementation of net.Conn which will make the 160 // "if c.conn == nil" check fail above. Read here for more information 161 // on this psychotic language feature: 162 // 163 // http://golang.org/doc/faq#nil_error 164 c.conn = nil 165 166 log.Printf("connection error: %s", err) 167 return err 168 } 169 170 log.Printf("handshaking with SSH") 171 host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 172 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) 173 if err != nil { 174 log.Printf("fatal handshake error: %s", err) 175 return fatalError{err} 176 } 177 178 c.client = ssh.NewClient(sshConn, sshChan, req) 179 180 if c.config.sshAgent != nil { 181 log.Printf("[DEBUG] Telling SSH config to forward to agent") 182 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 183 return fatalError{err} 184 } 185 186 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 187 session, err := c.newSession() 188 if err != nil { 189 return err 190 } 191 defer session.Close() 192 193 err = agent.RequestAgentForwarding(session) 194 195 if err == nil { 196 log.Printf("[INFO] agent forwarding enabled") 197 } else { 198 log.Printf("[WARN] error forwarding agent: %s", err) 199 } 200 } 201 202 if o != nil { 203 o.Output("Connected!") 204 } 205 206 return err 207 } 208 209 // Disconnect implementation of communicator.Communicator interface 210 func (c *Communicator) Disconnect() error { 211 c.lock.Lock() 212 defer c.lock.Unlock() 213 214 if c.config.sshAgent != nil { 215 if err := c.config.sshAgent.Close(); err != nil { 216 return err 217 } 218 } 219 220 if c.conn != nil { 221 conn := c.conn 222 c.conn = nil 223 return conn.Close() 224 } 225 226 return nil 227 } 228 229 // Timeout implementation of communicator.Communicator interface 230 func (c *Communicator) Timeout() time.Duration { 231 return c.connInfo.TimeoutVal 232 } 233 234 // ScriptPath implementation of communicator.Communicator interface 235 func (c *Communicator) ScriptPath() string { 236 randLock.Lock() 237 defer randLock.Unlock() 238 239 return strings.Replace( 240 c.connInfo.ScriptPath, "%RAND%", 241 strconv.FormatInt(int64(randShared.Int31()), 10), -1) 242 } 243 244 // Start implementation of communicator.Communicator interface 245 func (c *Communicator) Start(cmd *remote.Cmd) error { 246 session, err := c.newSession() 247 if err != nil { 248 return err 249 } 250 251 // Setup our session 252 session.Stdin = cmd.Stdin 253 session.Stdout = cmd.Stdout 254 session.Stderr = cmd.Stderr 255 256 if !c.config.noPty { 257 // Request a PTY 258 termModes := ssh.TerminalModes{ 259 ssh.ECHO: 0, // do not echo 260 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 261 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 262 } 263 264 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 265 return err 266 } 267 } 268 269 log.Printf("starting remote command: %s", cmd.Command) 270 err = session.Start(cmd.Command + "\n") 271 if err != nil { 272 return err 273 } 274 275 // Start a goroutine to wait for the session to end and set the 276 // exit boolean and status. 277 go func() { 278 defer session.Close() 279 280 err := session.Wait() 281 exitStatus := 0 282 if err != nil { 283 exitErr, ok := err.(*ssh.ExitError) 284 if ok { 285 exitStatus = exitErr.ExitStatus() 286 } 287 } 288 289 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 290 cmd.SetExited(exitStatus) 291 }() 292 293 return nil 294 } 295 296 // Upload implementation of communicator.Communicator interface 297 func (c *Communicator) Upload(path string, input io.Reader) error { 298 // The target directory and file for talking the SCP protocol 299 targetDir := filepath.Dir(path) 300 targetFile := filepath.Base(path) 301 302 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 303 // This does not work when the target host is unix. Switch to forward slash 304 // which works for unix and windows 305 targetDir = filepath.ToSlash(targetDir) 306 307 // Skip copying if we can get the file size directly from common io.Readers 308 size := int64(0) 309 310 switch src := input.(type) { 311 case *os.File: 312 fi, err := src.Stat() 313 if err != nil { 314 size = fi.Size() 315 } 316 case *bytes.Buffer: 317 size = int64(src.Len()) 318 case *bytes.Reader: 319 size = int64(src.Len()) 320 case *strings.Reader: 321 size = int64(src.Len()) 322 } 323 324 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 325 return scpUploadFile(targetFile, input, w, stdoutR, size) 326 } 327 328 return c.scpSession("scp -vt "+targetDir, scpFunc) 329 } 330 331 // UploadScript implementation of communicator.Communicator interface 332 func (c *Communicator) UploadScript(path string, input io.Reader) error { 333 reader := bufio.NewReader(input) 334 prefix, err := reader.Peek(2) 335 if err != nil { 336 return fmt.Errorf("Error reading script: %s", err) 337 } 338 339 var script bytes.Buffer 340 if string(prefix) != "#!" { 341 script.WriteString(DefaultShebang) 342 } 343 344 script.ReadFrom(reader) 345 if err := c.Upload(path, &script); err != nil { 346 return err 347 } 348 349 var stdout, stderr bytes.Buffer 350 cmd := &remote.Cmd{ 351 Command: fmt.Sprintf("chmod 0777 %s", path), 352 Stdout: &stdout, 353 Stderr: &stderr, 354 } 355 if err := c.Start(cmd); err != nil { 356 return fmt.Errorf( 357 "Error chmodding script file to 0777 in remote "+ 358 "machine: %s", err) 359 } 360 cmd.Wait() 361 if cmd.ExitStatus != 0 { 362 return fmt.Errorf( 363 "Error chmodding script file to 0777 in remote "+ 364 "machine %d: %s %s", cmd.ExitStatus, stdout.String(), stderr.String()) 365 } 366 367 return nil 368 } 369 370 // UploadDir implementation of communicator.Communicator interface 371 func (c *Communicator) UploadDir(dst string, src string) error { 372 log.Printf("Uploading dir '%s' to '%s'", src, dst) 373 scpFunc := func(w io.Writer, r *bufio.Reader) error { 374 uploadEntries := func() error { 375 f, err := os.Open(src) 376 if err != nil { 377 return err 378 } 379 defer f.Close() 380 381 entries, err := f.Readdir(-1) 382 if err != nil { 383 return err 384 } 385 386 return scpUploadDir(src, entries, w, r) 387 } 388 389 if src[len(src)-1] != '/' { 390 log.Printf("No trailing slash, creating the source directory name") 391 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 392 } 393 // Trailing slash, so only upload the contents 394 return uploadEntries() 395 } 396 397 return c.scpSession("scp -rvt "+dst, scpFunc) 398 } 399 400 func (c *Communicator) newSession() (session *ssh.Session, err error) { 401 log.Println("opening new ssh session") 402 if c.client == nil { 403 err = errors.New("client not available") 404 } else { 405 session, err = c.client.NewSession() 406 } 407 408 if err != nil { 409 log.Printf("ssh session open error: '%s', attempting reconnect", err) 410 if err := c.Connect(nil); err != nil { 411 return nil, err 412 } 413 414 return c.client.NewSession() 415 } 416 417 return session, nil 418 } 419 420 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 421 session, err := c.newSession() 422 if err != nil { 423 return err 424 } 425 defer session.Close() 426 427 // Get a pipe to stdin so that we can send data down 428 stdinW, err := session.StdinPipe() 429 if err != nil { 430 return err 431 } 432 433 // We only want to close once, so we nil w after we close it, 434 // and only close in the defer if it hasn't been closed already. 435 defer func() { 436 if stdinW != nil { 437 stdinW.Close() 438 } 439 }() 440 441 // Get a pipe to stdout so that we can get responses back 442 stdoutPipe, err := session.StdoutPipe() 443 if err != nil { 444 return err 445 } 446 stdoutR := bufio.NewReader(stdoutPipe) 447 448 // Set stderr to a bytes buffer 449 stderr := new(bytes.Buffer) 450 session.Stderr = stderr 451 452 // Start the sink mode on the other side 453 // TODO(mitchellh): There are probably issues with shell escaping the path 454 log.Println("Starting remote scp process: ", scpCommand) 455 if err := session.Start(scpCommand); err != nil { 456 return err 457 } 458 459 // Call our callback that executes in the context of SCP. We ignore 460 // EOF errors if they occur because it usually means that SCP prematurely 461 // ended on the other side. 462 log.Println("Started SCP session, beginning transfers...") 463 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 464 return err 465 } 466 467 // Close the stdin, which sends an EOF, and then set w to nil so that 468 // our defer func doesn't close it again since that is unsafe with 469 // the Go SSH package. 470 log.Println("SCP session complete, closing stdin pipe.") 471 stdinW.Close() 472 stdinW = nil 473 474 // Wait for the SCP connection to close, meaning it has consumed all 475 // our data and has completed. Or has errored. 476 log.Println("Waiting for SSH session to complete.") 477 err = session.Wait() 478 if err != nil { 479 if exitErr, ok := err.(*ssh.ExitError); ok { 480 // Otherwise, we have an ExitErorr, meaning we can just read 481 // the exit status 482 log.Printf(exitErr.String()) 483 484 // If we exited with status 127, it means SCP isn't available. 485 // Return a more descriptive error for that. 486 if exitErr.ExitStatus() == 127 { 487 return errors.New( 488 "SCP failed to start. This usually means that SCP is not\n" + 489 "properly installed on the remote system.") 490 } 491 } 492 493 return err 494 } 495 496 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 497 return nil 498 } 499 500 // checkSCPStatus checks that a prior command sent to SCP completed 501 // successfully. If it did not complete successfully, an error will 502 // be returned. 503 func checkSCPStatus(r *bufio.Reader) error { 504 code, err := r.ReadByte() 505 if err != nil { 506 return err 507 } 508 509 if code != 0 { 510 // Treat any non-zero (really 1 and 2) as fatal errors 511 message, _, err := r.ReadLine() 512 if err != nil { 513 return fmt.Errorf("Error reading error message: %s", err) 514 } 515 516 return errors.New(string(message)) 517 } 518 519 return nil 520 } 521 522 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { 523 if size == 0 { 524 // Create a temporary file where we can copy the contents of the src 525 // so that we can determine the length, since SCP is length-prefixed. 526 tf, err := ioutil.TempFile("", "terraform-upload") 527 if err != nil { 528 return fmt.Errorf("Error creating temporary file for upload: %s", err) 529 } 530 defer os.Remove(tf.Name()) 531 defer tf.Close() 532 533 log.Println("Copying input data into temporary file so we can read the length") 534 if _, err := io.Copy(tf, src); err != nil { 535 return err 536 } 537 538 // Sync the file so that the contents are definitely on disk, then 539 // read the length of it. 540 if err := tf.Sync(); err != nil { 541 return fmt.Errorf("Error creating temporary file for upload: %s", err) 542 } 543 544 // Seek the file to the beginning so we can re-read all of it 545 if _, err := tf.Seek(0, 0); err != nil { 546 return fmt.Errorf("Error creating temporary file for upload: %s", err) 547 } 548 549 fi, err := tf.Stat() 550 if err != nil { 551 return fmt.Errorf("Error creating temporary file for upload: %s", err) 552 } 553 554 src = tf 555 size = fi.Size() 556 } 557 558 // Start the protocol 559 log.Println("Beginning file upload...") 560 fmt.Fprintln(w, "C0644", size, dst) 561 if err := checkSCPStatus(r); err != nil { 562 return err 563 } 564 565 if _, err := io.Copy(w, src); err != nil { 566 return err 567 } 568 569 fmt.Fprint(w, "\x00") 570 if err := checkSCPStatus(r); err != nil { 571 return err 572 } 573 574 return nil 575 } 576 577 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 578 log.Printf("SCP: starting directory upload: %s", name) 579 fmt.Fprintln(w, "D0755 0", name) 580 err := checkSCPStatus(r) 581 if err != nil { 582 return err 583 } 584 585 if err := f(); err != nil { 586 return err 587 } 588 589 fmt.Fprintln(w, "E") 590 if err != nil { 591 return err 592 } 593 594 return nil 595 } 596 597 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 598 for _, fi := range fs { 599 realPath := filepath.Join(root, fi.Name()) 600 601 // Track if this is actually a symlink to a directory. If it is 602 // a symlink to a file we don't do any special behavior because uploading 603 // a file just works. If it is a directory, we need to know so we 604 // treat it as such. 605 isSymlinkToDir := false 606 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 607 symPath, err := filepath.EvalSymlinks(realPath) 608 if err != nil { 609 return err 610 } 611 612 symFi, err := os.Lstat(symPath) 613 if err != nil { 614 return err 615 } 616 617 isSymlinkToDir = symFi.IsDir() 618 } 619 620 if !fi.IsDir() && !isSymlinkToDir { 621 // It is a regular file (or symlink to a file), just upload it 622 f, err := os.Open(realPath) 623 if err != nil { 624 return err 625 } 626 627 err = func() error { 628 defer f.Close() 629 return scpUploadFile(fi.Name(), f, w, r, fi.Size()) 630 }() 631 632 if err != nil { 633 return err 634 } 635 636 continue 637 } 638 639 // It is a directory, recursively upload 640 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 641 f, err := os.Open(realPath) 642 if err != nil { 643 return err 644 } 645 defer f.Close() 646 647 entries, err := f.Readdir(-1) 648 if err != nil { 649 return err 650 } 651 652 return scpUploadDir(realPath, entries, w, r) 653 }) 654 if err != nil { 655 return err 656 } 657 } 658 659 return nil 660 } 661 662 // ConnectFunc is a convenience method for returning a function 663 // that just uses net.Dial to communicate with the remote end that 664 // is suitable for use with the SSH communicator configuration. 665 func ConnectFunc(network, addr string) func() (net.Conn, error) { 666 return func() (net.Conn, error) { 667 c, err := net.DialTimeout(network, addr, 15*time.Second) 668 if err != nil { 669 return nil, err 670 } 671 672 if tcpConn, ok := c.(*net.TCPConn); ok { 673 tcpConn.SetKeepAlive(true) 674 } 675 676 return c, nil 677 } 678 } 679 680 // BastionConnectFunc is a convenience method for returning a function 681 // that connects to a host over a bastion connection. 682 func BastionConnectFunc( 683 bProto string, 684 bAddr string, 685 bConf *ssh.ClientConfig, 686 proto string, 687 addr string) func() (net.Conn, error) { 688 return func() (net.Conn, error) { 689 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 690 bastion, err := ssh.Dial(bProto, bAddr, bConf) 691 if err != nil { 692 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 693 } 694 695 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 696 conn, err := bastion.Dial(proto, addr) 697 if err != nil { 698 bastion.Close() 699 return nil, err 700 } 701 702 // Wrap it up so we close both things properly 703 return &bastionConn{ 704 Conn: conn, 705 Bastion: bastion, 706 }, nil 707 } 708 } 709 710 type bastionConn struct { 711 net.Conn 712 Bastion *ssh.Client 713 } 714 715 func (c *bastionConn) Close() error { 716 c.Conn.Close() 717 return c.Bastion.Close() 718 }