github.com/rudderlabs/rudder-go-kit@v0.30.0/sftp/client.go (about) 1 //go:generate mockgen -destination=mock_sftp/mock_sftp_client.go -package mock_sftp github.com/rudderlabs/rudder-go-kit/sftp Client 2 package sftp 3 4 import ( 5 "errors" 6 "fmt" 7 "io" 8 "time" 9 10 "github.com/pkg/sftp" 11 "golang.org/x/crypto/ssh" 12 ) 13 14 // SSHConfig represents the configuration for SSH connection 15 type SSHConfig struct { 16 HostName string 17 Port int 18 User string 19 AuthMethod string 20 PrivateKey string 21 Password string // Password for password-based authentication 22 DialTimeout time.Duration 23 } 24 25 // sshClientConfig constructs an SSH client configuration based on the provided SSHConfig. 26 func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) { 27 if config == nil { 28 return nil, errors.New("config should not be nil") 29 } 30 31 if config.HostName == "" { 32 return nil, errors.New("hostname should not be empty") 33 } 34 35 if config.Port == 0 { 36 return nil, errors.New("port should not be empty") 37 } 38 39 if config.User == "" { 40 return nil, errors.New("user should not be empty") 41 } 42 43 var authMethods ssh.AuthMethod 44 45 switch config.AuthMethod { 46 case PasswordAuth: 47 authMethods = ssh.Password(config.Password) 48 case KeyAuth: 49 privateKey, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)) 50 if err != nil { 51 return nil, fmt.Errorf("cannot parse private key: %w", err) 52 } 53 authMethods = ssh.PublicKeys(privateKey) 54 default: 55 return nil, errors.New("unsupported authentication method") 56 } 57 58 sshConfig := &ssh.ClientConfig{ 59 User: config.User, 60 Auth: []ssh.AuthMethod{authMethods}, 61 Timeout: config.DialTimeout, 62 HostKeyCallback: ssh.InsecureIgnoreHostKey(), 63 } 64 65 return sshConfig, nil 66 } 67 68 // NewSSHClient establishes an SSH connection and returns an SSH client 69 func NewSSHClient(config *SSHConfig) (*ssh.Client, error) { 70 sshConfig, err := sshClientConfig(config) 71 if err != nil { 72 return nil, fmt.Errorf("cannot configure SSH client: %w", err) 73 } 74 75 sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", config.HostName, config.Port), sshConfig) 76 if err != nil { 77 return nil, fmt.Errorf("cannot dial SSH host %q:%d: %w", config.HostName, config.Port, err) 78 } 79 return sshClient, nil 80 } 81 82 type clientImpl struct { 83 client *sftp.Client 84 } 85 86 type Client interface { 87 OpenFile(path string, f int) (io.ReadWriteCloser, error) 88 Remove(path string) error 89 MkdirAll(path string) error 90 } 91 92 // newSFTPClient creates an SFTP client with existing SSH client 93 func newSFTPClient(client *ssh.Client) (Client, error) { 94 sftpClient, err := sftp.NewClient(client) 95 if err != nil { 96 return nil, fmt.Errorf("cannot create SFTP client: %w", err) 97 } 98 return &clientImpl{ 99 client: sftpClient, 100 }, nil 101 } 102 103 func (c *clientImpl) OpenFile(path string, f int) (io.ReadWriteCloser, error) { 104 return c.client.OpenFile(path, f) 105 } 106 107 func (c *clientImpl) Remove(path string) error { 108 return c.client.Remove(path) 109 } 110 111 func (c *clientImpl) MkdirAll(path string) error { 112 return c.client.MkdirAll(path) 113 }