github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/roachprod/install/session.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package install
    12  
    13  import (
    14  	"context"
    15  	"io"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"sync"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/cmd/roachprod/config"
    22  	"github.com/cockroachdb/errors"
    23  )
    24  
    25  type session interface {
    26  	CombinedOutput(cmd string) ([]byte, error)
    27  	Run(cmd string) error
    28  	SetStdin(r io.Reader)
    29  	SetStdout(w io.Writer)
    30  	SetStderr(w io.Writer)
    31  	Start(cmd string) error
    32  	StdinPipe() (io.WriteCloser, error)
    33  	StdoutPipe() (io.Reader, error)
    34  	StderrPipe() (io.Reader, error)
    35  	RequestPty() error
    36  	Wait() error
    37  	Close()
    38  }
    39  
    40  type remoteSession struct {
    41  	*exec.Cmd
    42  	cancel  func()
    43  	logfile string // captures ssh -vvv
    44  }
    45  
    46  func newRemoteSession(user, host string, logdir string) (*remoteSession, error) {
    47  	// TODO(tbg): this is disabled at the time of writing. It was difficult
    48  	// to assign the logfiles to the roachtest and as a bonus our CI harness
    49  	// never actually managed to collect the files since they had wrong
    50  	// permissions; instead they clogged up the roachprod dir.
    51  	// logfile := filepath.Join(
    52  	//	logdir,
    53  	// 	fmt.Sprintf("ssh_%s_%s", host, timeutil.Now().Format(time.RFC3339)),
    54  	// )
    55  	const logfile = ""
    56  	args := []string{
    57  		user + "@" + host,
    58  
    59  		// TODO(tbg): see above.
    60  		//"-vvv", "-E", logfile,
    61  		// NB: -q suppresses -E, at least on OSX. Difficult decisions will have
    62  		// to be made if omitting -q leads to annoyance on stdout/stderr.
    63  
    64  		"-q",
    65  		"-o", "UserKnownHostsFile=/dev/null",
    66  		"-o", "StrictHostKeyChecking=no",
    67  		// Send keep alives every minute to prevent connections without activity
    68  		// from dropping. (It is speculative that the absence of this caused
    69  		// problems, though there's some indication that we need them:
    70  		//
    71  		// https://github.com/cockroachdb/cockroach/issues/35337
    72  		"-o", "ServerAliveInterval=60",
    73  	}
    74  	args = append(args, sshAuthArgs()...)
    75  	ctx, cancel := context.WithCancel(context.Background())
    76  	cmd := exec.CommandContext(ctx, "ssh", args...)
    77  	return &remoteSession{cmd, cancel, logfile}, nil
    78  }
    79  
    80  func (s *remoteSession) errWithDebug(err error) error {
    81  	if err != nil && s.logfile != "" {
    82  		err = errors.Wrapf(err, "ssh verbose log retained in %s", s.logfile)
    83  		s.logfile = "" // prevent removal on close
    84  	}
    85  	return err
    86  }
    87  
    88  func (s *remoteSession) CombinedOutput(cmd string) ([]byte, error) {
    89  	s.Cmd.Args = append(s.Cmd.Args, cmd)
    90  	b, err := s.Cmd.CombinedOutput()
    91  	return b, s.errWithDebug(err)
    92  }
    93  
    94  func (s *remoteSession) Run(cmd string) error {
    95  	s.Cmd.Args = append(s.Cmd.Args, cmd)
    96  	return s.errWithDebug(s.Cmd.Run())
    97  }
    98  
    99  func (s *remoteSession) Start(cmd string) error {
   100  	s.Cmd.Args = append(s.Cmd.Args, cmd)
   101  	return s.Cmd.Start()
   102  }
   103  
   104  func (s *remoteSession) SetStdin(r io.Reader) {
   105  	s.Stdin = r
   106  }
   107  
   108  func (s *remoteSession) SetStdout(w io.Writer) {
   109  	s.Stdout = w
   110  }
   111  
   112  func (s *remoteSession) SetStderr(w io.Writer) {
   113  	s.Stderr = w
   114  }
   115  
   116  func (s *remoteSession) StdoutPipe() (io.Reader, error) {
   117  	// NB: exec.Cmd.StdoutPipe returns a io.ReadCloser, hence the need for the
   118  	// temporary storage into local variables.
   119  	r, err := s.Cmd.StdoutPipe()
   120  	return r, err
   121  }
   122  
   123  func (s *remoteSession) StderrPipe() (io.Reader, error) {
   124  	// NB: exec.Cmd.StderrPipe returns a io.ReadCloser, hence the need for the
   125  	// temporary storage into local variables.
   126  	r, err := s.Cmd.StderrPipe()
   127  	return r, err
   128  }
   129  
   130  func (s *remoteSession) RequestPty() error {
   131  	s.Cmd.Args = append(s.Cmd.Args, "-t")
   132  	return nil
   133  }
   134  
   135  func (s *remoteSession) Wait() error {
   136  	return s.Cmd.Wait()
   137  }
   138  
   139  func (s *remoteSession) Close() {
   140  	s.cancel()
   141  	if s.logfile != "" {
   142  		_ = os.Remove(s.logfile)
   143  	}
   144  }
   145  
   146  type localSession struct {
   147  	*exec.Cmd
   148  	cancel func()
   149  }
   150  
   151  func newLocalSession() *localSession {
   152  	ctx, cancel := context.WithCancel(context.Background())
   153  	cmd := exec.CommandContext(ctx, "/bin/bash", "-c")
   154  	return &localSession{cmd, cancel}
   155  }
   156  
   157  func (s *localSession) CombinedOutput(cmd string) ([]byte, error) {
   158  	s.Cmd.Args = append(s.Cmd.Args, cmd)
   159  	return s.Cmd.CombinedOutput()
   160  }
   161  
   162  func (s *localSession) Run(cmd string) error {
   163  	s.Cmd.Args = append(s.Cmd.Args, cmd)
   164  	return s.Cmd.Run()
   165  }
   166  
   167  func (s *localSession) Start(cmd string) error {
   168  	s.Cmd.Args = append(s.Cmd.Args, cmd)
   169  	return s.Cmd.Start()
   170  }
   171  
   172  func (s *localSession) SetStdin(r io.Reader) {
   173  	s.Stdin = r
   174  }
   175  
   176  func (s *localSession) SetStdout(w io.Writer) {
   177  	s.Stdout = w
   178  }
   179  
   180  func (s *localSession) SetStderr(w io.Writer) {
   181  	s.Stderr = w
   182  }
   183  
   184  func (s *localSession) StdoutPipe() (io.Reader, error) {
   185  	// NB: exec.Cmd.StdoutPipe returns a io.ReadCloser, hence the need for the
   186  	// temporary storage into local variables.
   187  	r, err := s.Cmd.StdoutPipe()
   188  	return r, err
   189  }
   190  
   191  func (s *localSession) StderrPipe() (io.Reader, error) {
   192  	// NB: exec.Cmd.StderrPipe returns a io.ReadCloser, hence the need for the
   193  	// temporary storage into local variables.
   194  	r, err := s.Cmd.StderrPipe()
   195  	return r, err
   196  }
   197  
   198  func (s *localSession) RequestPty() error {
   199  	return nil
   200  }
   201  
   202  func (s *localSession) Wait() error {
   203  	return s.Cmd.Wait()
   204  }
   205  
   206  func (s *localSession) Close() {
   207  	s.cancel()
   208  }
   209  
   210  var sshAuthArgsVal []string
   211  var sshAuthArgsOnce sync.Once
   212  
   213  func sshAuthArgs() []string {
   214  	sshAuthArgsOnce.Do(func() {
   215  		paths := []string{
   216  			filepath.Join(config.OSUser.HomeDir, ".ssh", "id_rsa"),
   217  			filepath.Join(config.OSUser.HomeDir, ".ssh", "google_compute_engine"),
   218  		}
   219  		for _, p := range paths {
   220  			if _, err := os.Stat(p); err == nil {
   221  				sshAuthArgsVal = append(sshAuthArgsVal, "-i", p)
   222  			}
   223  		}
   224  	})
   225  	return sshAuthArgsVal
   226  }