github.com/pkg/sftp@v1.13.6/packet_test.go (about) 1 package sftp 2 3 import ( 4 "bytes" 5 "encoding" 6 "errors" 7 "io/ioutil" 8 "os" 9 "reflect" 10 "testing" 11 ) 12 13 func TestMarshalUint32(t *testing.T) { 14 var tests = []struct { 15 v uint32 16 want []byte 17 }{ 18 {0, []byte{0, 0, 0, 0}}, 19 {42, []byte{0, 0, 0, 42}}, 20 {42 << 8, []byte{0, 0, 42, 0}}, 21 {42 << 16, []byte{0, 42, 0, 0}}, 22 {42 << 24, []byte{42, 0, 0, 0}}, 23 {^uint32(0), []byte{255, 255, 255, 255}}, 24 } 25 26 for _, tt := range tests { 27 got := marshalUint32(nil, tt.v) 28 if !bytes.Equal(tt.want, got) { 29 t.Errorf("marshalUint32(%d) = %#v, want %#v", tt.v, got, tt.want) 30 } 31 } 32 } 33 34 func TestMarshalUint64(t *testing.T) { 35 var tests = []struct { 36 v uint64 37 want []byte 38 }{ 39 {0, []byte{0, 0, 0, 0, 0, 0, 0, 0}}, 40 {42, []byte{0, 0, 0, 0, 0, 0, 0, 42}}, 41 {42 << 8, []byte{0, 0, 0, 0, 0, 0, 42, 0}}, 42 {42 << 16, []byte{0, 0, 0, 0, 0, 42, 0, 0}}, 43 {42 << 24, []byte{0, 0, 0, 0, 42, 0, 0, 0}}, 44 {42 << 32, []byte{0, 0, 0, 42, 0, 0, 0, 0}}, 45 {42 << 40, []byte{0, 0, 42, 0, 0, 0, 0, 0}}, 46 {42 << 48, []byte{0, 42, 0, 0, 0, 0, 0, 0}}, 47 {42 << 56, []byte{42, 0, 0, 0, 0, 0, 0, 0}}, 48 {^uint64(0), []byte{255, 255, 255, 255, 255, 255, 255, 255}}, 49 } 50 51 for _, tt := range tests { 52 got := marshalUint64(nil, tt.v) 53 if !bytes.Equal(tt.want, got) { 54 t.Errorf("marshalUint64(%d) = %#v, want %#v", tt.v, got, tt.want) 55 } 56 } 57 } 58 59 func TestMarshalString(t *testing.T) { 60 var tests = []struct { 61 v string 62 want []byte 63 }{ 64 {"", []byte{0, 0, 0, 0}}, 65 {"/", []byte{0x0, 0x0, 0x0, 0x01, '/'}}, 66 {"/foo", []byte{0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o'}}, 67 {"\x00bar", []byte{0x0, 0x0, 0x0, 0x4, 0, 'b', 'a', 'r'}}, 68 {"b\x00ar", []byte{0x0, 0x0, 0x0, 0x4, 'b', 0, 'a', 'r'}}, 69 {"ba\x00r", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 0, 'r'}}, 70 {"bar\x00", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 'r', 0}}, 71 } 72 73 for _, tt := range tests { 74 got := marshalString(nil, tt.v) 75 if !bytes.Equal(tt.want, got) { 76 t.Errorf("marshalString(%q) = %#v, want %#v", tt.v, got, tt.want) 77 } 78 } 79 } 80 81 func TestMarshal(t *testing.T) { 82 type Struct struct { 83 X, Y, Z uint32 84 } 85 86 var tests = []struct { 87 v interface{} 88 want []byte 89 }{ 90 {uint8(42), []byte{42}}, 91 {uint32(42 << 8), []byte{0, 0, 42, 0}}, 92 {uint64(42 << 32), []byte{0, 0, 0, 42, 0, 0, 0, 0}}, 93 {"foo", []byte{0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o'}}, 94 {Struct{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}}, 95 {[]uint32{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}}, 96 } 97 98 for _, tt := range tests { 99 got := marshal(nil, tt.v) 100 if !bytes.Equal(tt.want, got) { 101 t.Errorf("marshal(%#v) = %#v, want %#v", tt.v, got, tt.want) 102 } 103 } 104 } 105 106 func TestUnmarshalUint32(t *testing.T) { 107 testBuffer := []byte{ 108 0, 0, 0, 0, 109 0, 0, 0, 42, 110 0, 0, 42, 0, 111 0, 42, 0, 0, 112 42, 0, 0, 0, 113 255, 0, 0, 254, 114 } 115 116 var wants = []uint32{ 117 0, 118 42, 119 42 << 8, 120 42 << 16, 121 42 << 24, 122 255<<24 | 254, 123 } 124 125 var i int 126 for len(testBuffer) > 0 { 127 got, rest := unmarshalUint32(testBuffer) 128 129 if got != wants[i] { 130 t.Fatalf("unmarshalUint32(%#v) = %d, want %d", testBuffer[:4], got, wants[i]) 131 } 132 133 i++ 134 testBuffer = rest 135 } 136 } 137 138 func TestUnmarshalUint64(t *testing.T) { 139 testBuffer := []byte{ 140 0, 0, 0, 0, 0, 0, 0, 0, 141 0, 0, 0, 0, 0, 0, 0, 42, 142 0, 0, 0, 0, 0, 0, 42, 0, 143 0, 0, 0, 0, 0, 42, 0, 0, 144 0, 0, 0, 0, 42, 0, 0, 0, 145 0, 0, 0, 42, 0, 0, 0, 0, 146 0, 0, 42, 0, 0, 0, 0, 0, 147 0, 42, 0, 0, 0, 0, 0, 0, 148 42, 0, 0, 0, 0, 0, 0, 0, 149 255, 0, 0, 0, 0, 0, 0, 254, 150 } 151 152 var wants = []uint64{ 153 0, 154 42, 155 42 << 8, 156 42 << 16, 157 42 << 24, 158 42 << 32, 159 42 << 40, 160 42 << 48, 161 42 << 56, 162 255<<56 | 254, 163 } 164 165 var i int 166 for len(testBuffer) > 0 { 167 got, rest := unmarshalUint64(testBuffer) 168 169 if got != wants[i] { 170 t.Fatalf("unmarshalUint64(%#v) = %d, want %d", testBuffer[:8], got, wants[i]) 171 } 172 173 i++ 174 testBuffer = rest 175 } 176 } 177 178 var unmarshalStringTests = []struct { 179 b []byte 180 want string 181 rest []byte 182 }{ 183 {marshalString(nil, ""), "", nil}, 184 {marshalString(nil, "blah"), "blah", nil}, 185 } 186 187 func TestUnmarshalString(t *testing.T) { 188 testBuffer := []byte{ 189 0, 0, 0, 0, 190 0, 0, 0, 1, '/', 191 0, 0, 0, 4, '/', 'f', 'o', 'o', 192 0, 0, 0, 4, 0, 'b', 'a', 'r', 193 0, 0, 0, 4, 'b', 0, 'a', 'r', 194 0, 0, 0, 4, 'b', 'a', 0, 'r', 195 0, 0, 0, 4, 'b', 'a', 'r', 0, 196 } 197 198 var wants = []string{ 199 "", 200 "/", 201 "/foo", 202 "\x00bar", 203 "b\x00ar", 204 "ba\x00r", 205 "bar\x00", 206 } 207 208 var i int 209 for len(testBuffer) > 0 { 210 got, rest := unmarshalString(testBuffer) 211 212 if got != wants[i] { 213 t.Fatalf("unmarshalUint64(%#v...) = %q, want %q", testBuffer[:4], got, wants[i]) 214 } 215 216 i++ 217 testBuffer = rest 218 } 219 } 220 221 func TestUnmarshalAttrs(t *testing.T) { 222 var tests = []struct { 223 b []byte 224 want *FileStat 225 }{ 226 { 227 b: []byte{0x00, 0x00, 0x00, 0x00}, 228 want: &FileStat{}, 229 }, 230 { 231 b: []byte{ 232 0x00, 0x00, 0x00, byte(sshFileXferAttrSize), 233 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, 234 }, 235 want: &FileStat{ 236 Size: 20, 237 }, 238 }, 239 { 240 b: []byte{ 241 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions), 242 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, 243 0x00, 0x00, 0x01, 0xA4, 244 }, 245 want: &FileStat{ 246 Size: 20, 247 Mode: 0644, 248 }, 249 }, 250 { 251 b: []byte{ 252 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID), 253 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, 254 0x00, 0x00, 0x03, 0xE8, 255 0x00, 0x00, 0x03, 0xE9, 256 0x00, 0x00, 0x01, 0xA4, 257 }, 258 want: &FileStat{ 259 Size: 20, 260 Mode: 0644, 261 UID: 1000, 262 GID: 1001, 263 }, 264 }, 265 { 266 b: []byte{ 267 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID | sshFileXferAttrACmodTime), 268 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, 269 0x00, 0x00, 0x03, 0xE8, 270 0x00, 0x00, 0x03, 0xE9, 271 0x00, 0x00, 0x01, 0xA4, 272 0x00, 0x00, 0x00, 42, 273 0x00, 0x00, 0x00, 13, 274 }, 275 want: &FileStat{ 276 Size: 20, 277 Mode: 0644, 278 UID: 1000, 279 GID: 1001, 280 Atime: 42, 281 Mtime: 13, 282 }, 283 }, 284 } 285 286 for _, tt := range tests { 287 got, _ := unmarshalAttrs(tt.b) 288 if !reflect.DeepEqual(got, tt.want) { 289 t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want) 290 } 291 } 292 } 293 294 func TestUnmarshalStatus(t *testing.T) { 295 var requestID uint32 = 1 296 297 id := marshalUint32(nil, requestID) 298 idCode := marshalUint32(id, sshFxFailure) 299 idCodeMsg := marshalString(idCode, "err msg") 300 idCodeMsgLang := marshalString(idCodeMsg, "lang tag") 301 302 var tests = []struct { 303 desc string 304 reqID uint32 305 status []byte 306 want error 307 }{ 308 { 309 desc: "well-formed status", 310 status: idCodeMsgLang, 311 want: &StatusError{ 312 Code: sshFxFailure, 313 msg: "err msg", 314 lang: "lang tag", 315 }, 316 }, 317 { 318 desc: "missing language tag", 319 status: idCodeMsg, 320 want: &StatusError{ 321 Code: sshFxFailure, 322 msg: "err msg", 323 }, 324 }, 325 { 326 desc: "missing error message and language tag", 327 status: idCode, 328 want: &StatusError{ 329 Code: sshFxFailure, 330 }, 331 }, 332 } 333 334 for _, tt := range tests { 335 t.Run(tt.desc, func(t *testing.T) { 336 got := unmarshalStatus(1, tt.status) 337 if !reflect.DeepEqual(got, tt.want) { 338 t.Errorf("unmarshalStatus(1, % X):\n- got: %#v\n- want: %#v", tt.status, got, tt.want) 339 } 340 }) 341 } 342 343 got := unmarshalStatus(2, idCodeMsgLang) 344 want := &unexpectedIDErr{ 345 want: 2, 346 got: 1, 347 } 348 if !reflect.DeepEqual(got, want) { 349 t.Errorf("unmarshalStatus(2, % X):\n- got: %#v\n- want: %#v", idCodeMsgLang, got, want) 350 } 351 } 352 353 func TestSendPacket(t *testing.T) { 354 var tests = []struct { 355 packet encoding.BinaryMarshaler 356 want []byte 357 }{ 358 { 359 packet: &sshFxInitPacket{ 360 Version: 3, 361 Extensions: []extensionPair{ 362 {"posix-rename@openssh.com", "1"}, 363 }, 364 }, 365 want: []byte{ 366 0x0, 0x0, 0x0, 0x26, 367 0x1, 368 0x0, 0x0, 0x0, 0x3, 369 0x0, 0x0, 0x0, 0x18, 370 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', 371 0x0, 0x0, 0x0, 0x1, 372 '1', 373 }, 374 }, 375 { 376 packet: &sshFxpOpenPacket{ 377 ID: 1, 378 Path: "/foo", 379 Pflags: flags(os.O_RDONLY), 380 }, 381 want: []byte{ 382 0x0, 0x0, 0x0, 0x15, 383 0x3, 384 0x0, 0x0, 0x0, 0x1, 385 0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o', 386 0x0, 0x0, 0x0, 0x1, 387 0x0, 0x0, 0x0, 0x0, 388 }, 389 }, 390 { 391 packet: &sshFxpWritePacket{ 392 ID: 124, 393 Handle: "foo", 394 Offset: 13, 395 Length: uint32(len("bar")), 396 Data: []byte("bar"), 397 }, 398 want: []byte{ 399 0x0, 0x0, 0x0, 0x1b, 400 0x6, 401 0x0, 0x0, 0x0, 0x7c, 402 0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o', 403 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 404 0x0, 0x0, 0x0, 0x3, 'b', 'a', 'r', 405 }, 406 }, 407 { 408 packet: &sshFxpSetstatPacket{ 409 ID: 31, 410 Path: "/bar", 411 Flags: sshFileXferAttrUIDGID, 412 Attrs: struct { 413 UID uint32 414 GID uint32 415 }{ 416 UID: 1000, 417 GID: 100, 418 }, 419 }, 420 want: []byte{ 421 0x0, 0x0, 0x0, 0x19, 422 0x9, 423 0x0, 0x0, 0x0, 0x1f, 424 0x0, 0x0, 0x0, 0x4, '/', 'b', 'a', 'r', 425 0x0, 0x0, 0x0, 0x2, 426 0x0, 0x0, 0x3, 0xe8, 427 0x0, 0x0, 0x0, 0x64, 428 }, 429 }, 430 } 431 432 for _, tt := range tests { 433 b := new(bytes.Buffer) 434 sendPacket(b, tt.packet) 435 if got := b.Bytes(); !bytes.Equal(tt.want, got) { 436 t.Errorf("sendPacket(%v): got %x want %x", tt.packet, tt.want, got) 437 } 438 } 439 } 440 441 func sp(data encoding.BinaryMarshaler) []byte { 442 b := new(bytes.Buffer) 443 sendPacket(b, data) 444 return b.Bytes() 445 } 446 447 func TestRecvPacket(t *testing.T) { 448 var recvPacketTests = []struct { 449 b []byte 450 451 want uint8 452 body []byte 453 wantErr error 454 }{ 455 { 456 b: sp(&sshFxInitPacket{ 457 Version: 3, 458 Extensions: []extensionPair{ 459 {"posix-rename@openssh.com", "1"}, 460 }, 461 }), 462 want: sshFxpInit, 463 body: []byte{ 464 0x0, 0x0, 0x0, 0x3, 465 0x0, 0x0, 0x0, 0x18, 466 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', 467 0x0, 0x0, 0x0, 0x01, 468 '1', 469 }, 470 }, 471 { 472 b: []byte{ 473 0x0, 0x0, 0x0, 0x0, 474 }, 475 wantErr: errShortPacket, 476 }, 477 { 478 b: []byte{ 479 0xff, 0xff, 0xff, 0xff, 480 }, 481 wantErr: errLongPacket, 482 }, 483 } 484 485 for _, tt := range recvPacketTests { 486 r := bytes.NewReader(tt.b) 487 488 got, body, err := recvPacket(r, nil, 0) 489 if tt.wantErr == nil { 490 if err != nil { 491 t.Fatalf("recvPacket(%#v): unexpected error: %v", tt.b, err) 492 } 493 } else { 494 if !errors.Is(err, tt.wantErr) { 495 t.Fatalf("recvPacket(%#v) = %v, want %v", tt.b, err, tt.wantErr) 496 } 497 } 498 499 if got != tt.want { 500 t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, got, tt.want) 501 } 502 503 if !bytes.Equal(body, tt.body) { 504 t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, body, tt.body) 505 } 506 } 507 } 508 509 func TestSSHFxpOpenPacketreadonly(t *testing.T) { 510 var tests = []struct { 511 pflags uint32 512 ok bool 513 }{ 514 { 515 pflags: sshFxfRead, 516 ok: true, 517 }, 518 { 519 pflags: sshFxfWrite, 520 ok: false, 521 }, 522 { 523 pflags: sshFxfRead | sshFxfWrite, 524 ok: false, 525 }, 526 } 527 528 for _, tt := range tests { 529 p := &sshFxpOpenPacket{ 530 Pflags: tt.pflags, 531 } 532 533 if want, got := tt.ok, p.readonly(); want != got { 534 t.Errorf("unexpected value for p.readonly(): want: %v, got: %v", 535 want, got) 536 } 537 } 538 } 539 540 func TestSSHFxpOpenPackethasPflags(t *testing.T) { 541 var tests = []struct { 542 desc string 543 haveFlags uint32 544 testFlags []uint32 545 ok bool 546 }{ 547 { 548 desc: "have read, test against write", 549 haveFlags: sshFxfRead, 550 testFlags: []uint32{sshFxfWrite}, 551 ok: false, 552 }, 553 { 554 desc: "have write, test against read", 555 haveFlags: sshFxfWrite, 556 testFlags: []uint32{sshFxfRead}, 557 ok: false, 558 }, 559 { 560 desc: "have read+write, test against read", 561 haveFlags: sshFxfRead | sshFxfWrite, 562 testFlags: []uint32{sshFxfRead}, 563 ok: true, 564 }, 565 { 566 desc: "have read+write, test against write", 567 haveFlags: sshFxfRead | sshFxfWrite, 568 testFlags: []uint32{sshFxfWrite}, 569 ok: true, 570 }, 571 { 572 desc: "have read+write, test against read+write", 573 haveFlags: sshFxfRead | sshFxfWrite, 574 testFlags: []uint32{sshFxfRead, sshFxfWrite}, 575 ok: true, 576 }, 577 } 578 579 for _, tt := range tests { 580 t.Log(tt.desc) 581 582 p := &sshFxpOpenPacket{ 583 Pflags: tt.haveFlags, 584 } 585 586 if want, got := tt.ok, p.hasPflags(tt.testFlags...); want != got { 587 t.Errorf("unexpected value for p.hasPflags(%#v): want: %v, got: %v", 588 tt.testFlags, want, got) 589 } 590 } 591 } 592 593 func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) { 594 b.ResetTimer() 595 596 for i := 0; i < b.N; i++ { 597 sendPacket(ioutil.Discard, packet) 598 } 599 } 600 601 func BenchmarkMarshalInit(b *testing.B) { 602 benchMarshal(b, &sshFxInitPacket{ 603 Version: 3, 604 Extensions: []extensionPair{ 605 {"posix-rename@openssh.com", "1"}, 606 }, 607 }) 608 } 609 610 func BenchmarkMarshalOpen(b *testing.B) { 611 benchMarshal(b, &sshFxpOpenPacket{ 612 ID: 1, 613 Path: "/home/test/some/random/path", 614 Pflags: flags(os.O_RDONLY), 615 }) 616 } 617 618 func BenchmarkMarshalWriteWorstCase(b *testing.B) { 619 data := make([]byte, 32*1024) 620 621 benchMarshal(b, &sshFxpWritePacket{ 622 ID: 1, 623 Handle: "someopaquehandle", 624 Offset: 0, 625 Length: uint32(len(data)), 626 Data: data, 627 }) 628 } 629 630 func BenchmarkMarshalWrite1k(b *testing.B) { 631 data := make([]byte, 1025) 632 633 benchMarshal(b, &sshFxpWritePacket{ 634 ID: 1, 635 Handle: "someopaquehandle", 636 Offset: 0, 637 Length: uint32(len(data)), 638 Data: data, 639 }) 640 }