github.com/pkg/sftp@v1.13.6/client_test.go (about) 1 package sftp 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "os" 8 "testing" 9 10 "github.com/kr/fs" 11 ) 12 13 // assert that *Client implements fs.FileSystem 14 var _ fs.FileSystem = new(Client) 15 16 // assert that *File implements io.ReadWriteCloser 17 var _ io.ReadWriteCloser = new(File) 18 19 func TestNormaliseError(t *testing.T) { 20 var ( 21 ok = &StatusError{Code: sshFxOk} 22 eof = &StatusError{Code: sshFxEOF} 23 fail = &StatusError{Code: sshFxFailure} 24 noSuchFile = &StatusError{Code: sshFxNoSuchFile} 25 foo = errors.New("foo") 26 ) 27 28 var tests = []struct { 29 desc string 30 err error 31 want error 32 }{ 33 { 34 desc: "nil error", 35 }, 36 { 37 desc: "not *StatusError", 38 err: foo, 39 want: foo, 40 }, 41 { 42 desc: "*StatusError with ssh_FX_EOF", 43 err: eof, 44 want: io.EOF, 45 }, 46 { 47 desc: "*StatusError with ssh_FX_NO_SUCH_FILE", 48 err: noSuchFile, 49 want: os.ErrNotExist, 50 }, 51 { 52 desc: "*StatusError with ssh_FX_OK", 53 err: ok, 54 }, 55 { 56 desc: "*StatusError with ssh_FX_FAILURE", 57 err: fail, 58 want: fail, 59 }, 60 } 61 62 for _, tt := range tests { 63 got := normaliseError(tt.err) 64 if got != tt.want { 65 t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n- got: %#v", 66 tt.err, tt.desc, tt.want, got) 67 } 68 } 69 } 70 71 var flagsTests = []struct { 72 flags int 73 want uint32 74 }{ 75 {os.O_RDONLY, sshFxfRead}, 76 {os.O_WRONLY, sshFxfWrite}, 77 {os.O_RDWR, sshFxfRead | sshFxfWrite}, 78 {os.O_RDWR | os.O_CREATE | os.O_TRUNC, sshFxfRead | sshFxfWrite | sshFxfCreat | sshFxfTrunc}, 79 {os.O_WRONLY | os.O_APPEND, sshFxfWrite | sshFxfAppend}, 80 } 81 82 func TestFlags(t *testing.T) { 83 for i, tt := range flagsTests { 84 got := flags(tt.flags) 85 if got != tt.want { 86 t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) 87 } 88 } 89 } 90 91 type packetSizeTest struct { 92 size int 93 valid bool 94 } 95 96 var maxPacketCheckedTests = []packetSizeTest{ 97 {size: 0, valid: false}, 98 {size: 1, valid: true}, 99 {size: 32768, valid: true}, 100 {size: 32769, valid: false}, 101 } 102 103 var maxPacketUncheckedTests = []packetSizeTest{ 104 {size: 0, valid: false}, 105 {size: 1, valid: true}, 106 {size: 32768, valid: true}, 107 {size: 32769, valid: true}, 108 } 109 110 func TestMaxPacketChecked(t *testing.T) { 111 for _, tt := range maxPacketCheckedTests { 112 testMaxPacketOption(t, MaxPacketChecked(tt.size), tt) 113 } 114 } 115 116 func TestMaxPacketUnchecked(t *testing.T) { 117 for _, tt := range maxPacketUncheckedTests { 118 testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt) 119 } 120 } 121 122 func TestMaxPacket(t *testing.T) { 123 for _, tt := range maxPacketCheckedTests { 124 testMaxPacketOption(t, MaxPacket(tt.size), tt) 125 } 126 } 127 128 func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) { 129 var c Client 130 131 err := o(&c) 132 if (err == nil) != tt.valid { 133 t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil) 134 } 135 if c.maxPacket != tt.size && tt.valid { 136 t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket) 137 } 138 } 139 140 func testFstatOption(t *testing.T, o ClientOption, value bool) { 141 var c Client 142 143 err := o(&c) 144 if err == nil && c.useFstat != value { 145 t.Errorf("UseFStat(%v)\n- want: %v\n- got: %v", value, value, c.useFstat) 146 } 147 } 148 149 func TestUseFstatChecked(t *testing.T) { 150 testFstatOption(t, UseFstat(true), true) 151 testFstatOption(t, UseFstat(false), false) 152 } 153 154 type sink struct{} 155 156 func (*sink) Close() error { return nil } 157 func (*sink) Write(p []byte) (int, error) { return len(p), nil } 158 159 func TestClientZeroLengthPacket(t *testing.T) { 160 // Packet length zero (never valid). This used to crash the client. 161 packet := []byte{0, 0, 0, 0} 162 163 r := bytes.NewReader(packet) 164 c, err := NewClientPipe(r, &sink{}) 165 if err == nil { 166 t.Error("expected an error, got nil") 167 } 168 if c != nil { 169 c.Close() 170 } 171 } 172 173 func TestClientShortPacket(t *testing.T) { 174 // init packet too short. 175 packet := []byte{0, 0, 0, 1, 2} 176 177 r := bytes.NewReader(packet) 178 _, err := NewClientPipe(r, &sink{}) 179 if !errors.Is(err, errShortPacket) { 180 t.Fatalf("expected error: %v, got: %v", errShortPacket, err) 181 } 182 } 183 184 // Issue #418: panic in clientConn.recv when the sid is incomplete. 185 func TestClientNoSid(t *testing.T) { 186 stream := new(bytes.Buffer) 187 sendPacket(stream, &sshFxVersionPacket{Version: sftpProtocolVersion}) 188 // Next packet has the sid cut short after two bytes. 189 stream.Write([]byte{0, 0, 0, 10, 0, 0}) 190 191 c, err := NewClientPipe(stream, &sink{}) 192 if err != nil { 193 t.Fatal(err) 194 } 195 196 _, err = c.Stat("anything") 197 if !errors.Is(err, ErrSSHFxConnectionLost) { 198 t.Fatal("expected ErrSSHFxConnectionLost, got", err) 199 } 200 }