github.com/viant/toolbox@v0.34.5/ssh/session.go (about) 1 package ssh 2 3 import ( 4 "bytes" 5 "fmt" 6 "github.com/lunixbochs/vtclean" 7 "github.com/pkg/errors" 8 "github.com/viant/toolbox" 9 "golang.org/x/crypto/ssh" 10 "io" 11 "path" 12 "strings" 13 "sync" 14 "sync/atomic" 15 "time" 16 ) 17 18 type TerminatedError struct { 19 Err error 20 } 21 22 func (t *TerminatedError) Error() string { 23 return fmt.Sprintf("terminated due to %v", t.Err) 24 } 25 26 //ErrTerminated - command session terminated 27 var ErrTerminated = &TerminatedError{} 28 29 const defaultShell = "/bin/bash" 30 31 const ( 32 drainTimeoutMs = 10 33 defaultTimeoutMs = 20000 34 stdoutFlashFrequencyMs = 1000 35 initTimeoutMs = 300 36 defaultTickFrequency = 100 37 ) 38 39 //Listener represent command listener (it will send stdout fragments as thier being available on stdout) 40 type Listener func(stdout string, hasMore bool) 41 42 //MultiCommandSession represents a multi command session 43 type MultiCommandSession interface { 44 Run(command string, listener Listener, timeoutMs int, terminators ...string) (string, error) 45 46 ShellPrompt() string 47 48 System() string 49 50 Reconnect() error 51 52 Close() 53 } 54 55 //multiCommandSession represents a multi command session 56 //a new command are send vi stdin 57 type multiCommandSession struct { 58 service *service 59 config *SessionConfig 60 replayCommands *ReplayCommands 61 recordSession bool 62 session *ssh.Session 63 stdOutput chan string 64 stdError chan string 65 stdInput io.WriteCloser 66 promptSequence string 67 shellPrompt string 68 escapedShellPrompt string 69 system string 70 running int32 71 stdin string 72 } 73 74 func (s *multiCommandSession) Run(command string, listener Listener, timeoutMs int, terminators ...string) (string, error) { 75 if atomic.LoadInt32(&s.running) == 0 { 76 return "", ErrTerminated 77 } 78 s.drainStdout() 79 if !strings.HasSuffix(command, "\n") { 80 command += "\n" 81 } 82 var stdin = command 83 s.stdin = stdin 84 _, err := s.stdInput.Write([]byte(stdin)) 85 if err != nil { 86 return "", fmt.Errorf("failed to execute command: %v, err: %v", command, err) 87 } 88 var output string 89 output, _, err = s.readResponse(timeoutMs, listener, terminators...) 90 if s.recordSession { 91 s.replayCommands.Register(stdin, output) 92 } 93 return output, err 94 } 95 96 //ShellPrompt returns a shell prompt 97 func (s *multiCommandSession) ShellPrompt() string { 98 return s.shellPrompt 99 } 100 101 //System returns a system name 102 func (s *multiCommandSession) System() string { 103 return s.system 104 } 105 106 //Close closes the session with its resources 107 func (s *multiCommandSession) Close() { 108 atomic.StoreInt32(&s.running, 0) 109 _ = s.stdInput.Close() 110 if s.session != nil { 111 _ = s.session.Close() 112 } 113 114 } 115 116 func (s *multiCommandSession) closeIfError(err error) bool { 117 if err != nil { 118 if err != ErrTerminated { 119 ErrTerminated.Err = err 120 } 121 s.Close() 122 return true 123 } 124 return false 125 } 126 127 func (s *multiCommandSession) start(shell string) (output string, err error) { 128 var reader, errReader io.Reader 129 reader, err = s.session.StdoutPipe() 130 if err != nil { 131 return "", err 132 } 133 errReader, err = s.session.StderrPipe() 134 if err != nil { 135 return "", err 136 } 137 138 waitGroup := sync.WaitGroup{} 139 waitGroup.Add(2) 140 go func() { 141 waitGroup.Done() 142 s.copy(reader, s.stdOutput) 143 }() 144 go func() { 145 waitGroup.Done() 146 s.copy(errReader, s.stdError) 147 }() 148 149 if shell == "" { 150 shell = defaultShell 151 } 152 waitGroup.Wait() 153 s.stdin = shell 154 err = s.session.Start(shell) 155 if err != nil { 156 return "", err 157 } 158 _, name := path.Split(shell) 159 output, _, err = s.readResponse(defaultTimeoutMs, nil, name) 160 return output, err 161 } 162 163 //copy copy data from reader to channel 164 func (s *multiCommandSession) copy(reader io.Reader, out chan string) { 165 var written int64 = 0 166 buf := make([]byte, 128*1024) 167 var err error 168 var bytesRead int 169 for { 170 writer := new(bytes.Buffer) 171 if atomic.LoadInt32(&s.running) == 0 { 172 return 173 } 174 bytesRead, err = reader.Read(buf) 175 if bytesRead > 0 { 176 bytesWritten, writeError := writer.Write(buf[:bytesRead]) 177 if s.closeIfError(writeError) { 178 return 179 } 180 if bytesWritten > 0 { 181 written += int64(bytesWritten) 182 } 183 184 if bytesRead != bytesWritten { 185 if s.closeIfError(io.ErrShortWrite) { 186 return 187 } 188 } 189 out <- string(writer.Bytes()) 190 } 191 192 if s.closeIfError(err) { 193 return 194 } 195 } 196 } 197 198 func escapeInput(input string) string { 199 input = vtclean.Clean(input, false) 200 if input == "" { 201 return input 202 } 203 return strings.Trim(input, "\n\r\t ") 204 } 205 206 func (s *multiCommandSession) Reconnect() (err error) { 207 atomic.StoreInt32(&s.running, 1) 208 s.service.Reconnect() 209 s.session, err = s.service.client.NewSession() 210 defer func() { 211 if err != nil { 212 s.service.client.Close() 213 } 214 }() 215 if err != nil { 216 return err 217 } 218 return s.init() 219 } 220 221 func (s *multiCommandSession) hasPrompt(input string) bool { 222 escapedInput := escapeInput(input) 223 var shellPrompt = s.shellPrompt 224 if shellPrompt == "" { 225 shellPrompt = "$" 226 } 227 if s.escapedShellPrompt == "" && s.shellPrompt != "" { 228 s.escapedShellPrompt = escapeInput(s.shellPrompt) 229 } 230 231 if s.escapedShellPrompt != "" && strings.HasSuffix(escapedInput, s.escapedShellPrompt) || strings.HasSuffix(input, s.shellPrompt) { 232 return true 233 } 234 return false 235 } 236 237 func (s *multiCommandSession) hasTerminator(input string, terminators ...string) bool { 238 escapedInput := escapeInput(input) 239 input = escapedInput 240 for _, candidate := range terminators { 241 candidateLen := len(candidate) 242 if candidateLen == 0 { 243 continue 244 } 245 if candidate[0:1] == "^" && strings.HasPrefix(input, candidate[1:]) { 246 return true 247 } 248 if candidate[candidateLen-1:] == "$" && strings.HasSuffix(input, candidate[:candidateLen-1]) { 249 return true 250 } 251 if strings.Contains(input, candidate) { 252 return true 253 } 254 } 255 return false 256 } 257 258 func (s *multiCommandSession) removePromptIfNeeded(stdout string) string { 259 if strings.Contains(stdout, s.shellPrompt) { 260 stdout = strings.Replace(stdout, s.shellPrompt, "", 1) 261 var lines = []string{} 262 for _, line := range strings.Split(stdout, "\n") { 263 if strings.TrimSpace(line) == "" { 264 continue 265 } 266 lines = append(lines, line) 267 } 268 stdout = strings.Join(lines, "\n") 269 } 270 return stdout 271 } 272 273 func (s *multiCommandSession) readResponse(timeoutMs int, listener Listener, terminators ...string) (out string, has bool, err error) { 274 var hasPrompt, hasTerminator bool 275 if timeoutMs == 0 { 276 timeoutMs = defaultTimeoutMs 277 } 278 notification := newNotificationWindow(listener, stdoutFlashFrequencyMs) 279 defer notification.flush() 280 281 var done int32 282 defer atomic.StoreInt32(&done, 1) 283 var errOut string 284 var hasOutput bool 285 286 var waitTimeMs = 0 287 var tickFrequencyMs = defaultTickFrequency 288 if tickFrequencyMs > timeoutMs { 289 tickFrequencyMs = timeoutMs 290 } 291 var timeoutDuration = time.Duration(tickFrequencyMs) * time.Millisecond 292 293 outer: 294 for { 295 select { 296 case partialOutput := <-s.stdOutput: 297 waitTimeMs = 0 298 out += partialOutput 299 hasTerminator = s.hasTerminator(out, terminators...) 300 if len(partialOutput) > 0 { 301 if hasTerminator { 302 partialOutput = addLineBreakIfNeeded(partialOutput) 303 } 304 notification.notify(s.removePromptIfNeeded(partialOutput)) 305 } 306 hasPrompt = s.hasPrompt(out) 307 if (hasPrompt || hasTerminator) && len(s.stdOutput) == 0 { 308 break outer 309 } 310 case e := <-s.stdError: 311 errOut += e 312 notification.notify(s.removePromptIfNeeded(e)) 313 hasPrompt = s.hasPrompt(errOut) 314 hasTerminator = s.hasTerminator(errOut, terminators...) 315 if (hasPrompt || hasTerminator) && len(s.stdOutput) == 0 { 316 break outer 317 } 318 case <-time.After(timeoutDuration): 319 waitTimeMs += tickFrequencyMs 320 if waitTimeMs >= timeoutMs { 321 break outer 322 } 323 } 324 } 325 326 if hasTerminator { 327 s.drainStdout() 328 329 } 330 if errOut != "" { 331 err = errors.New(errOut) 332 } 333 334 if len(out) > 0 { 335 hasOutput = true 336 out = s.removePromptIfNeeded(out) 337 } 338 return out, hasOutput, err 339 } 340 func addLineBreakIfNeeded(text string) string { 341 index := strings.LastIndex(text, "\n") 342 if index == -1 { 343 return text + "\n" 344 } 345 lastFragment := string(text[index:]) 346 if strings.TrimSpace(lastFragment) != "" { 347 return text + "\n" 348 } 349 return text 350 } 351 352 func (s *multiCommandSession) drainStdout() { 353 //read any outstanding output 354 for { 355 _, has, _ := s.readResponse(drainTimeoutMs, nil, "") 356 if !has { 357 return 358 } 359 } 360 } 361 362 func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { 363 c := make(chan bool) 364 go func() { 365 defer close(c) 366 wg.Wait() 367 c <- true 368 }() 369 select { 370 case <-c: 371 return false // completed normally 372 case <-time.After(timeout): 373 return true // timed out 374 } 375 } 376 377 func (s *multiCommandSession) shellInit() (err error) { 378 if s.promptSequence != "" { 379 if _, err = s.Run(s.promptSequence, nil, initTimeoutMs); err != nil { 380 return err 381 } 382 } 383 384 var ts = toolbox.AsString(time.Now().UnixNano()) 385 var waitGroup = &sync.WaitGroup{} 386 waitGroup.Add(1) 387 388 if s.config.Shell == defaultShell { 389 s.promptSequence = "PS1=\"" + ts + "\\$\"" 390 s.shellPrompt = ts + "$" 391 s.escapedShellPrompt = escapeInput(s.shellPrompt) 392 } 393 394 var listener Listener 395 listener = func(stdout string, hasMore bool) { 396 if !hasMore { 397 waitGroup.Done() 398 } 399 } 400 401 _, err = s.Run("", listener, initTimeoutMs, "$") 402 403 waitTimeout(waitGroup, 60*time.Second) 404 s.drainStdout() 405 _, err = s.Run(s.promptSequence, nil, defaultTimeoutMs, "$") 406 if s.closeIfError(err) { 407 return err 408 } 409 for i := 0; i < 3; i++ { 410 s.system, err = s.Run("uname -s", nil, initTimeoutMs) 411 s.system = strings.ToLower(strings.TrimSpace(s.system)) 412 if strings.TrimSpace(s.system) != "" && !strings.Contains(s.system, "$") { 413 break 414 } 415 } 416 s.drainStdout() 417 return nil 418 } 419 420 func (s *multiCommandSession) init() (err error) { 421 s.session, err = s.service.client.NewSession() 422 defer func() { 423 if err != nil { 424 s.service.client.Close() 425 } 426 }() 427 s.stdOutput = make(chan string) 428 s.stdError = make(chan string) 429 for k, v := range s.config.EnvVariables { 430 err = s.session.Setenv(k, v) 431 if err != nil { 432 return err 433 } 434 } 435 modes := ssh.TerminalModes{ 436 ssh.ECHO: 0, // disable echoing 437 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 438 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 439 } 440 441 if err := s.session.RequestPty(s.config.Term, s.config.Rows, s.config.Columns, modes); err != nil { 442 return err 443 } 444 445 if s.stdInput, err = s.session.StdinPipe(); err != nil { 446 return err 447 } 448 449 stdout, err := s.start(s.config.Shell) 450 if s.closeIfError(err) { 451 return err 452 } 453 if err = checkNotFound(stdout); err != nil { 454 return fmt.Errorf("failed to open %v shell, %v", s.config.Shell, err) 455 } 456 return s.shellInit() 457 } 458 459 func checkNotFound(output string) error { 460 if strings.Contains(output, "not found") { 461 return fmt.Errorf("failed run %s", output) 462 } 463 return nil 464 } 465 466 func newMultiCommandSession(service *service, config *SessionConfig, replayCommands *ReplayCommands, recordSession bool) (MultiCommandSession, error) { 467 if config == nil { 468 config = &SessionConfig{} 469 } 470 config.applyDefault() 471 472 result := &multiCommandSession{ 473 service: service, 474 config: config, 475 running: 1, 476 recordSession: recordSession, 477 replayCommands: replayCommands, 478 } 479 return result, result.init() 480 } 481 482 type notificationWindow struct { 483 checkpoint *time.Time 484 listener Listener 485 elapsedMs int 486 stdout string 487 frequencyMs int 488 } 489 490 func (t *notificationWindow) flush() { 491 if t.listener == nil { 492 return 493 } 494 if t.stdout != "" { 495 t.listener(t.stdout, true) 496 } 497 498 t.listener("", false) 499 } 500 501 func (t *notificationWindow) notify(stdout string) { 502 var now = time.Now() 503 if t.listener == nil { 504 return 505 } 506 t.stdout += stdout 507 t.elapsedMs += int(now.Sub(*t.checkpoint) / time.Millisecond) 508 t.checkpoint = &now 509 if t.elapsedMs > t.frequencyMs { 510 t.listener(t.stdout, true) 511 t.stdout = "" 512 t.elapsedMs = 0 513 } 514 } 515 516 func newNotificationWindow(listener Listener, frequencyMs int) *notificationWindow { 517 var now = time.Now() 518 return ¬ificationWindow{ 519 checkpoint: &now, 520 listener: listener, 521 frequencyMs: frequencyMs, 522 } 523 }