github.com/rudderlabs/rudder-go-kit@v0.30.0/sftp/sftp_test.go (about)

     1  package sftp
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"os"
     9  	"path/filepath"
    10  	"strconv"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/golang/mock/gomock"
    15  	"github.com/ory/dockertest/v3"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	"github.com/rudderlabs/rudder-go-kit/sftp/mock_sftp"
    19  	"github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/sshserver"
    20  )
    21  
    22  type nopReadWriteCloser struct {
    23  	io.ReadWriter
    24  }
    25  
    26  func (nwc *nopReadWriteCloser) Close() error {
    27  	return nil
    28  }
    29  
    30  func TestSSHClientConfig(t *testing.T) {
    31  	// Read private key
    32  	privateKey, err := os.ReadFile("testdata/ssh/test_key")
    33  	require.NoError(t, err)
    34  
    35  	type testCase struct {
    36  		description   string
    37  		config        *SSHConfig
    38  		expectedError error
    39  	}
    40  
    41  	testCases := []testCase{
    42  		{
    43  			description:   "WithNilConfig",
    44  			config:        nil,
    45  			expectedError: fmt.Errorf("config should not be nil"),
    46  		},
    47  		{
    48  			description: "WithEmptyHostName",
    49  			config: &SSHConfig{
    50  				HostName:   "",
    51  				Port:       22,
    52  				User:       "someUser",
    53  				AuthMethod: "passwordAuth",
    54  				Password:   "somePassword",
    55  			},
    56  			expectedError: fmt.Errorf("hostname should not be empty"),
    57  		},
    58  		{
    59  			description: "WithEmptyPort",
    60  			config: &SSHConfig{
    61  				HostName:   "someHostName",
    62  				User:       "someUser",
    63  				AuthMethod: "passwordAuth",
    64  				Password:   "somePassword",
    65  			},
    66  			expectedError: fmt.Errorf("port should not be empty"),
    67  		},
    68  		{
    69  			description: "WithPassword",
    70  			config: &SSHConfig{
    71  				HostName:   "someHostName",
    72  				Port:       22,
    73  				User:       "someUser",
    74  				AuthMethod: "passwordAuth",
    75  				Password:   "somePassword",
    76  			},
    77  			expectedError: nil,
    78  		},
    79  		{
    80  			description: "WithPrivateKey",
    81  			config: &SSHConfig{
    82  				HostName:   "someHostName",
    83  				Port:       22,
    84  				User:       "someUser",
    85  				AuthMethod: "keyAuth",
    86  				PrivateKey: string(privateKey),
    87  			},
    88  			expectedError: nil,
    89  		},
    90  		{
    91  			description: "WithUnsupportedAuthMethod",
    92  			config: &SSHConfig{
    93  				HostName:   "HostName",
    94  				Port:       22,
    95  				User:       "someUser",
    96  				AuthMethod: "invalidAuth",
    97  				PrivateKey: "somePrivateKey",
    98  			},
    99  			expectedError: fmt.Errorf("unsupported authentication method"),
   100  		},
   101  	}
   102  
   103  	for _, tc := range testCases {
   104  		t.Run(tc.description, func(t *testing.T) {
   105  			sshConfig, err := sshClientConfig(tc.config)
   106  			if tc.expectedError != nil {
   107  
   108  				require.Error(t, tc.expectedError, err.Error())
   109  				require.Nil(t, sshConfig)
   110  			} else {
   111  				require.NoError(t, err)
   112  				require.NotNil(t, sshConfig)
   113  			}
   114  		})
   115  	}
   116  }
   117  
   118  func TestUpload(t *testing.T) {
   119  	ctrl := gomock.NewController(t)
   120  	defer ctrl.Finish()
   121  
   122  	// Create local directory within the temporary directory
   123  	localDir, err := os.MkdirTemp("", t.Name())
   124  	require.NoError(t, err)
   125  
   126  	// Set up local path within the directory
   127  	localFilePath := filepath.Join(localDir, "test_file.json")
   128  
   129  	// Create local file and write data to it
   130  	localFile, err := os.Create(localFilePath)
   131  	require.NoError(t, err)
   132  	defer func() { _ = localFile.Close() }()
   133  	data := []byte(`{"foo": "bar"}`)
   134  	err = os.WriteFile(localFilePath, data, 0o644)
   135  	require.NoError(t, err)
   136  
   137  	remoteBuf := bytes.NewBuffer(nil)
   138  
   139  	mockSFTPClient := mock_sftp.NewMockClient(ctrl)
   140  	mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil)
   141  	mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).Return(nil)
   142  
   143  	fileManager := &fileManagerImpl{client: mockSFTPClient}
   144  
   145  	err = fileManager.Upload(localFilePath, "someRemotePath")
   146  	require.NoError(t, err)
   147  	require.Equal(t, data, remoteBuf.Bytes())
   148  }
   149  
   150  func TestDownload(t *testing.T) {
   151  	ctrl := gomock.NewController(t)
   152  	defer ctrl.Finish()
   153  
   154  	// Create local directory within the temporary directory
   155  	localDir, err := os.MkdirTemp("", t.Name())
   156  	require.NoError(t, err)
   157  
   158  	// Set up local file path within the directory
   159  	localFilePath := filepath.Join(localDir, "test_file.json")
   160  
   161  	data := []byte(`{"foo": "bar"}`)
   162  	remoteBuf := bytes.NewBuffer(data)
   163  
   164  	mockSFTPClient := mock_sftp.NewMockClient(ctrl)
   165  	mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil)
   166  
   167  	fileManager := &fileManagerImpl{client: mockSFTPClient}
   168  
   169  	err = fileManager.Download(filepath.Join("someRemoteDir", "test_file.json"), localDir)
   170  	require.NoError(t, err)
   171  	localFileContents, err := os.ReadFile(localFilePath)
   172  	require.NoError(t, err)
   173  	require.Equal(t, data, localFileContents)
   174  }
   175  
   176  func TestDelete(t *testing.T) {
   177  	ctrl := gomock.NewController(t)
   178  	defer ctrl.Finish()
   179  
   180  	remoteFilePath := "someRemoteFilePath"
   181  	mockSFTPClient := mock_sftp.NewMockClient(ctrl)
   182  	mockSFTPClient.EXPECT().Remove(remoteFilePath).Return(nil)
   183  
   184  	fileManager := &fileManagerImpl{client: mockSFTPClient}
   185  
   186  	err := fileManager.Delete(remoteFilePath)
   187  	require.NoError(t, err)
   188  }
   189  
   190  func TestSFTP(t *testing.T) {
   191  	pool, err := dockertest.NewPool("")
   192  	require.NoError(t, err)
   193  
   194  	// Let's setup the SSH server
   195  	publicKeyPath, err := filepath.Abs("testdata/ssh/test_key.pub")
   196  	require.NoError(t, err)
   197  	sshServer, err := sshserver.Setup(pool, t,
   198  		sshserver.WithPublicKeyPath(publicKeyPath),
   199  		sshserver.WithCredentials("linuxserver.io", ""),
   200  	)
   201  	require.NoError(t, err)
   202  	sshServerHost := fmt.Sprintf("localhost:%d", sshServer.Port)
   203  	t.Logf("SSH server is listening on %s", sshServerHost)
   204  
   205  	// Read private key
   206  	privateKey, err := os.ReadFile("testdata/ssh/test_key")
   207  	require.NoError(t, err)
   208  
   209  	// Setup ssh client
   210  	hostname, portStr, err := net.SplitHostPort(sshServerHost)
   211  	require.NoError(t, err)
   212  	port, err := strconv.Atoi(portStr)
   213  	require.NoError(t, err)
   214  	sshClient, err := NewSSHClient(&SSHConfig{
   215  		User:        "linuxserver.io",
   216  		HostName:    hostname,
   217  		Port:        port,
   218  		AuthMethod:  "keyAuth",
   219  		PrivateKey:  string(privateKey),
   220  		DialTimeout: 10 * time.Second,
   221  	})
   222  	require.NoError(t, err)
   223  
   224  	// Create session
   225  	session, err := sshClient.NewSession()
   226  	require.NoError(t, err)
   227  	defer func() { _ = session.Close() }()
   228  
   229  	remoteDir := filepath.Join("/tmp", "remote", "data")
   230  	err = session.Run(fmt.Sprintf("mkdir -p %s", remoteDir))
   231  	require.NoError(t, err)
   232  
   233  	sftpManger, err := NewFileManager(sshClient)
   234  	require.NoError(t, err)
   235  
   236  	// Create local and remote directories within the temporary directory
   237  	baseDir := t.TempDir()
   238  	localDir := filepath.Join(baseDir, "local")
   239  
   240  	err = os.MkdirAll(localDir, 0o755)
   241  	require.NoError(t, err)
   242  
   243  	// Set up local and remote file paths within their respective directories
   244  	localFilePath := filepath.Join(localDir, "test_file.json")
   245  	remoteFilePath := filepath.Join(remoteDir, "test_file.json")
   246  
   247  	// Create local file and write data to it
   248  	localFile, err := os.Create(localFilePath)
   249  	require.NoError(t, err)
   250  	defer func() { _ = localFile.Close() }()
   251  	data := []byte(`{"foo": "bar"}`)
   252  	err = os.WriteFile(localFilePath, data, 0o644)
   253  	require.NoError(t, err)
   254  
   255  	err = sftpManger.Upload(localFilePath, remoteFilePath)
   256  	require.NoError(t, err)
   257  
   258  	err = sftpManger.Download(remoteFilePath, baseDir)
   259  	require.NoError(t, err)
   260  
   261  	localFileContents, err := os.ReadFile(localFilePath)
   262  	require.NoError(t, err)
   263  	downloadedFileContents, err := os.ReadFile(filepath.Join(baseDir, "test_file.json"))
   264  	require.NoError(t, err)
   265  	// Compare the contents of the local file and the downloaded file from the remote server
   266  	require.Equal(t, localFileContents, downloadedFileContents)
   267  
   268  	err = sftpManger.Delete(remoteFilePath)
   269  	require.NoError(t, err)
   270  
   271  	err = sftpManger.Download(remoteFilePath, baseDir)
   272  	require.Error(t, err, "cannot open remote file: file does not exist")
   273  }