github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/roachprod/install/session.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package install 12 13 import ( 14 "context" 15 "io" 16 "os" 17 "os/exec" 18 "path/filepath" 19 "sync" 20 21 "github.com/cockroachdb/cockroach/pkg/cmd/roachprod/config" 22 "github.com/cockroachdb/errors" 23 ) 24 25 type session interface { 26 CombinedOutput(cmd string) ([]byte, error) 27 Run(cmd string) error 28 SetStdin(r io.Reader) 29 SetStdout(w io.Writer) 30 SetStderr(w io.Writer) 31 Start(cmd string) error 32 StdinPipe() (io.WriteCloser, error) 33 StdoutPipe() (io.Reader, error) 34 StderrPipe() (io.Reader, error) 35 RequestPty() error 36 Wait() error 37 Close() 38 } 39 40 type remoteSession struct { 41 *exec.Cmd 42 cancel func() 43 logfile string // captures ssh -vvv 44 } 45 46 func newRemoteSession(user, host string, logdir string) (*remoteSession, error) { 47 // TODO(tbg): this is disabled at the time of writing. It was difficult 48 // to assign the logfiles to the roachtest and as a bonus our CI harness 49 // never actually managed to collect the files since they had wrong 50 // permissions; instead they clogged up the roachprod dir. 51 // logfile := filepath.Join( 52 // logdir, 53 // fmt.Sprintf("ssh_%s_%s", host, timeutil.Now().Format(time.RFC3339)), 54 // ) 55 const logfile = "" 56 args := []string{ 57 user + "@" + host, 58 59 // TODO(tbg): see above. 60 //"-vvv", "-E", logfile, 61 // NB: -q suppresses -E, at least on OSX. Difficult decisions will have 62 // to be made if omitting -q leads to annoyance on stdout/stderr. 63 64 "-q", 65 "-o", "UserKnownHostsFile=/dev/null", 66 "-o", "StrictHostKeyChecking=no", 67 // Send keep alives every minute to prevent connections without activity 68 // from dropping. (It is speculative that the absence of this caused 69 // problems, though there's some indication that we need them: 70 // 71 // https://github.com/cockroachdb/cockroach/issues/35337 72 "-o", "ServerAliveInterval=60", 73 } 74 args = append(args, sshAuthArgs()...) 75 ctx, cancel := context.WithCancel(context.Background()) 76 cmd := exec.CommandContext(ctx, "ssh", args...) 77 return &remoteSession{cmd, cancel, logfile}, nil 78 } 79 80 func (s *remoteSession) errWithDebug(err error) error { 81 if err != nil && s.logfile != "" { 82 err = errors.Wrapf(err, "ssh verbose log retained in %s", s.logfile) 83 s.logfile = "" // prevent removal on close 84 } 85 return err 86 } 87 88 func (s *remoteSession) CombinedOutput(cmd string) ([]byte, error) { 89 s.Cmd.Args = append(s.Cmd.Args, cmd) 90 b, err := s.Cmd.CombinedOutput() 91 return b, s.errWithDebug(err) 92 } 93 94 func (s *remoteSession) Run(cmd string) error { 95 s.Cmd.Args = append(s.Cmd.Args, cmd) 96 return s.errWithDebug(s.Cmd.Run()) 97 } 98 99 func (s *remoteSession) Start(cmd string) error { 100 s.Cmd.Args = append(s.Cmd.Args, cmd) 101 return s.Cmd.Start() 102 } 103 104 func (s *remoteSession) SetStdin(r io.Reader) { 105 s.Stdin = r 106 } 107 108 func (s *remoteSession) SetStdout(w io.Writer) { 109 s.Stdout = w 110 } 111 112 func (s *remoteSession) SetStderr(w io.Writer) { 113 s.Stderr = w 114 } 115 116 func (s *remoteSession) StdoutPipe() (io.Reader, error) { 117 // NB: exec.Cmd.StdoutPipe returns a io.ReadCloser, hence the need for the 118 // temporary storage into local variables. 119 r, err := s.Cmd.StdoutPipe() 120 return r, err 121 } 122 123 func (s *remoteSession) StderrPipe() (io.Reader, error) { 124 // NB: exec.Cmd.StderrPipe returns a io.ReadCloser, hence the need for the 125 // temporary storage into local variables. 126 r, err := s.Cmd.StderrPipe() 127 return r, err 128 } 129 130 func (s *remoteSession) RequestPty() error { 131 s.Cmd.Args = append(s.Cmd.Args, "-t") 132 return nil 133 } 134 135 func (s *remoteSession) Wait() error { 136 return s.Cmd.Wait() 137 } 138 139 func (s *remoteSession) Close() { 140 s.cancel() 141 if s.logfile != "" { 142 _ = os.Remove(s.logfile) 143 } 144 } 145 146 type localSession struct { 147 *exec.Cmd 148 cancel func() 149 } 150 151 func newLocalSession() *localSession { 152 ctx, cancel := context.WithCancel(context.Background()) 153 cmd := exec.CommandContext(ctx, "/bin/bash", "-c") 154 return &localSession{cmd, cancel} 155 } 156 157 func (s *localSession) CombinedOutput(cmd string) ([]byte, error) { 158 s.Cmd.Args = append(s.Cmd.Args, cmd) 159 return s.Cmd.CombinedOutput() 160 } 161 162 func (s *localSession) Run(cmd string) error { 163 s.Cmd.Args = append(s.Cmd.Args, cmd) 164 return s.Cmd.Run() 165 } 166 167 func (s *localSession) Start(cmd string) error { 168 s.Cmd.Args = append(s.Cmd.Args, cmd) 169 return s.Cmd.Start() 170 } 171 172 func (s *localSession) SetStdin(r io.Reader) { 173 s.Stdin = r 174 } 175 176 func (s *localSession) SetStdout(w io.Writer) { 177 s.Stdout = w 178 } 179 180 func (s *localSession) SetStderr(w io.Writer) { 181 s.Stderr = w 182 } 183 184 func (s *localSession) StdoutPipe() (io.Reader, error) { 185 // NB: exec.Cmd.StdoutPipe returns a io.ReadCloser, hence the need for the 186 // temporary storage into local variables. 187 r, err := s.Cmd.StdoutPipe() 188 return r, err 189 } 190 191 func (s *localSession) StderrPipe() (io.Reader, error) { 192 // NB: exec.Cmd.StderrPipe returns a io.ReadCloser, hence the need for the 193 // temporary storage into local variables. 194 r, err := s.Cmd.StderrPipe() 195 return r, err 196 } 197 198 func (s *localSession) RequestPty() error { 199 return nil 200 } 201 202 func (s *localSession) Wait() error { 203 return s.Cmd.Wait() 204 } 205 206 func (s *localSession) Close() { 207 s.cancel() 208 } 209 210 var sshAuthArgsVal []string 211 var sshAuthArgsOnce sync.Once 212 213 func sshAuthArgs() []string { 214 sshAuthArgsOnce.Do(func() { 215 paths := []string{ 216 filepath.Join(config.OSUser.HomeDir, ".ssh", "id_rsa"), 217 filepath.Join(config.OSUser.HomeDir, ".ssh", "google_compute_engine"), 218 } 219 for _, p := range paths { 220 if _, err := os.Stat(p); err == nil { 221 sshAuthArgsVal = append(sshAuthArgsVal, "-i", p) 222 } 223 } 224 }) 225 return sshAuthArgsVal 226 }