github.com/rudderlabs/rudder-go-kit@v0.30.0/sftp/sftp_test.go (about) 1 package sftp 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "net" 8 "os" 9 "path/filepath" 10 "strconv" 11 "testing" 12 "time" 13 14 "github.com/golang/mock/gomock" 15 "github.com/ory/dockertest/v3" 16 "github.com/stretchr/testify/require" 17 18 "github.com/rudderlabs/rudder-go-kit/sftp/mock_sftp" 19 "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/sshserver" 20 ) 21 22 type nopReadWriteCloser struct { 23 io.ReadWriter 24 } 25 26 func (nwc *nopReadWriteCloser) Close() error { 27 return nil 28 } 29 30 func TestSSHClientConfig(t *testing.T) { 31 // Read private key 32 privateKey, err := os.ReadFile("testdata/ssh/test_key") 33 require.NoError(t, err) 34 35 type testCase struct { 36 description string 37 config *SSHConfig 38 expectedError error 39 } 40 41 testCases := []testCase{ 42 { 43 description: "WithNilConfig", 44 config: nil, 45 expectedError: fmt.Errorf("config should not be nil"), 46 }, 47 { 48 description: "WithEmptyHostName", 49 config: &SSHConfig{ 50 HostName: "", 51 Port: 22, 52 User: "someUser", 53 AuthMethod: "passwordAuth", 54 Password: "somePassword", 55 }, 56 expectedError: fmt.Errorf("hostname should not be empty"), 57 }, 58 { 59 description: "WithEmptyPort", 60 config: &SSHConfig{ 61 HostName: "someHostName", 62 User: "someUser", 63 AuthMethod: "passwordAuth", 64 Password: "somePassword", 65 }, 66 expectedError: fmt.Errorf("port should not be empty"), 67 }, 68 { 69 description: "WithPassword", 70 config: &SSHConfig{ 71 HostName: "someHostName", 72 Port: 22, 73 User: "someUser", 74 AuthMethod: "passwordAuth", 75 Password: "somePassword", 76 }, 77 expectedError: nil, 78 }, 79 { 80 description: "WithPrivateKey", 81 config: &SSHConfig{ 82 HostName: "someHostName", 83 Port: 22, 84 User: "someUser", 85 AuthMethod: "keyAuth", 86 PrivateKey: string(privateKey), 87 }, 88 expectedError: nil, 89 }, 90 { 91 description: "WithUnsupportedAuthMethod", 92 config: &SSHConfig{ 93 HostName: "HostName", 94 Port: 22, 95 User: "someUser", 96 AuthMethod: "invalidAuth", 97 PrivateKey: "somePrivateKey", 98 }, 99 expectedError: fmt.Errorf("unsupported authentication method"), 100 }, 101 } 102 103 for _, tc := range testCases { 104 t.Run(tc.description, func(t *testing.T) { 105 sshConfig, err := sshClientConfig(tc.config) 106 if tc.expectedError != nil { 107 108 require.Error(t, tc.expectedError, err.Error()) 109 require.Nil(t, sshConfig) 110 } else { 111 require.NoError(t, err) 112 require.NotNil(t, sshConfig) 113 } 114 }) 115 } 116 } 117 118 func TestUpload(t *testing.T) { 119 ctrl := gomock.NewController(t) 120 defer ctrl.Finish() 121 122 // Create local directory within the temporary directory 123 localDir, err := os.MkdirTemp("", t.Name()) 124 require.NoError(t, err) 125 126 // Set up local path within the directory 127 localFilePath := filepath.Join(localDir, "test_file.json") 128 129 // Create local file and write data to it 130 localFile, err := os.Create(localFilePath) 131 require.NoError(t, err) 132 defer func() { _ = localFile.Close() }() 133 data := []byte(`{"foo": "bar"}`) 134 err = os.WriteFile(localFilePath, data, 0o644) 135 require.NoError(t, err) 136 137 remoteBuf := bytes.NewBuffer(nil) 138 139 mockSFTPClient := mock_sftp.NewMockClient(ctrl) 140 mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil) 141 mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).Return(nil) 142 143 fileManager := &fileManagerImpl{client: mockSFTPClient} 144 145 err = fileManager.Upload(localFilePath, "someRemotePath") 146 require.NoError(t, err) 147 require.Equal(t, data, remoteBuf.Bytes()) 148 } 149 150 func TestDownload(t *testing.T) { 151 ctrl := gomock.NewController(t) 152 defer ctrl.Finish() 153 154 // Create local directory within the temporary directory 155 localDir, err := os.MkdirTemp("", t.Name()) 156 require.NoError(t, err) 157 158 // Set up local file path within the directory 159 localFilePath := filepath.Join(localDir, "test_file.json") 160 161 data := []byte(`{"foo": "bar"}`) 162 remoteBuf := bytes.NewBuffer(data) 163 164 mockSFTPClient := mock_sftp.NewMockClient(ctrl) 165 mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil) 166 167 fileManager := &fileManagerImpl{client: mockSFTPClient} 168 169 err = fileManager.Download(filepath.Join("someRemoteDir", "test_file.json"), localDir) 170 require.NoError(t, err) 171 localFileContents, err := os.ReadFile(localFilePath) 172 require.NoError(t, err) 173 require.Equal(t, data, localFileContents) 174 } 175 176 func TestDelete(t *testing.T) { 177 ctrl := gomock.NewController(t) 178 defer ctrl.Finish() 179 180 remoteFilePath := "someRemoteFilePath" 181 mockSFTPClient := mock_sftp.NewMockClient(ctrl) 182 mockSFTPClient.EXPECT().Remove(remoteFilePath).Return(nil) 183 184 fileManager := &fileManagerImpl{client: mockSFTPClient} 185 186 err := fileManager.Delete(remoteFilePath) 187 require.NoError(t, err) 188 } 189 190 func TestSFTP(t *testing.T) { 191 pool, err := dockertest.NewPool("") 192 require.NoError(t, err) 193 194 // Let's setup the SSH server 195 publicKeyPath, err := filepath.Abs("testdata/ssh/test_key.pub") 196 require.NoError(t, err) 197 sshServer, err := sshserver.Setup(pool, t, 198 sshserver.WithPublicKeyPath(publicKeyPath), 199 sshserver.WithCredentials("linuxserver.io", ""), 200 ) 201 require.NoError(t, err) 202 sshServerHost := fmt.Sprintf("localhost:%d", sshServer.Port) 203 t.Logf("SSH server is listening on %s", sshServerHost) 204 205 // Read private key 206 privateKey, err := os.ReadFile("testdata/ssh/test_key") 207 require.NoError(t, err) 208 209 // Setup ssh client 210 hostname, portStr, err := net.SplitHostPort(sshServerHost) 211 require.NoError(t, err) 212 port, err := strconv.Atoi(portStr) 213 require.NoError(t, err) 214 sshClient, err := NewSSHClient(&SSHConfig{ 215 User: "linuxserver.io", 216 HostName: hostname, 217 Port: port, 218 AuthMethod: "keyAuth", 219 PrivateKey: string(privateKey), 220 DialTimeout: 10 * time.Second, 221 }) 222 require.NoError(t, err) 223 224 // Create session 225 session, err := sshClient.NewSession() 226 require.NoError(t, err) 227 defer func() { _ = session.Close() }() 228 229 remoteDir := filepath.Join("/tmp", "remote", "data") 230 err = session.Run(fmt.Sprintf("mkdir -p %s", remoteDir)) 231 require.NoError(t, err) 232 233 sftpManger, err := NewFileManager(sshClient) 234 require.NoError(t, err) 235 236 // Create local and remote directories within the temporary directory 237 baseDir := t.TempDir() 238 localDir := filepath.Join(baseDir, "local") 239 240 err = os.MkdirAll(localDir, 0o755) 241 require.NoError(t, err) 242 243 // Set up local and remote file paths within their respective directories 244 localFilePath := filepath.Join(localDir, "test_file.json") 245 remoteFilePath := filepath.Join(remoteDir, "test_file.json") 246 247 // Create local file and write data to it 248 localFile, err := os.Create(localFilePath) 249 require.NoError(t, err) 250 defer func() { _ = localFile.Close() }() 251 data := []byte(`{"foo": "bar"}`) 252 err = os.WriteFile(localFilePath, data, 0o644) 253 require.NoError(t, err) 254 255 err = sftpManger.Upload(localFilePath, remoteFilePath) 256 require.NoError(t, err) 257 258 err = sftpManger.Download(remoteFilePath, baseDir) 259 require.NoError(t, err) 260 261 localFileContents, err := os.ReadFile(localFilePath) 262 require.NoError(t, err) 263 downloadedFileContents, err := os.ReadFile(filepath.Join(baseDir, "test_file.json")) 264 require.NoError(t, err) 265 // Compare the contents of the local file and the downloaded file from the remote server 266 require.Equal(t, localFileContents, downloadedFileContents) 267 268 err = sftpManger.Delete(remoteFilePath) 269 require.NoError(t, err) 270 271 err = sftpManger.Download(remoteFilePath, baseDir) 272 require.Error(t, err, "cannot open remote file: file does not exist") 273 }