github.com/rudderlabs/rudder-go-kit@v0.30.0/sftp/client.go (about)

     1  //go:generate mockgen -destination=mock_sftp/mock_sftp_client.go -package mock_sftp github.com/rudderlabs/rudder-go-kit/sftp Client
     2  package sftp
     3  
     4  import (
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"time"
     9  
    10  	"github.com/pkg/sftp"
    11  	"golang.org/x/crypto/ssh"
    12  )
    13  
    14  // SSHConfig represents the configuration for SSH connection
    15  type SSHConfig struct {
    16  	HostName    string
    17  	Port        int
    18  	User        string
    19  	AuthMethod  string
    20  	PrivateKey  string
    21  	Password    string // Password for password-based authentication
    22  	DialTimeout time.Duration
    23  }
    24  
    25  // sshClientConfig constructs an SSH client configuration based on the provided SSHConfig.
    26  func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) {
    27  	if config == nil {
    28  		return nil, errors.New("config should not be nil")
    29  	}
    30  
    31  	if config.HostName == "" {
    32  		return nil, errors.New("hostname should not be empty")
    33  	}
    34  
    35  	if config.Port == 0 {
    36  		return nil, errors.New("port should not be empty")
    37  	}
    38  
    39  	if config.User == "" {
    40  		return nil, errors.New("user should not be empty")
    41  	}
    42  
    43  	var authMethods ssh.AuthMethod
    44  
    45  	switch config.AuthMethod {
    46  	case PasswordAuth:
    47  		authMethods = ssh.Password(config.Password)
    48  	case KeyAuth:
    49  		privateKey, err := ssh.ParsePrivateKey([]byte(config.PrivateKey))
    50  		if err != nil {
    51  			return nil, fmt.Errorf("cannot parse private key: %w", err)
    52  		}
    53  		authMethods = ssh.PublicKeys(privateKey)
    54  	default:
    55  		return nil, errors.New("unsupported authentication method")
    56  	}
    57  
    58  	sshConfig := &ssh.ClientConfig{
    59  		User:            config.User,
    60  		Auth:            []ssh.AuthMethod{authMethods},
    61  		Timeout:         config.DialTimeout,
    62  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    63  	}
    64  
    65  	return sshConfig, nil
    66  }
    67  
    68  // NewSSHClient establishes an SSH connection and returns an SSH client
    69  func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
    70  	sshConfig, err := sshClientConfig(config)
    71  	if err != nil {
    72  		return nil, fmt.Errorf("cannot configure SSH client: %w", err)
    73  	}
    74  
    75  	sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", config.HostName, config.Port), sshConfig)
    76  	if err != nil {
    77  		return nil, fmt.Errorf("cannot dial SSH host %q:%d: %w", config.HostName, config.Port, err)
    78  	}
    79  	return sshClient, nil
    80  }
    81  
    82  type clientImpl struct {
    83  	client *sftp.Client
    84  }
    85  
    86  type Client interface {
    87  	OpenFile(path string, f int) (io.ReadWriteCloser, error)
    88  	Remove(path string) error
    89  	MkdirAll(path string) error
    90  }
    91  
    92  // newSFTPClient creates an SFTP client with existing SSH client
    93  func newSFTPClient(client *ssh.Client) (Client, error) {
    94  	sftpClient, err := sftp.NewClient(client)
    95  	if err != nil {
    96  		return nil, fmt.Errorf("cannot create SFTP client: %w", err)
    97  	}
    98  	return &clientImpl{
    99  		client: sftpClient,
   100  	}, nil
   101  }
   102  
   103  func (c *clientImpl) OpenFile(path string, f int) (io.ReadWriteCloser, error) {
   104  	return c.client.OpenFile(path, f)
   105  }
   106  
   107  func (c *clientImpl) Remove(path string) error {
   108  	return c.client.Remove(path)
   109  }
   110  
   111  func (c *clientImpl) MkdirAll(path string) error {
   112  	return c.client.MkdirAll(path)
   113  }