github.com/mfpierre/corectl@v0.5.6/ssh.go (about)

     1  // Copyright 2015 - António Meireles  <antonio.meireles@reformi.st>
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  package main
    17  
    18  import (
    19  	"fmt"
    20  	"io"
    21  	"log"
    22  	"os"
    23  	"path/filepath"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/pkg/sftp"
    28  	"github.com/rakyll/pb"
    29  	"github.com/spf13/cobra"
    30  	"golang.org/x/crypto/ssh"
    31  	"golang.org/x/crypto/ssh/terminal"
    32  )
    33  
    34  var (
    35  	sshCmd = &cobra.Command{
    36  		Use:     "ssh VMid [\"command1;...\"]",
    37  		Aliases: []string{"attach"},
    38  		Short:   "Attach to or run commands inside a running CoreOS instance",
    39  		PreRunE: func(cmd *cobra.Command, args []string) (err error) {
    40  			engine.rawArgs.BindPFlags(cmd.Flags())
    41  			if len(args) < 1 {
    42  				return fmt.Errorf("This command requires at least " +
    43  					"one argument to work ")
    44  			}
    45  			return
    46  		},
    47  		RunE: sshCommand,
    48  		Example: `  corectl ssh VMid                 // logins into VMid
    49    corectl ssh VMid "some commands" // runs 'some commands' inside VMid and exits`,
    50  	}
    51  	scpCmd = &cobra.Command{
    52  		Use:     "put path/to/file VMid:/file/path/on/destination",
    53  		Aliases: []string{"copy", "cp", "scp"},
    54  		Short:   "copy file to inside VM",
    55  		PreRunE: func(cmd *cobra.Command, args []string) (err error) {
    56  			engine.rawArgs.BindPFlags(cmd.Flags())
    57  			if len(args) < 2 {
    58  				return fmt.Errorf("This command requires at least " +
    59  					"two argument to work ")
    60  			}
    61  			return
    62  		},
    63  		RunE: scpCommand,
    64  		Example: `  // copies 'filePath' into '/destinationPath' inside VMid
    65    corectl put filePath VMid:/destinationPath`,
    66  	}
    67  )
    68  
    69  func sshCommand(cmd *cobra.Command, args []string) (err error) {
    70  	var sshSession = &sshClient{}
    71  	vm := VMInfo{}
    72  
    73  	if vm, err = vmInfo(args[0]); err != nil {
    74  		return
    75  	}
    76  
    77  	if sshSession, err = vm.startSSHsession(); err != nil {
    78  		return
    79  	}
    80  	defer sshSession.close()
    81  
    82  	if len(args) > 1 {
    83  		return sshSession.executeRemoteCommand(strings.Join(args[1:], " "))
    84  	}
    85  	return sshSession.remoteShell()
    86  }
    87  
    88  type sshClient struct {
    89  	session                   *ssh.Session
    90  	conn                      *ssh.Client
    91  	oldState                  *terminal.State
    92  	termWidth, termHeight, fd int
    93  }
    94  
    95  func (c *sshClient) close() {
    96  	c.conn.Close()
    97  	c.session.Close()
    98  	terminal.Restore(c.fd, c.oldState)
    99  }
   100  
   101  func (vm VMInfo) startSSHsession() (c *sshClient, err error) {
   102  	var secret ssh.Signer
   103  	c = &sshClient{}
   104  
   105  	if secret, err = ssh.ParsePrivateKey(
   106  		[]byte(vm.InternalSSHprivKey)); err != nil {
   107  		return
   108  	}
   109  
   110  	config := &ssh.ClientConfig{
   111  		User: "core", Auth: []ssh.AuthMethod{
   112  			ssh.PublicKeys(secret),
   113  		},
   114  	}
   115  
   116  	//wait a bit for VM's ssh to be available...
   117  	for {
   118  		var e error
   119  		if c.conn, e = ssh.Dial("tcp", vm.PublicIP+":22", config); e == nil {
   120  			break
   121  		}
   122  		time.Sleep(100 * time.Millisecond)
   123  		select {
   124  		case <-time.After(time.Second * 5):
   125  			return c, fmt.Errorf("%s unreachable", vm.PublicIP+":22")
   126  		}
   127  	}
   128  
   129  	if c.session, err = c.conn.NewSession(); err != nil {
   130  		return c, fmt.Errorf("unable to create session: %s", err)
   131  	}
   132  
   133  	c.fd = int(os.Stdin.Fd())
   134  	if c.oldState, err = terminal.MakeRaw(c.fd); err != nil {
   135  		return
   136  	}
   137  
   138  	c.session.Stdout, c.session.Stderr, c.session.Stdin =
   139  		os.Stdout, os.Stderr, os.Stdin
   140  
   141  	if c.termWidth, c.termHeight, err = terminal.GetSize(c.fd); err != nil {
   142  		return
   143  	}
   144  
   145  	modes := ssh.TerminalModes{
   146  		ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400,
   147  	}
   148  
   149  	if err = c.session.RequestPty("xterm-256color",
   150  		c.termHeight, c.termWidth, modes); err != nil {
   151  		return c, fmt.Errorf("request for pseudo terminal failed: %s", err)
   152  	}
   153  	return
   154  }
   155  
   156  func (c *sshClient) executeRemoteCommand(run string) (err error) {
   157  	if err = c.session.Run(run); err != nil && !strings.HasSuffix(err.Error(),
   158  		"exited without exit status or exit signal") {
   159  		return
   160  	}
   161  	return nil
   162  }
   163  
   164  func (c *sshClient) remoteShell() (err error) {
   165  	if err = c.session.Shell(); err != nil {
   166  		return
   167  	}
   168  
   169  	if err = c.session.Wait(); err != nil && !strings.HasSuffix(err.Error(),
   170  		"exited without exit status or exit signal") {
   171  		return err
   172  	}
   173  	return nil
   174  }
   175  
   176  func vmInfo(id string) (vm VMInfo, err error) {
   177  	var up []VMInfo
   178  	if up, err = allRunningInstances(); err != nil {
   179  		return
   180  	}
   181  	for _, v := range up {
   182  		if v.Name == id || v.UUID == id {
   183  			return v, err
   184  		}
   185  	}
   186  	return vm, fmt.Errorf("'%s' not found, or dead", id)
   187  }
   188  
   189  func (c *sshClient) sCopy(source, destination, target string) (err error) {
   190  	var (
   191  		ftp         *sftp.Client
   192  		src         *os.File
   193  		srcS, destS os.FileInfo
   194  		dest        *sftp.File
   195  		bar         *pb.ProgressBar
   196  	)
   197  
   198  	if ftp, err = sftp.NewClient(c.conn); err != nil {
   199  		return
   200  	}
   201  	defer ftp.Close()
   202  
   203  	if src, err = os.Open(source); err != nil {
   204  		return
   205  	}
   206  	defer src.Close()
   207  	if srcS, err = os.Stat(source); err != nil {
   208  		return
   209  	}
   210  	if _, err = ftp.ReadDir(filepath.Dir(destination)); err != nil {
   211  		err = fmt.Errorf("unable to upload %v as parent %v "+
   212  			"not in target", source, filepath.Dir(destination))
   213  		return
   214  	}
   215  	if _, err = ftp.ReadDir(destination); err == nil {
   216  		destination = ftp.Join(destination, "/", filepath.Base(source))
   217  	}
   218  	if dest, err = ftp.Create(destination); err != nil {
   219  		return
   220  	}
   221  	defer dest.Close()
   222  	log.Println("uploading '" + source + "' to '" +
   223  		target + ":" + destination + "'")
   224  	bar = pb.New(int(srcS.Size())).SetUnits(pb.U_BYTES)
   225  	bar.Start()
   226  	writer := io.MultiWriter(bar, dest)
   227  	defer bar.Finish()
   228  	if _, err = io.Copy(writer, src); err != nil {
   229  		return
   230  	}
   231  
   232  	if destS, err = ftp.Stat(destination); err != nil {
   233  		return
   234  	}
   235  	if srcS.Size() != destS.Size() {
   236  		err = fmt.Errorf("something went wrong. " +
   237  			"destination file size != from sources'")
   238  	}
   239  	return
   240  }
   241  
   242  func scpCommand(cmd *cobra.Command, args []string) (err error) {
   243  	var (
   244  		session, vm                 = &sshClient{}, VMInfo{}
   245  		split                       = strings.Split(args[1], ":")
   246  		source, destination, target = args[0], split[1], split[0]
   247  	)
   248  	if vm, err = vmInfo(target); err != nil {
   249  		return
   250  	}
   251  	if session, err = vm.startSSHsession(); err != nil {
   252  		return
   253  	}
   254  	defer session.close()
   255  	return session.sCopy(source, destination, target)
   256  }
   257  
   258  func init() {
   259  	RootCmd.AddCommand(sshCmd)
   260  	RootCmd.AddCommand(scpCmd)
   261  }