github.com/pkg/sftp@v1.13.6/client_test.go (about)

     1  package sftp
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"os"
     8  	"testing"
     9  
    10  	"github.com/kr/fs"
    11  )
    12  
    13  // assert that *Client implements fs.FileSystem
    14  var _ fs.FileSystem = new(Client)
    15  
    16  // assert that *File implements io.ReadWriteCloser
    17  var _ io.ReadWriteCloser = new(File)
    18  
    19  func TestNormaliseError(t *testing.T) {
    20  	var (
    21  		ok         = &StatusError{Code: sshFxOk}
    22  		eof        = &StatusError{Code: sshFxEOF}
    23  		fail       = &StatusError{Code: sshFxFailure}
    24  		noSuchFile = &StatusError{Code: sshFxNoSuchFile}
    25  		foo        = errors.New("foo")
    26  	)
    27  
    28  	var tests = []struct {
    29  		desc string
    30  		err  error
    31  		want error
    32  	}{
    33  		{
    34  			desc: "nil error",
    35  		},
    36  		{
    37  			desc: "not *StatusError",
    38  			err:  foo,
    39  			want: foo,
    40  		},
    41  		{
    42  			desc: "*StatusError with ssh_FX_EOF",
    43  			err:  eof,
    44  			want: io.EOF,
    45  		},
    46  		{
    47  			desc: "*StatusError with ssh_FX_NO_SUCH_FILE",
    48  			err:  noSuchFile,
    49  			want: os.ErrNotExist,
    50  		},
    51  		{
    52  			desc: "*StatusError with ssh_FX_OK",
    53  			err:  ok,
    54  		},
    55  		{
    56  			desc: "*StatusError with ssh_FX_FAILURE",
    57  			err:  fail,
    58  			want: fail,
    59  		},
    60  	}
    61  
    62  	for _, tt := range tests {
    63  		got := normaliseError(tt.err)
    64  		if got != tt.want {
    65  			t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n-  got: %#v",
    66  				tt.err, tt.desc, tt.want, got)
    67  		}
    68  	}
    69  }
    70  
    71  var flagsTests = []struct {
    72  	flags int
    73  	want  uint32
    74  }{
    75  	{os.O_RDONLY, sshFxfRead},
    76  	{os.O_WRONLY, sshFxfWrite},
    77  	{os.O_RDWR, sshFxfRead | sshFxfWrite},
    78  	{os.O_RDWR | os.O_CREATE | os.O_TRUNC, sshFxfRead | sshFxfWrite | sshFxfCreat | sshFxfTrunc},
    79  	{os.O_WRONLY | os.O_APPEND, sshFxfWrite | sshFxfAppend},
    80  }
    81  
    82  func TestFlags(t *testing.T) {
    83  	for i, tt := range flagsTests {
    84  		got := flags(tt.flags)
    85  		if got != tt.want {
    86  			t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
    87  		}
    88  	}
    89  }
    90  
    91  type packetSizeTest struct {
    92  	size  int
    93  	valid bool
    94  }
    95  
    96  var maxPacketCheckedTests = []packetSizeTest{
    97  	{size: 0, valid: false},
    98  	{size: 1, valid: true},
    99  	{size: 32768, valid: true},
   100  	{size: 32769, valid: false},
   101  }
   102  
   103  var maxPacketUncheckedTests = []packetSizeTest{
   104  	{size: 0, valid: false},
   105  	{size: 1, valid: true},
   106  	{size: 32768, valid: true},
   107  	{size: 32769, valid: true},
   108  }
   109  
   110  func TestMaxPacketChecked(t *testing.T) {
   111  	for _, tt := range maxPacketCheckedTests {
   112  		testMaxPacketOption(t, MaxPacketChecked(tt.size), tt)
   113  	}
   114  }
   115  
   116  func TestMaxPacketUnchecked(t *testing.T) {
   117  	for _, tt := range maxPacketUncheckedTests {
   118  		testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt)
   119  	}
   120  }
   121  
   122  func TestMaxPacket(t *testing.T) {
   123  	for _, tt := range maxPacketCheckedTests {
   124  		testMaxPacketOption(t, MaxPacket(tt.size), tt)
   125  	}
   126  }
   127  
   128  func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) {
   129  	var c Client
   130  
   131  	err := o(&c)
   132  	if (err == nil) != tt.valid {
   133  		t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil)
   134  	}
   135  	if c.maxPacket != tt.size && tt.valid {
   136  		t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket)
   137  	}
   138  }
   139  
   140  func testFstatOption(t *testing.T, o ClientOption, value bool) {
   141  	var c Client
   142  
   143  	err := o(&c)
   144  	if err == nil && c.useFstat != value {
   145  		t.Errorf("UseFStat(%v)\n- want: %v\n- got: %v", value, value, c.useFstat)
   146  	}
   147  }
   148  
   149  func TestUseFstatChecked(t *testing.T) {
   150  	testFstatOption(t, UseFstat(true), true)
   151  	testFstatOption(t, UseFstat(false), false)
   152  }
   153  
   154  type sink struct{}
   155  
   156  func (*sink) Close() error                { return nil }
   157  func (*sink) Write(p []byte) (int, error) { return len(p), nil }
   158  
   159  func TestClientZeroLengthPacket(t *testing.T) {
   160  	// Packet length zero (never valid). This used to crash the client.
   161  	packet := []byte{0, 0, 0, 0}
   162  
   163  	r := bytes.NewReader(packet)
   164  	c, err := NewClientPipe(r, &sink{})
   165  	if err == nil {
   166  		t.Error("expected an error, got nil")
   167  	}
   168  	if c != nil {
   169  		c.Close()
   170  	}
   171  }
   172  
   173  func TestClientShortPacket(t *testing.T) {
   174  	// init packet too short.
   175  	packet := []byte{0, 0, 0, 1, 2}
   176  
   177  	r := bytes.NewReader(packet)
   178  	_, err := NewClientPipe(r, &sink{})
   179  	if !errors.Is(err, errShortPacket) {
   180  		t.Fatalf("expected error: %v, got: %v", errShortPacket, err)
   181  	}
   182  }
   183  
   184  // Issue #418: panic in clientConn.recv when the sid is incomplete.
   185  func TestClientNoSid(t *testing.T) {
   186  	stream := new(bytes.Buffer)
   187  	sendPacket(stream, &sshFxVersionPacket{Version: sftpProtocolVersion})
   188  	// Next packet has the sid cut short after two bytes.
   189  	stream.Write([]byte{0, 0, 0, 10, 0, 0})
   190  
   191  	c, err := NewClientPipe(stream, &sink{})
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	_, err = c.Stat("anything")
   197  	if !errors.Is(err, ErrSSHFxConnectionLost) {
   198  		t.Fatal("expected ErrSSHFxConnectionLost, got", err)
   199  	}
   200  }