github.com/viant/toolbox@v0.34.5/ssh/session.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"github.com/lunixbochs/vtclean"
     7  	"github.com/pkg/errors"
     8  	"github.com/viant/toolbox"
     9  	"golang.org/x/crypto/ssh"
    10  	"io"
    11  	"path"
    12  	"strings"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  )
    17  
    18  type TerminatedError struct {
    19  	Err error
    20  }
    21  
    22  func (t *TerminatedError) Error() string {
    23  	return fmt.Sprintf("terminated due to %v", t.Err)
    24  }
    25  
    26  //ErrTerminated - command session terminated
    27  var ErrTerminated = &TerminatedError{}
    28  
    29  const defaultShell = "/bin/bash"
    30  
    31  const (
    32  	drainTimeoutMs         = 10
    33  	defaultTimeoutMs       = 20000
    34  	stdoutFlashFrequencyMs = 1000
    35  	initTimeoutMs          = 300
    36  	defaultTickFrequency   = 100
    37  )
    38  
    39  //Listener represent command listener (it will send stdout fragments as thier being available on stdout)
    40  type Listener func(stdout string, hasMore bool)
    41  
    42  //MultiCommandSession represents a multi command session
    43  type MultiCommandSession interface {
    44  	Run(command string, listener Listener, timeoutMs int, terminators ...string) (string, error)
    45  
    46  	ShellPrompt() string
    47  
    48  	System() string
    49  
    50  	Reconnect() error
    51  
    52  	Close()
    53  }
    54  
    55  //multiCommandSession represents a multi command session
    56  //a new command are send vi stdin
    57  type multiCommandSession struct {
    58  	service            *service
    59  	config             *SessionConfig
    60  	replayCommands     *ReplayCommands
    61  	recordSession      bool
    62  	session            *ssh.Session
    63  	stdOutput          chan string
    64  	stdError           chan string
    65  	stdInput           io.WriteCloser
    66  	promptSequence     string
    67  	shellPrompt        string
    68  	escapedShellPrompt string
    69  	system             string
    70  	running            int32
    71  	stdin              string
    72  }
    73  
    74  func (s *multiCommandSession) Run(command string, listener Listener, timeoutMs int, terminators ...string) (string, error) {
    75  	if atomic.LoadInt32(&s.running) == 0 {
    76  		return "", ErrTerminated
    77  	}
    78  	s.drainStdout()
    79  	if !strings.HasSuffix(command, "\n") {
    80  		command += "\n"
    81  	}
    82  	var stdin = command
    83  	s.stdin = stdin
    84  	_, err := s.stdInput.Write([]byte(stdin))
    85  	if err != nil {
    86  		return "", fmt.Errorf("failed to execute command: %v, err: %v", command, err)
    87  	}
    88  	var output string
    89  	output, _, err = s.readResponse(timeoutMs, listener, terminators...)
    90  	if s.recordSession {
    91  		s.replayCommands.Register(stdin, output)
    92  	}
    93  	return output, err
    94  }
    95  
    96  //ShellPrompt returns a shell prompt
    97  func (s *multiCommandSession) ShellPrompt() string {
    98  	return s.shellPrompt
    99  }
   100  
   101  //System returns a system name
   102  func (s *multiCommandSession) System() string {
   103  	return s.system
   104  }
   105  
   106  //Close closes the session with its resources
   107  func (s *multiCommandSession) Close() {
   108  	atomic.StoreInt32(&s.running, 0)
   109  	_ = s.stdInput.Close()
   110  	if s.session != nil {
   111  		_ = s.session.Close()
   112  	}
   113  
   114  }
   115  
   116  func (s *multiCommandSession) closeIfError(err error) bool {
   117  	if err != nil {
   118  		if err != ErrTerminated {
   119  			ErrTerminated.Err = err
   120  		}
   121  		s.Close()
   122  		return true
   123  	}
   124  	return false
   125  }
   126  
   127  func (s *multiCommandSession) start(shell string) (output string, err error) {
   128  	var reader, errReader io.Reader
   129  	reader, err = s.session.StdoutPipe()
   130  	if err != nil {
   131  		return "", err
   132  	}
   133  	errReader, err = s.session.StderrPipe()
   134  	if err != nil {
   135  		return "", err
   136  	}
   137  
   138  	waitGroup := sync.WaitGroup{}
   139  	waitGroup.Add(2)
   140  	go func() {
   141  		waitGroup.Done()
   142  		s.copy(reader, s.stdOutput)
   143  	}()
   144  	go func() {
   145  		waitGroup.Done()
   146  		s.copy(errReader, s.stdError)
   147  	}()
   148  
   149  	if shell == "" {
   150  		shell = defaultShell
   151  	}
   152  	waitGroup.Wait()
   153  	s.stdin = shell
   154  	err = s.session.Start(shell)
   155  	if err != nil {
   156  		return "", err
   157  	}
   158  	_, name := path.Split(shell)
   159  	output, _, err = s.readResponse(defaultTimeoutMs, nil, name)
   160  	return output, err
   161  }
   162  
   163  //copy copy data from reader to channel
   164  func (s *multiCommandSession) copy(reader io.Reader, out chan string) {
   165  	var written int64 = 0
   166  	buf := make([]byte, 128*1024)
   167  	var err error
   168  	var bytesRead int
   169  	for {
   170  		writer := new(bytes.Buffer)
   171  		if atomic.LoadInt32(&s.running) == 0 {
   172  			return
   173  		}
   174  		bytesRead, err = reader.Read(buf)
   175  		if bytesRead > 0 {
   176  			bytesWritten, writeError := writer.Write(buf[:bytesRead])
   177  			if s.closeIfError(writeError) {
   178  				return
   179  			}
   180  			if bytesWritten > 0 {
   181  				written += int64(bytesWritten)
   182  			}
   183  
   184  			if bytesRead != bytesWritten {
   185  				if s.closeIfError(io.ErrShortWrite) {
   186  					return
   187  				}
   188  			}
   189  			out <- string(writer.Bytes())
   190  		}
   191  
   192  		if s.closeIfError(err) {
   193  			return
   194  		}
   195  	}
   196  }
   197  
   198  func escapeInput(input string) string {
   199  	input = vtclean.Clean(input, false)
   200  	if input == "" {
   201  		return input
   202  	}
   203  	return strings.Trim(input, "\n\r\t ")
   204  }
   205  
   206  func (s *multiCommandSession) Reconnect() (err error) {
   207  	atomic.StoreInt32(&s.running, 1)
   208  	s.service.Reconnect()
   209  	s.session, err = s.service.client.NewSession()
   210  	defer func() {
   211  		if err != nil {
   212  			s.service.client.Close()
   213  		}
   214  	}()
   215  	if err != nil {
   216  		return err
   217  	}
   218  	return s.init()
   219  }
   220  
   221  func (s *multiCommandSession) hasPrompt(input string) bool {
   222  	escapedInput := escapeInput(input)
   223  	var shellPrompt = s.shellPrompt
   224  	if shellPrompt == "" {
   225  		shellPrompt = "$"
   226  	}
   227  	if s.escapedShellPrompt == "" && s.shellPrompt != "" {
   228  		s.escapedShellPrompt = escapeInput(s.shellPrompt)
   229  	}
   230  
   231  	if s.escapedShellPrompt != "" && strings.HasSuffix(escapedInput, s.escapedShellPrompt) || strings.HasSuffix(input, s.shellPrompt) {
   232  		return true
   233  	}
   234  	return false
   235  }
   236  
   237  func (s *multiCommandSession) hasTerminator(input string, terminators ...string) bool {
   238  	escapedInput := escapeInput(input)
   239  	input = escapedInput
   240  	for _, candidate := range terminators {
   241  		candidateLen := len(candidate)
   242  		if candidateLen == 0 {
   243  			continue
   244  		}
   245  		if candidate[0:1] == "^" && strings.HasPrefix(input, candidate[1:]) {
   246  			return true
   247  		}
   248  		if candidate[candidateLen-1:] == "$" && strings.HasSuffix(input, candidate[:candidateLen-1]) {
   249  			return true
   250  		}
   251  		if strings.Contains(input, candidate) {
   252  			return true
   253  		}
   254  	}
   255  	return false
   256  }
   257  
   258  func (s *multiCommandSession) removePromptIfNeeded(stdout string) string {
   259  	if strings.Contains(stdout, s.shellPrompt) {
   260  		stdout = strings.Replace(stdout, s.shellPrompt, "", 1)
   261  		var lines = []string{}
   262  		for _, line := range strings.Split(stdout, "\n") {
   263  			if strings.TrimSpace(line) == "" {
   264  				continue
   265  			}
   266  			lines = append(lines, line)
   267  		}
   268  		stdout = strings.Join(lines, "\n")
   269  	}
   270  	return stdout
   271  }
   272  
   273  func (s *multiCommandSession) readResponse(timeoutMs int, listener Listener, terminators ...string) (out string, has bool, err error) {
   274  	var hasPrompt, hasTerminator bool
   275  	if timeoutMs == 0 {
   276  		timeoutMs = defaultTimeoutMs
   277  	}
   278  	notification := newNotificationWindow(listener, stdoutFlashFrequencyMs)
   279  	defer notification.flush()
   280  
   281  	var done int32
   282  	defer atomic.StoreInt32(&done, 1)
   283  	var errOut string
   284  	var hasOutput bool
   285  
   286  	var waitTimeMs = 0
   287  	var tickFrequencyMs = defaultTickFrequency
   288  	if tickFrequencyMs > timeoutMs {
   289  		tickFrequencyMs = timeoutMs
   290  	}
   291  	var timeoutDuration = time.Duration(tickFrequencyMs) * time.Millisecond
   292  
   293  outer:
   294  	for {
   295  		select {
   296  		case partialOutput := <-s.stdOutput:
   297  			waitTimeMs = 0
   298  			out += partialOutput
   299  			hasTerminator = s.hasTerminator(out, terminators...)
   300  			if len(partialOutput) > 0 {
   301  				if hasTerminator {
   302  					partialOutput = addLineBreakIfNeeded(partialOutput)
   303  				}
   304  				notification.notify(s.removePromptIfNeeded(partialOutput))
   305  			}
   306  			hasPrompt = s.hasPrompt(out)
   307  			if (hasPrompt || hasTerminator) && len(s.stdOutput) == 0 {
   308  				break outer
   309  			}
   310  		case e := <-s.stdError:
   311  			errOut += e
   312  			notification.notify(s.removePromptIfNeeded(e))
   313  			hasPrompt = s.hasPrompt(errOut)
   314  			hasTerminator = s.hasTerminator(errOut, terminators...)
   315  			if (hasPrompt || hasTerminator) && len(s.stdOutput) == 0 {
   316  				break outer
   317  			}
   318  		case <-time.After(timeoutDuration):
   319  			waitTimeMs += tickFrequencyMs
   320  			if waitTimeMs >= timeoutMs {
   321  				break outer
   322  			}
   323  		}
   324  	}
   325  
   326  	if hasTerminator {
   327  		s.drainStdout()
   328  
   329  	}
   330  	if errOut != "" {
   331  		err = errors.New(errOut)
   332  	}
   333  
   334  	if len(out) > 0 {
   335  		hasOutput = true
   336  		out = s.removePromptIfNeeded(out)
   337  	}
   338  	return out, hasOutput, err
   339  }
   340  func addLineBreakIfNeeded(text string) string {
   341  	index := strings.LastIndex(text, "\n")
   342  	if index == -1 {
   343  		return text + "\n"
   344  	}
   345  	lastFragment := string(text[index:])
   346  	if strings.TrimSpace(lastFragment) != "" {
   347  		return text + "\n"
   348  	}
   349  	return text
   350  }
   351  
   352  func (s *multiCommandSession) drainStdout() {
   353  	//read any outstanding output
   354  	for {
   355  		_, has, _ := s.readResponse(drainTimeoutMs, nil, "")
   356  		if !has {
   357  			return
   358  		}
   359  	}
   360  }
   361  
   362  func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
   363  	c := make(chan bool)
   364  	go func() {
   365  		defer close(c)
   366  		wg.Wait()
   367  		c <- true
   368  	}()
   369  	select {
   370  	case <-c:
   371  		return false // completed normally
   372  	case <-time.After(timeout):
   373  		return true // timed out
   374  	}
   375  }
   376  
   377  func (s *multiCommandSession) shellInit() (err error) {
   378  	if s.promptSequence != "" {
   379  		if _, err = s.Run(s.promptSequence, nil, initTimeoutMs); err != nil {
   380  			return err
   381  		}
   382  	}
   383  
   384  	var ts = toolbox.AsString(time.Now().UnixNano())
   385  	var waitGroup = &sync.WaitGroup{}
   386  	waitGroup.Add(1)
   387  
   388  	if s.config.Shell == defaultShell {
   389  		s.promptSequence = "PS1=\"" + ts + "\\$\""
   390  		s.shellPrompt = ts + "$"
   391  		s.escapedShellPrompt = escapeInput(s.shellPrompt)
   392  	}
   393  
   394  	var listener Listener
   395  	listener = func(stdout string, hasMore bool) {
   396  		if !hasMore {
   397  			waitGroup.Done()
   398  		}
   399  	}
   400  
   401  	_, err = s.Run("", listener, initTimeoutMs, "$")
   402  
   403  	waitTimeout(waitGroup, 60*time.Second)
   404  	s.drainStdout()
   405  	_, err = s.Run(s.promptSequence, nil, defaultTimeoutMs, "$")
   406  	if s.closeIfError(err) {
   407  		return err
   408  	}
   409  	for i := 0; i < 3; i++ {
   410  		s.system, err = s.Run("uname -s", nil, initTimeoutMs)
   411  		s.system = strings.ToLower(strings.TrimSpace(s.system))
   412  		if strings.TrimSpace(s.system) != "" && !strings.Contains(s.system, "$") {
   413  			break
   414  		}
   415  	}
   416  	s.drainStdout()
   417  	return nil
   418  }
   419  
   420  func (s *multiCommandSession) init() (err error) {
   421  	s.session, err = s.service.client.NewSession()
   422  	defer func() {
   423  		if err != nil {
   424  			s.service.client.Close()
   425  		}
   426  	}()
   427  	s.stdOutput = make(chan string)
   428  	s.stdError = make(chan string)
   429  	for k, v := range s.config.EnvVariables {
   430  		err = s.session.Setenv(k, v)
   431  		if err != nil {
   432  			return err
   433  		}
   434  	}
   435  	modes := ssh.TerminalModes{
   436  		ssh.ECHO:          0,     // disable echoing
   437  		ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
   438  		ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
   439  	}
   440  
   441  	if err := s.session.RequestPty(s.config.Term, s.config.Rows, s.config.Columns, modes); err != nil {
   442  		return err
   443  	}
   444  
   445  	if s.stdInput, err = s.session.StdinPipe(); err != nil {
   446  		return err
   447  	}
   448  
   449  	stdout, err := s.start(s.config.Shell)
   450  	if s.closeIfError(err) {
   451  		return err
   452  	}
   453  	if err = checkNotFound(stdout); err != nil {
   454  		return fmt.Errorf("failed to open %v shell, %v", s.config.Shell, err)
   455  	}
   456  	return s.shellInit()
   457  }
   458  
   459  func checkNotFound(output string) error {
   460  	if strings.Contains(output, "not found") {
   461  		return fmt.Errorf("failed run %s", output)
   462  	}
   463  	return nil
   464  }
   465  
   466  func newMultiCommandSession(service *service, config *SessionConfig, replayCommands *ReplayCommands, recordSession bool) (MultiCommandSession, error) {
   467  	if config == nil {
   468  		config = &SessionConfig{}
   469  	}
   470  	config.applyDefault()
   471  
   472  	result := &multiCommandSession{
   473  		service:        service,
   474  		config:         config,
   475  		running:        1,
   476  		recordSession:  recordSession,
   477  		replayCommands: replayCommands,
   478  	}
   479  	return result, result.init()
   480  }
   481  
   482  type notificationWindow struct {
   483  	checkpoint  *time.Time
   484  	listener    Listener
   485  	elapsedMs   int
   486  	stdout      string
   487  	frequencyMs int
   488  }
   489  
   490  func (t *notificationWindow) flush() {
   491  	if t.listener == nil {
   492  		return
   493  	}
   494  	if t.stdout != "" {
   495  		t.listener(t.stdout, true)
   496  	}
   497  
   498  	t.listener("", false)
   499  }
   500  
   501  func (t *notificationWindow) notify(stdout string) {
   502  	var now = time.Now()
   503  	if t.listener == nil {
   504  		return
   505  	}
   506  	t.stdout += stdout
   507  	t.elapsedMs += int(now.Sub(*t.checkpoint) / time.Millisecond)
   508  	t.checkpoint = &now
   509  	if t.elapsedMs > t.frequencyMs {
   510  		t.listener(t.stdout, true)
   511  		t.stdout = ""
   512  		t.elapsedMs = 0
   513  	}
   514  }
   515  
   516  func newNotificationWindow(listener Listener, frequencyMs int) *notificationWindow {
   517  	var now = time.Now()
   518  	return &notificationWindow{
   519  		checkpoint:  &now,
   520  		listener:    listener,
   521  		frequencyMs: frequencyMs,
   522  	}
   523  }