github.com/viant/toolbox@v0.34.5/ssh/service.go (about) 1 package ssh 2 3 import ( 4 "bytes" 5 "fmt" 6 "github.com/pkg/errors" 7 "github.com/viant/toolbox/cred" 8 "github.com/viant/toolbox/storage" 9 "golang.org/x/crypto/ssh" 10 "io" 11 "net" 12 "os" 13 "path" 14 "strings" 15 "sync" 16 "time" 17 ) 18 19 type ( 20 //Service represents ssh service 21 Service interface { 22 //Service returns a service wrapper 23 Client() *ssh.Client 24 25 //OpenMultiCommandSession opens multi command session 26 OpenMultiCommandSession(config *SessionConfig) (MultiCommandSession, error) 27 28 //Run runs supplied command 29 Run(command string) error 30 31 //Upload uploads provided content to specified destination 32 //Deprecated: please consider using https://github.com/viant/afs/tree/master/scp 33 Upload(destination string, mode os.FileMode, content []byte) error 34 35 //Download downloads content from specified source. 36 //Deprecated: please consider using https://github.com/viant/afs/tree/master/scp 37 Download(source string) ([]byte, error) 38 39 //OpenTunnel opens a tunnel between local to remote for network traffic. 40 OpenTunnel(localAddress, remoteAddress string) error 41 42 NewSession() (*ssh.Session, error) 43 44 Close() error 45 } 46 ) 47 48 //service represnt SSH service 49 type service struct { 50 host string 51 client *ssh.Client 52 forwarding []*Tunnel 53 replayCommands *ReplayCommands 54 recordSession bool 55 config *ssh.ClientConfig 56 } 57 58 //Service returns undelying ssh Service 59 func (c *service) Client() *ssh.Client { 60 return c.client 61 } 62 63 //Service returns undelying ssh Service 64 func (c *service) NewSession() (*ssh.Session, error) { 65 return c.client.NewSession() 66 } 67 68 //MultiCommandSession create a new MultiCommandSession 69 func (c *service) OpenMultiCommandSession(config *SessionConfig) (MultiCommandSession, error) { 70 return newMultiCommandSession(c, config, c.replayCommands, c.recordSession) 71 } 72 73 func (c *service) Run(command string) error { 74 session, err := c.client.NewSession() 75 if err != nil { 76 panic("failed to create session: " + err.Error()) 77 } 78 defer session.Close() 79 return session.Run(command) 80 } 81 82 func (c *service) transferData(payload []byte, createFileCmd string, writer io.Writer, errors chan error, waitGroup *sync.WaitGroup) { 83 const endSequence = "\x00" 84 defer waitGroup.Done() 85 _, err := fmt.Fprint(writer, createFileCmd) 86 if err != nil { 87 errors <- err 88 return 89 } 90 _, err = io.Copy(writer, bytes.NewReader(payload)) 91 if err != nil { 92 errors <- err 93 return 94 } 95 if _, err = fmt.Fprint(writer, endSequence); err != nil { 96 errors <- err 97 return 98 } 99 } 100 101 type Errors chan error 102 103 func (e Errors) GetError() error { 104 select { 105 case err := <-e: 106 return err 107 case <-time.After(time.Millisecond): 108 } 109 return nil 110 } 111 112 const operationSuccessful = 0 113 114 func checkOutput(reader io.Reader, errorChannel Errors) { 115 writer := new(bytes.Buffer) 116 io.Copy(writer, reader) 117 if writer.Len() > 1 { 118 data := writer.Bytes() 119 if data[1] == operationSuccessful { 120 return 121 } else if len(data) > 2 { 122 errorChannel <- errors.New(string(data[2:])) 123 } 124 } 125 } 126 127 //Upload uploads passed in content into remote destination 128 func (c *service) Upload(destination string, mode os.FileMode, content []byte) (err error) { 129 err = c.upload(destination, mode, content) 130 131 if err != nil { 132 if strings.Contains(err.Error(), "No such file or directory") { 133 dir, _ := path.Split(destination) 134 c.Run("mkdir -p " + dir) 135 return c.upload(destination, mode, content) 136 } else if strings.Contains(err.Error(), "handshake") || strings.Contains(err.Error(), "connection") { 137 138 time.Sleep(500 * time.Millisecond) 139 fmt.Printf("got error %v\n", err) 140 c.Reconnect() 141 return c.upload(destination, mode, content) 142 } 143 } 144 return err 145 } 146 147 func (c *service) getSession() (*ssh.Session, error) { 148 return c.client.NewSession() 149 } 150 151 //Upload uploads passed in content into remote destination 152 func (c *service) upload(destination string, mode os.FileMode, content []byte) (err error) { 153 dir, file := path.Split(destination) 154 if mode == 0 { 155 mode = 0644 156 } 157 waitGroup := &sync.WaitGroup{} 158 waitGroup.Add(1) 159 if strings.HasPrefix(file, "/") { 160 file = string(file[1:]) 161 } 162 session, err := c.getSession() 163 if err != nil { 164 return err 165 } 166 167 writer, err := session.StdinPipe() 168 if err != nil { 169 return errors.Wrap(err, "failed to acquire stdin") 170 } 171 defer writer.Close() 172 173 var transferError Errors = make(chan error, 1) 174 defer close(transferError) 175 var sessionError Errors = make(chan error, 1) 176 defer close(sessionError) 177 output, err := session.StdoutPipe() 178 if err != nil { 179 return errors.Wrap(err, "failed to acquire stdout") 180 } 181 go checkOutput(output, sessionError) 182 183 if mode >= 01000 { 184 mode = storage.DefaultFileMode 185 } 186 fileMode := string(fmt.Sprintf("C%04o", mode)[:5]) 187 createFileCmd := fmt.Sprintf("%v %d %s\n", fileMode, len(content), file) 188 go c.transferData(content, createFileCmd, writer, transferError, waitGroup) 189 scpCommand := "scp -qtr " + dir 190 err = session.Start(scpCommand) 191 if err != nil { 192 return err 193 } 194 waitGroup.Wait() 195 writerErr := writer.Close() 196 if err := sessionError.GetError(); err != nil { 197 return err 198 } 199 if err := transferError.GetError(); err != nil { 200 return err 201 } 202 if err = session.Wait(); err != nil { 203 if err := sessionError.GetError(); err != nil { 204 return err 205 } 206 return err 207 } 208 return writerErr 209 } 210 211 //Download download passed source file from remote host. 212 func (c *service) Download(source string) ([]byte, error) { 213 session, err := c.client.NewSession() 214 if err != nil { 215 return nil, err 216 } 217 defer session.Close() 218 return session.Output(fmt.Sprintf("cat %s", source)) 219 } 220 221 //Host returns client host 222 func (c *service) Host() string { 223 return c.host 224 } 225 226 //Close closes service 227 func (c *service) Close() error { 228 if len(c.forwarding) > 0 { 229 for _, forwarding := range c.forwarding { 230 _ = forwarding.Close() 231 } 232 } 233 return c.client.Close() 234 } 235 236 //Reconnect client 237 func (c *service) Reconnect() error { 238 return c.connect() 239 } 240 241 //OpenTunnel tunnels data between localAddress and remoteAddress on ssh connection 242 func (c *service) OpenTunnel(localAddress, remoteAddress string) error { 243 local, err := net.Listen("tcp", localAddress) 244 if err != nil { 245 return errors.Wrap(err, fmt.Sprintf("failed to listen on local: %v %v", localAddress)) 246 } 247 var forwarding = NewForwarding(c.client, remoteAddress, local) 248 if len(c.forwarding) == 0 { 249 c.forwarding = make([]*Tunnel, 0) 250 } 251 c.forwarding = append(c.forwarding, forwarding) 252 go forwarding.Handle() 253 return nil 254 } 255 256 func (c *service) connect() (err error) { 257 if c.client, err = ssh.Dial("tcp", c.host, c.config); err != nil { 258 return errors.Wrap(err, fmt.Sprintf("failed to dial %v: %s", c.host)) 259 } 260 return nil 261 } 262 263 //NewService create a new ssh service, it takes host port and authentication config 264 func NewService(host string, port int, authConfig *cred.Config) (Service, error) { 265 if authConfig == nil { 266 authConfig = &cred.Config{} 267 } 268 clientConfig, err := authConfig.ClientConfig() 269 if err != nil { 270 return nil, err 271 } 272 var result = &service{ 273 host: fmt.Sprintf("%s:%d", host, port), 274 config: clientConfig, 275 } 276 return result, result.connect() 277 }