github.com/secure-build/gitlab-runner@v12.5.0+incompatible/helpers/ssh/ssh_command.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"io/ioutil"
     9  	"strings"
    10  	"time"
    11  
    12  	"golang.org/x/crypto/ssh"
    13  
    14  	"gitlab.com/gitlab-org/gitlab-runner/helpers"
    15  )
    16  
    17  type Client struct {
    18  	Config
    19  
    20  	Stdout         io.Writer
    21  	Stderr         io.Writer
    22  	ConnectRetries int
    23  
    24  	client *ssh.Client
    25  }
    26  
    27  type Command struct {
    28  	Environment []string
    29  	Command     []string
    30  	Stdin       string
    31  }
    32  
    33  type ExitError struct {
    34  	Inner error
    35  }
    36  
    37  func (e *ExitError) Error() string {
    38  	if e.Inner == nil {
    39  		return "error"
    40  	}
    41  	return e.Inner.Error()
    42  }
    43  
    44  func (s *Client) getSSHKey(identityFile string) (key ssh.Signer, err error) {
    45  	buf, err := ioutil.ReadFile(identityFile)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	key, err = ssh.ParsePrivateKey(buf)
    50  	return key, err
    51  }
    52  
    53  func (s *Client) getSSHAuthMethods() ([]ssh.AuthMethod, error) {
    54  	var methods []ssh.AuthMethod
    55  	methods = append(methods, ssh.Password(s.Password))
    56  
    57  	if s.IdentityFile != "" {
    58  		key, err := s.getSSHKey(s.IdentityFile)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  		methods = append(methods, ssh.PublicKeys(key))
    63  	}
    64  
    65  	return methods, nil
    66  }
    67  
    68  func (s *Client) Connect() error {
    69  	if s.Host == "" {
    70  		s.Host = "localhost"
    71  	}
    72  	if s.User == "" {
    73  		s.User = "root"
    74  	}
    75  	if s.Port == "" {
    76  		s.Port = "22"
    77  	}
    78  
    79  	methods, err := s.getSSHAuthMethods()
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	config := &ssh.ClientConfig{
    85  		User: s.User,
    86  		Auth: methods,
    87  	}
    88  
    89  	connectRetries := s.ConnectRetries
    90  	if connectRetries == 0 {
    91  		connectRetries = 3
    92  	}
    93  
    94  	var finalError error
    95  
    96  	for i := 0; i < connectRetries; i++ {
    97  		client, err := ssh.Dial("tcp", s.Host+":"+s.Port, config)
    98  		if err == nil {
    99  			s.client = client
   100  			return nil
   101  		}
   102  		time.Sleep(sshRetryInterval * time.Second)
   103  		finalError = err
   104  	}
   105  
   106  	return finalError
   107  }
   108  
   109  func (s *Client) Exec(cmd string) error {
   110  	if s.client == nil {
   111  		return errors.New("Not connected")
   112  	}
   113  
   114  	session, err := s.client.NewSession()
   115  	if err != nil {
   116  		return err
   117  	}
   118  	session.Stdout = s.Stdout
   119  	session.Stderr = s.Stderr
   120  	err = session.Run(cmd)
   121  	session.Close()
   122  	return err
   123  }
   124  
   125  func (s *Command) fullCommand() string {
   126  	var arguments []string
   127  	// TODO: This method is compatible only with Bjourne compatible shells
   128  	for _, part := range s.Command {
   129  		arguments = append(arguments, helpers.ShellEscape(part))
   130  	}
   131  	return strings.Join(arguments, " ")
   132  }
   133  
   134  func (s *Client) Run(ctx context.Context, cmd Command) error {
   135  	if s.client == nil {
   136  		return errors.New("Not connected")
   137  	}
   138  
   139  	session, err := s.client.NewSession()
   140  	if err != nil {
   141  		return err
   142  	}
   143  	defer session.Close()
   144  
   145  	var envVariables bytes.Buffer
   146  	for _, keyValue := range cmd.Environment {
   147  		envVariables.WriteString("export " + helpers.ShellEscape(keyValue) + "\n")
   148  	}
   149  
   150  	session.Stdin = io.MultiReader(
   151  		&envVariables,
   152  		bytes.NewBufferString(cmd.Stdin),
   153  	)
   154  	session.Stdout = s.Stdout
   155  	session.Stderr = s.Stderr
   156  	err = session.Start(cmd.fullCommand())
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	waitCh := make(chan error)
   162  	go func() {
   163  		err := session.Wait()
   164  		if _, ok := err.(*ssh.ExitError); ok {
   165  			err = &ExitError{Inner: err}
   166  		}
   167  		waitCh <- err
   168  	}()
   169  
   170  	select {
   171  	case <-ctx.Done():
   172  		session.Signal(ssh.SIGKILL)
   173  		session.Close()
   174  		return <-waitCh
   175  
   176  	case err := <-waitCh:
   177  		return err
   178  	}
   179  }
   180  
   181  func (s *Client) Cleanup() {
   182  	if s.client != nil {
   183  		s.client.Close()
   184  	}
   185  }