github.com/pkg/sftp@v1.13.6/packet_test.go (about)

     1  package sftp
     2  
     3  import (
     4  	"bytes"
     5  	"encoding"
     6  	"errors"
     7  	"io/ioutil"
     8  	"os"
     9  	"reflect"
    10  	"testing"
    11  )
    12  
    13  func TestMarshalUint32(t *testing.T) {
    14  	var tests = []struct {
    15  		v    uint32
    16  		want []byte
    17  	}{
    18  		{0, []byte{0, 0, 0, 0}},
    19  		{42, []byte{0, 0, 0, 42}},
    20  		{42 << 8, []byte{0, 0, 42, 0}},
    21  		{42 << 16, []byte{0, 42, 0, 0}},
    22  		{42 << 24, []byte{42, 0, 0, 0}},
    23  		{^uint32(0), []byte{255, 255, 255, 255}},
    24  	}
    25  
    26  	for _, tt := range tests {
    27  		got := marshalUint32(nil, tt.v)
    28  		if !bytes.Equal(tt.want, got) {
    29  			t.Errorf("marshalUint32(%d) = %#v, want %#v", tt.v, got, tt.want)
    30  		}
    31  	}
    32  }
    33  
    34  func TestMarshalUint64(t *testing.T) {
    35  	var tests = []struct {
    36  		v    uint64
    37  		want []byte
    38  	}{
    39  		{0, []byte{0, 0, 0, 0, 0, 0, 0, 0}},
    40  		{42, []byte{0, 0, 0, 0, 0, 0, 0, 42}},
    41  		{42 << 8, []byte{0, 0, 0, 0, 0, 0, 42, 0}},
    42  		{42 << 16, []byte{0, 0, 0, 0, 0, 42, 0, 0}},
    43  		{42 << 24, []byte{0, 0, 0, 0, 42, 0, 0, 0}},
    44  		{42 << 32, []byte{0, 0, 0, 42, 0, 0, 0, 0}},
    45  		{42 << 40, []byte{0, 0, 42, 0, 0, 0, 0, 0}},
    46  		{42 << 48, []byte{0, 42, 0, 0, 0, 0, 0, 0}},
    47  		{42 << 56, []byte{42, 0, 0, 0, 0, 0, 0, 0}},
    48  		{^uint64(0), []byte{255, 255, 255, 255, 255, 255, 255, 255}},
    49  	}
    50  
    51  	for _, tt := range tests {
    52  		got := marshalUint64(nil, tt.v)
    53  		if !bytes.Equal(tt.want, got) {
    54  			t.Errorf("marshalUint64(%d) = %#v, want %#v", tt.v, got, tt.want)
    55  		}
    56  	}
    57  }
    58  
    59  func TestMarshalString(t *testing.T) {
    60  	var tests = []struct {
    61  		v    string
    62  		want []byte
    63  	}{
    64  		{"", []byte{0, 0, 0, 0}},
    65  		{"/", []byte{0x0, 0x0, 0x0, 0x01, '/'}},
    66  		{"/foo", []byte{0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o'}},
    67  		{"\x00bar", []byte{0x0, 0x0, 0x0, 0x4, 0, 'b', 'a', 'r'}},
    68  		{"b\x00ar", []byte{0x0, 0x0, 0x0, 0x4, 'b', 0, 'a', 'r'}},
    69  		{"ba\x00r", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 0, 'r'}},
    70  		{"bar\x00", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 'r', 0}},
    71  	}
    72  
    73  	for _, tt := range tests {
    74  		got := marshalString(nil, tt.v)
    75  		if !bytes.Equal(tt.want, got) {
    76  			t.Errorf("marshalString(%q) = %#v, want %#v", tt.v, got, tt.want)
    77  		}
    78  	}
    79  }
    80  
    81  func TestMarshal(t *testing.T) {
    82  	type Struct struct {
    83  		X, Y, Z uint32
    84  	}
    85  
    86  	var tests = []struct {
    87  		v    interface{}
    88  		want []byte
    89  	}{
    90  		{uint8(42), []byte{42}},
    91  		{uint32(42 << 8), []byte{0, 0, 42, 0}},
    92  		{uint64(42 << 32), []byte{0, 0, 0, 42, 0, 0, 0, 0}},
    93  		{"foo", []byte{0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o'}},
    94  		{Struct{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}},
    95  		{[]uint32{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}},
    96  	}
    97  
    98  	for _, tt := range tests {
    99  		got := marshal(nil, tt.v)
   100  		if !bytes.Equal(tt.want, got) {
   101  			t.Errorf("marshal(%#v) = %#v, want %#v", tt.v, got, tt.want)
   102  		}
   103  	}
   104  }
   105  
   106  func TestUnmarshalUint32(t *testing.T) {
   107  	testBuffer := []byte{
   108  		0, 0, 0, 0,
   109  		0, 0, 0, 42,
   110  		0, 0, 42, 0,
   111  		0, 42, 0, 0,
   112  		42, 0, 0, 0,
   113  		255, 0, 0, 254,
   114  	}
   115  
   116  	var wants = []uint32{
   117  		0,
   118  		42,
   119  		42 << 8,
   120  		42 << 16,
   121  		42 << 24,
   122  		255<<24 | 254,
   123  	}
   124  
   125  	var i int
   126  	for len(testBuffer) > 0 {
   127  		got, rest := unmarshalUint32(testBuffer)
   128  
   129  		if got != wants[i] {
   130  			t.Fatalf("unmarshalUint32(%#v) = %d, want %d", testBuffer[:4], got, wants[i])
   131  		}
   132  
   133  		i++
   134  		testBuffer = rest
   135  	}
   136  }
   137  
   138  func TestUnmarshalUint64(t *testing.T) {
   139  	testBuffer := []byte{
   140  		0, 0, 0, 0, 0, 0, 0, 0,
   141  		0, 0, 0, 0, 0, 0, 0, 42,
   142  		0, 0, 0, 0, 0, 0, 42, 0,
   143  		0, 0, 0, 0, 0, 42, 0, 0,
   144  		0, 0, 0, 0, 42, 0, 0, 0,
   145  		0, 0, 0, 42, 0, 0, 0, 0,
   146  		0, 0, 42, 0, 0, 0, 0, 0,
   147  		0, 42, 0, 0, 0, 0, 0, 0,
   148  		42, 0, 0, 0, 0, 0, 0, 0,
   149  		255, 0, 0, 0, 0, 0, 0, 254,
   150  	}
   151  
   152  	var wants = []uint64{
   153  		0,
   154  		42,
   155  		42 << 8,
   156  		42 << 16,
   157  		42 << 24,
   158  		42 << 32,
   159  		42 << 40,
   160  		42 << 48,
   161  		42 << 56,
   162  		255<<56 | 254,
   163  	}
   164  
   165  	var i int
   166  	for len(testBuffer) > 0 {
   167  		got, rest := unmarshalUint64(testBuffer)
   168  
   169  		if got != wants[i] {
   170  			t.Fatalf("unmarshalUint64(%#v) = %d, want %d", testBuffer[:8], got, wants[i])
   171  		}
   172  
   173  		i++
   174  		testBuffer = rest
   175  	}
   176  }
   177  
   178  var unmarshalStringTests = []struct {
   179  	b    []byte
   180  	want string
   181  	rest []byte
   182  }{
   183  	{marshalString(nil, ""), "", nil},
   184  	{marshalString(nil, "blah"), "blah", nil},
   185  }
   186  
   187  func TestUnmarshalString(t *testing.T) {
   188  	testBuffer := []byte{
   189  		0, 0, 0, 0,
   190  		0, 0, 0, 1, '/',
   191  		0, 0, 0, 4, '/', 'f', 'o', 'o',
   192  		0, 0, 0, 4, 0, 'b', 'a', 'r',
   193  		0, 0, 0, 4, 'b', 0, 'a', 'r',
   194  		0, 0, 0, 4, 'b', 'a', 0, 'r',
   195  		0, 0, 0, 4, 'b', 'a', 'r', 0,
   196  	}
   197  
   198  	var wants = []string{
   199  		"",
   200  		"/",
   201  		"/foo",
   202  		"\x00bar",
   203  		"b\x00ar",
   204  		"ba\x00r",
   205  		"bar\x00",
   206  	}
   207  
   208  	var i int
   209  	for len(testBuffer) > 0 {
   210  		got, rest := unmarshalString(testBuffer)
   211  
   212  		if got != wants[i] {
   213  			t.Fatalf("unmarshalUint64(%#v...) = %q, want %q", testBuffer[:4], got, wants[i])
   214  		}
   215  
   216  		i++
   217  		testBuffer = rest
   218  	}
   219  }
   220  
   221  func TestUnmarshalAttrs(t *testing.T) {
   222  	var tests = []struct {
   223  		b    []byte
   224  		want *FileStat
   225  	}{
   226  		{
   227  			b:    []byte{0x00, 0x00, 0x00, 0x00},
   228  			want: &FileStat{},
   229  		},
   230  		{
   231  			b: []byte{
   232  				0x00, 0x00, 0x00, byte(sshFileXferAttrSize),
   233  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
   234  			},
   235  			want: &FileStat{
   236  				Size: 20,
   237  			},
   238  		},
   239  		{
   240  			b: []byte{
   241  				0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions),
   242  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
   243  				0x00, 0x00, 0x01, 0xA4,
   244  			},
   245  			want: &FileStat{
   246  				Size: 20,
   247  				Mode: 0644,
   248  			},
   249  		},
   250  		{
   251  			b: []byte{
   252  				0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID),
   253  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
   254  				0x00, 0x00, 0x03, 0xE8,
   255  				0x00, 0x00, 0x03, 0xE9,
   256  				0x00, 0x00, 0x01, 0xA4,
   257  			},
   258  			want: &FileStat{
   259  				Size: 20,
   260  				Mode: 0644,
   261  				UID:  1000,
   262  				GID:  1001,
   263  			},
   264  		},
   265  		{
   266  			b: []byte{
   267  				0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID | sshFileXferAttrACmodTime),
   268  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
   269  				0x00, 0x00, 0x03, 0xE8,
   270  				0x00, 0x00, 0x03, 0xE9,
   271  				0x00, 0x00, 0x01, 0xA4,
   272  				0x00, 0x00, 0x00, 42,
   273  				0x00, 0x00, 0x00, 13,
   274  			},
   275  			want: &FileStat{
   276  				Size:  20,
   277  				Mode:  0644,
   278  				UID:   1000,
   279  				GID:   1001,
   280  				Atime: 42,
   281  				Mtime: 13,
   282  			},
   283  		},
   284  	}
   285  
   286  	for _, tt := range tests {
   287  		got, _ := unmarshalAttrs(tt.b)
   288  		if !reflect.DeepEqual(got, tt.want) {
   289  			t.Errorf("unmarshalAttrs(% X):\n-  got: %#v\n- want: %#v", tt.b, got, tt.want)
   290  		}
   291  	}
   292  }
   293  
   294  func TestUnmarshalStatus(t *testing.T) {
   295  	var requestID uint32 = 1
   296  
   297  	id := marshalUint32(nil, requestID)
   298  	idCode := marshalUint32(id, sshFxFailure)
   299  	idCodeMsg := marshalString(idCode, "err msg")
   300  	idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
   301  
   302  	var tests = []struct {
   303  		desc   string
   304  		reqID  uint32
   305  		status []byte
   306  		want   error
   307  	}{
   308  		{
   309  			desc:   "well-formed status",
   310  			status: idCodeMsgLang,
   311  			want: &StatusError{
   312  				Code: sshFxFailure,
   313  				msg:  "err msg",
   314  				lang: "lang tag",
   315  			},
   316  		},
   317  		{
   318  			desc:   "missing language tag",
   319  			status: idCodeMsg,
   320  			want: &StatusError{
   321  				Code: sshFxFailure,
   322  				msg:  "err msg",
   323  			},
   324  		},
   325  		{
   326  			desc:   "missing error message and language tag",
   327  			status: idCode,
   328  			want: &StatusError{
   329  				Code: sshFxFailure,
   330  			},
   331  		},
   332  	}
   333  
   334  	for _, tt := range tests {
   335  		t.Run(tt.desc, func(t *testing.T) {
   336  			got := unmarshalStatus(1, tt.status)
   337  			if !reflect.DeepEqual(got, tt.want) {
   338  				t.Errorf("unmarshalStatus(1, % X):\n-  got: %#v\n- want: %#v", tt.status, got, tt.want)
   339  			}
   340  		})
   341  	}
   342  
   343  	got := unmarshalStatus(2, idCodeMsgLang)
   344  	want := &unexpectedIDErr{
   345  		want: 2,
   346  		got:  1,
   347  	}
   348  	if !reflect.DeepEqual(got, want) {
   349  		t.Errorf("unmarshalStatus(2, % X):\n-  got: %#v\n- want: %#v", idCodeMsgLang, got, want)
   350  	}
   351  }
   352  
   353  func TestSendPacket(t *testing.T) {
   354  	var tests = []struct {
   355  		packet encoding.BinaryMarshaler
   356  		want   []byte
   357  	}{
   358  		{
   359  			packet: &sshFxInitPacket{
   360  				Version: 3,
   361  				Extensions: []extensionPair{
   362  					{"posix-rename@openssh.com", "1"},
   363  				},
   364  			},
   365  			want: []byte{
   366  				0x0, 0x0, 0x0, 0x26,
   367  				0x1,
   368  				0x0, 0x0, 0x0, 0x3,
   369  				0x0, 0x0, 0x0, 0x18,
   370  				'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
   371  				0x0, 0x0, 0x0, 0x1,
   372  				'1',
   373  			},
   374  		},
   375  		{
   376  			packet: &sshFxpOpenPacket{
   377  				ID:     1,
   378  				Path:   "/foo",
   379  				Pflags: flags(os.O_RDONLY),
   380  			},
   381  			want: []byte{
   382  				0x0, 0x0, 0x0, 0x15,
   383  				0x3,
   384  				0x0, 0x0, 0x0, 0x1,
   385  				0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o',
   386  				0x0, 0x0, 0x0, 0x1,
   387  				0x0, 0x0, 0x0, 0x0,
   388  			},
   389  		},
   390  		{
   391  			packet: &sshFxpWritePacket{
   392  				ID:     124,
   393  				Handle: "foo",
   394  				Offset: 13,
   395  				Length: uint32(len("bar")),
   396  				Data:   []byte("bar"),
   397  			},
   398  			want: []byte{
   399  				0x0, 0x0, 0x0, 0x1b,
   400  				0x6,
   401  				0x0, 0x0, 0x0, 0x7c,
   402  				0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o',
   403  				0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd,
   404  				0x0, 0x0, 0x0, 0x3, 'b', 'a', 'r',
   405  			},
   406  		},
   407  		{
   408  			packet: &sshFxpSetstatPacket{
   409  				ID:    31,
   410  				Path:  "/bar",
   411  				Flags: sshFileXferAttrUIDGID,
   412  				Attrs: struct {
   413  					UID uint32
   414  					GID uint32
   415  				}{
   416  					UID: 1000,
   417  					GID: 100,
   418  				},
   419  			},
   420  			want: []byte{
   421  				0x0, 0x0, 0x0, 0x19,
   422  				0x9,
   423  				0x0, 0x0, 0x0, 0x1f,
   424  				0x0, 0x0, 0x0, 0x4, '/', 'b', 'a', 'r',
   425  				0x0, 0x0, 0x0, 0x2,
   426  				0x0, 0x0, 0x3, 0xe8,
   427  				0x0, 0x0, 0x0, 0x64,
   428  			},
   429  		},
   430  	}
   431  
   432  	for _, tt := range tests {
   433  		b := new(bytes.Buffer)
   434  		sendPacket(b, tt.packet)
   435  		if got := b.Bytes(); !bytes.Equal(tt.want, got) {
   436  			t.Errorf("sendPacket(%v): got %x want %x", tt.packet, tt.want, got)
   437  		}
   438  	}
   439  }
   440  
   441  func sp(data encoding.BinaryMarshaler) []byte {
   442  	b := new(bytes.Buffer)
   443  	sendPacket(b, data)
   444  	return b.Bytes()
   445  }
   446  
   447  func TestRecvPacket(t *testing.T) {
   448  	var recvPacketTests = []struct {
   449  		b []byte
   450  
   451  		want    uint8
   452  		body    []byte
   453  		wantErr error
   454  	}{
   455  		{
   456  			b: sp(&sshFxInitPacket{
   457  				Version: 3,
   458  				Extensions: []extensionPair{
   459  					{"posix-rename@openssh.com", "1"},
   460  				},
   461  			}),
   462  			want: sshFxpInit,
   463  			body: []byte{
   464  				0x0, 0x0, 0x0, 0x3,
   465  				0x0, 0x0, 0x0, 0x18,
   466  				'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
   467  				0x0, 0x0, 0x0, 0x01,
   468  				'1',
   469  			},
   470  		},
   471  		{
   472  			b: []byte{
   473  				0x0, 0x0, 0x0, 0x0,
   474  			},
   475  			wantErr: errShortPacket,
   476  		},
   477  		{
   478  			b: []byte{
   479  				0xff, 0xff, 0xff, 0xff,
   480  			},
   481  			wantErr: errLongPacket,
   482  		},
   483  	}
   484  
   485  	for _, tt := range recvPacketTests {
   486  		r := bytes.NewReader(tt.b)
   487  
   488  		got, body, err := recvPacket(r, nil, 0)
   489  		if tt.wantErr == nil {
   490  			if err != nil {
   491  				t.Fatalf("recvPacket(%#v): unexpected error: %v", tt.b, err)
   492  			}
   493  		} else {
   494  			if !errors.Is(err, tt.wantErr) {
   495  				t.Fatalf("recvPacket(%#v) = %v, want %v", tt.b, err, tt.wantErr)
   496  			}
   497  		}
   498  
   499  		if got != tt.want {
   500  			t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, got, tt.want)
   501  		}
   502  
   503  		if !bytes.Equal(body, tt.body) {
   504  			t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, body, tt.body)
   505  		}
   506  	}
   507  }
   508  
   509  func TestSSHFxpOpenPacketreadonly(t *testing.T) {
   510  	var tests = []struct {
   511  		pflags uint32
   512  		ok     bool
   513  	}{
   514  		{
   515  			pflags: sshFxfRead,
   516  			ok:     true,
   517  		},
   518  		{
   519  			pflags: sshFxfWrite,
   520  			ok:     false,
   521  		},
   522  		{
   523  			pflags: sshFxfRead | sshFxfWrite,
   524  			ok:     false,
   525  		},
   526  	}
   527  
   528  	for _, tt := range tests {
   529  		p := &sshFxpOpenPacket{
   530  			Pflags: tt.pflags,
   531  		}
   532  
   533  		if want, got := tt.ok, p.readonly(); want != got {
   534  			t.Errorf("unexpected value for p.readonly(): want: %v, got: %v",
   535  				want, got)
   536  		}
   537  	}
   538  }
   539  
   540  func TestSSHFxpOpenPackethasPflags(t *testing.T) {
   541  	var tests = []struct {
   542  		desc      string
   543  		haveFlags uint32
   544  		testFlags []uint32
   545  		ok        bool
   546  	}{
   547  		{
   548  			desc:      "have read, test against write",
   549  			haveFlags: sshFxfRead,
   550  			testFlags: []uint32{sshFxfWrite},
   551  			ok:        false,
   552  		},
   553  		{
   554  			desc:      "have write, test against read",
   555  			haveFlags: sshFxfWrite,
   556  			testFlags: []uint32{sshFxfRead},
   557  			ok:        false,
   558  		},
   559  		{
   560  			desc:      "have read+write, test against read",
   561  			haveFlags: sshFxfRead | sshFxfWrite,
   562  			testFlags: []uint32{sshFxfRead},
   563  			ok:        true,
   564  		},
   565  		{
   566  			desc:      "have read+write, test against write",
   567  			haveFlags: sshFxfRead | sshFxfWrite,
   568  			testFlags: []uint32{sshFxfWrite},
   569  			ok:        true,
   570  		},
   571  		{
   572  			desc:      "have read+write, test against read+write",
   573  			haveFlags: sshFxfRead | sshFxfWrite,
   574  			testFlags: []uint32{sshFxfRead, sshFxfWrite},
   575  			ok:        true,
   576  		},
   577  	}
   578  
   579  	for _, tt := range tests {
   580  		t.Log(tt.desc)
   581  
   582  		p := &sshFxpOpenPacket{
   583  			Pflags: tt.haveFlags,
   584  		}
   585  
   586  		if want, got := tt.ok, p.hasPflags(tt.testFlags...); want != got {
   587  			t.Errorf("unexpected value for p.hasPflags(%#v): want: %v, got: %v",
   588  				tt.testFlags, want, got)
   589  		}
   590  	}
   591  }
   592  
   593  func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) {
   594  	b.ResetTimer()
   595  
   596  	for i := 0; i < b.N; i++ {
   597  		sendPacket(ioutil.Discard, packet)
   598  	}
   599  }
   600  
   601  func BenchmarkMarshalInit(b *testing.B) {
   602  	benchMarshal(b, &sshFxInitPacket{
   603  		Version: 3,
   604  		Extensions: []extensionPair{
   605  			{"posix-rename@openssh.com", "1"},
   606  		},
   607  	})
   608  }
   609  
   610  func BenchmarkMarshalOpen(b *testing.B) {
   611  	benchMarshal(b, &sshFxpOpenPacket{
   612  		ID:     1,
   613  		Path:   "/home/test/some/random/path",
   614  		Pflags: flags(os.O_RDONLY),
   615  	})
   616  }
   617  
   618  func BenchmarkMarshalWriteWorstCase(b *testing.B) {
   619  	data := make([]byte, 32*1024)
   620  
   621  	benchMarshal(b, &sshFxpWritePacket{
   622  		ID:     1,
   623  		Handle: "someopaquehandle",
   624  		Offset: 0,
   625  		Length: uint32(len(data)),
   626  		Data:   data,
   627  	})
   628  }
   629  
   630  func BenchmarkMarshalWrite1k(b *testing.B) {
   631  	data := make([]byte, 1025)
   632  
   633  	benchMarshal(b, &sshFxpWritePacket{
   634  		ID:     1,
   635  		Handle: "someopaquehandle",
   636  		Offset: 0,
   637  		Length: uint32(len(data)),
   638  		Data:   data,
   639  	})
   640  }