github.com/viant/toolbox@v0.34.5/ssh/service.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"github.com/pkg/errors"
     7  	"github.com/viant/toolbox/cred"
     8  	"github.com/viant/toolbox/storage"
     9  	"golang.org/x/crypto/ssh"
    10  	"io"
    11  	"net"
    12  	"os"
    13  	"path"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  )
    18  
    19  type (
    20  	//Service represents ssh service
    21  	Service interface {
    22  		//Service returns a service wrapper
    23  		Client() *ssh.Client
    24  
    25  		//OpenMultiCommandSession opens multi command session
    26  		OpenMultiCommandSession(config *SessionConfig) (MultiCommandSession, error)
    27  
    28  		//Run runs supplied command
    29  		Run(command string) error
    30  
    31  		//Upload uploads provided content to specified destination
    32  		//Deprecated: please consider using https://github.com/viant/afs/tree/master/scp
    33  		Upload(destination string, mode os.FileMode, content []byte) error
    34  
    35  		//Download downloads content from specified source.
    36  		//Deprecated: please consider using https://github.com/viant/afs/tree/master/scp
    37  		Download(source string) ([]byte, error)
    38  
    39  		//OpenTunnel opens a tunnel between local to remote for network traffic.
    40  		OpenTunnel(localAddress, remoteAddress string) error
    41  
    42  		NewSession() (*ssh.Session, error)
    43  
    44  		Close() error
    45  	}
    46  )
    47  
    48  //service represnt SSH service
    49  type service struct {
    50  	host           string
    51  	client         *ssh.Client
    52  	forwarding     []*Tunnel
    53  	replayCommands *ReplayCommands
    54  	recordSession  bool
    55  	config         *ssh.ClientConfig
    56  }
    57  
    58  //Service returns undelying ssh Service
    59  func (c *service) Client() *ssh.Client {
    60  	return c.client
    61  }
    62  
    63  //Service returns undelying ssh Service
    64  func (c *service) NewSession() (*ssh.Session, error) {
    65  	return c.client.NewSession()
    66  }
    67  
    68  //MultiCommandSession create a new MultiCommandSession
    69  func (c *service) OpenMultiCommandSession(config *SessionConfig) (MultiCommandSession, error) {
    70  	return newMultiCommandSession(c, config, c.replayCommands, c.recordSession)
    71  }
    72  
    73  func (c *service) Run(command string) error {
    74  	session, err := c.client.NewSession()
    75  	if err != nil {
    76  		panic("failed to create session: " + err.Error())
    77  	}
    78  	defer session.Close()
    79  	return session.Run(command)
    80  }
    81  
    82  func (c *service) transferData(payload []byte, createFileCmd string, writer io.Writer, errors chan error, waitGroup *sync.WaitGroup) {
    83  	const endSequence = "\x00"
    84  	defer waitGroup.Done()
    85  	_, err := fmt.Fprint(writer, createFileCmd)
    86  	if err != nil {
    87  		errors <- err
    88  		return
    89  	}
    90  	_, err = io.Copy(writer, bytes.NewReader(payload))
    91  	if err != nil {
    92  		errors <- err
    93  		return
    94  	}
    95  	if _, err = fmt.Fprint(writer, endSequence); err != nil {
    96  		errors <- err
    97  		return
    98  	}
    99  }
   100  
   101  type Errors chan error
   102  
   103  func (e Errors) GetError() error {
   104  	select {
   105  	case err := <-e:
   106  		return err
   107  	case <-time.After(time.Millisecond):
   108  	}
   109  	return nil
   110  }
   111  
   112  const operationSuccessful = 0
   113  
   114  func checkOutput(reader io.Reader, errorChannel Errors) {
   115  	writer := new(bytes.Buffer)
   116  	io.Copy(writer, reader)
   117  	if writer.Len() > 1 {
   118  		data := writer.Bytes()
   119  		if data[1] == operationSuccessful {
   120  			return
   121  		} else if len(data) > 2 {
   122  			errorChannel <- errors.New(string(data[2:]))
   123  		}
   124  	}
   125  }
   126  
   127  //Upload uploads passed in content into remote destination
   128  func (c *service) Upload(destination string, mode os.FileMode, content []byte) (err error) {
   129  	err = c.upload(destination, mode, content)
   130  
   131  	if err != nil {
   132  		if strings.Contains(err.Error(), "No such file or directory") {
   133  			dir, _ := path.Split(destination)
   134  			c.Run("mkdir -p " + dir)
   135  			return c.upload(destination, mode, content)
   136  		} else if strings.Contains(err.Error(), "handshake") || strings.Contains(err.Error(), "connection") {
   137  
   138  			time.Sleep(500 * time.Millisecond)
   139  			fmt.Printf("got error %v\n", err)
   140  			c.Reconnect()
   141  			return c.upload(destination, mode, content)
   142  		}
   143  	}
   144  	return err
   145  }
   146  
   147  func (c *service) getSession() (*ssh.Session, error) {
   148  	return c.client.NewSession()
   149  }
   150  
   151  //Upload uploads passed in content into remote destination
   152  func (c *service) upload(destination string, mode os.FileMode, content []byte) (err error) {
   153  	dir, file := path.Split(destination)
   154  	if mode == 0 {
   155  		mode = 0644
   156  	}
   157  	waitGroup := &sync.WaitGroup{}
   158  	waitGroup.Add(1)
   159  	if strings.HasPrefix(file, "/") {
   160  		file = string(file[1:])
   161  	}
   162  	session, err := c.getSession()
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	writer, err := session.StdinPipe()
   168  	if err != nil {
   169  		return errors.Wrap(err, "failed to acquire stdin")
   170  	}
   171  	defer writer.Close()
   172  
   173  	var transferError Errors = make(chan error, 1)
   174  	defer close(transferError)
   175  	var sessionError Errors = make(chan error, 1)
   176  	defer close(sessionError)
   177  	output, err := session.StdoutPipe()
   178  	if err != nil {
   179  		return errors.Wrap(err, "failed to acquire stdout")
   180  	}
   181  	go checkOutput(output, sessionError)
   182  
   183  	if mode >= 01000 {
   184  		mode = storage.DefaultFileMode
   185  	}
   186  	fileMode := string(fmt.Sprintf("C%04o", mode)[:5])
   187  	createFileCmd := fmt.Sprintf("%v %d %s\n", fileMode, len(content), file)
   188  	go c.transferData(content, createFileCmd, writer, transferError, waitGroup)
   189  	scpCommand := "scp -qtr " + dir
   190  	err = session.Start(scpCommand)
   191  	if err != nil {
   192  		return err
   193  	}
   194  	waitGroup.Wait()
   195  	writerErr := writer.Close()
   196  	if err := sessionError.GetError(); err != nil {
   197  		return err
   198  	}
   199  	if err := transferError.GetError(); err != nil {
   200  		return err
   201  	}
   202  	if err = session.Wait(); err != nil {
   203  		if err := sessionError.GetError(); err != nil {
   204  			return err
   205  		}
   206  		return err
   207  	}
   208  	return writerErr
   209  }
   210  
   211  //Download download passed source file from remote host.
   212  func (c *service) Download(source string) ([]byte, error) {
   213  	session, err := c.client.NewSession()
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	defer session.Close()
   218  	return session.Output(fmt.Sprintf("cat %s", source))
   219  }
   220  
   221  //Host returns client host
   222  func (c *service) Host() string {
   223  	return c.host
   224  }
   225  
   226  //Close closes service
   227  func (c *service) Close() error {
   228  	if len(c.forwarding) > 0 {
   229  		for _, forwarding := range c.forwarding {
   230  			_ = forwarding.Close()
   231  		}
   232  	}
   233  	return c.client.Close()
   234  }
   235  
   236  //Reconnect client
   237  func (c *service) Reconnect() error {
   238  	return c.connect()
   239  }
   240  
   241  //OpenTunnel tunnels data between localAddress and remoteAddress on ssh connection
   242  func (c *service) OpenTunnel(localAddress, remoteAddress string) error {
   243  	local, err := net.Listen("tcp", localAddress)
   244  	if err != nil {
   245  		return errors.Wrap(err, fmt.Sprintf("failed to listen on local: %v %v", localAddress))
   246  	}
   247  	var forwarding = NewForwarding(c.client, remoteAddress, local)
   248  	if len(c.forwarding) == 0 {
   249  		c.forwarding = make([]*Tunnel, 0)
   250  	}
   251  	c.forwarding = append(c.forwarding, forwarding)
   252  	go forwarding.Handle()
   253  	return nil
   254  }
   255  
   256  func (c *service) connect() (err error) {
   257  	if c.client, err = ssh.Dial("tcp", c.host, c.config); err != nil {
   258  		return errors.Wrap(err, fmt.Sprintf("failed to dial %v: %s", c.host))
   259  	}
   260  	return nil
   261  }
   262  
   263  //NewService create a new ssh service, it takes host port and authentication config
   264  func NewService(host string, port int, authConfig *cred.Config) (Service, error) {
   265  	if authConfig == nil {
   266  		authConfig = &cred.Config{}
   267  	}
   268  	clientConfig, err := authConfig.ClientConfig()
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  	var result = &service{
   273  		host:   fmt.Sprintf("%s:%d", host, port),
   274  		config: clientConfig,
   275  	}
   276  	return result, result.connect()
   277  }