github.com/pkg/sftp@v1.13.6/packet-manager_test.go (about) 1 package sftp 2 3 import ( 4 "encoding" 5 "fmt" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 type _testSender struct { 12 sent chan encoding.BinaryMarshaler 13 } 14 15 func newTestSender() *_testSender { 16 return &_testSender{make(chan encoding.BinaryMarshaler)} 17 } 18 19 func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error { 20 s.sent <- p 21 return nil 22 } 23 24 type fakepacket struct { 25 reqid uint32 26 oid uint32 27 } 28 29 func fake(rid, order uint32) fakepacket { 30 return fakepacket{reqid: rid, oid: order} 31 } 32 33 func (fakepacket) MarshalBinary() ([]byte, error) { 34 return make([]byte, 4), nil 35 } 36 37 func (fakepacket) UnmarshalBinary([]byte) error { 38 return nil 39 } 40 41 func (f fakepacket) id() uint32 { 42 return f.reqid 43 } 44 45 type pair struct { 46 in, out fakepacket 47 } 48 49 type orderedPair struct { 50 in orderedRequest 51 out orderedResponse 52 } 53 54 // basic test 55 var ttable1 = []pair{ 56 {fake(0, 0), fake(0, 0)}, 57 {fake(1, 1), fake(1, 1)}, 58 {fake(2, 2), fake(2, 2)}, 59 {fake(3, 3), fake(3, 3)}, 60 } 61 62 // outgoing packets out of order 63 var ttable2 = []pair{ 64 {fake(10, 0), fake(12, 2)}, 65 {fake(11, 1), fake(11, 1)}, 66 {fake(12, 2), fake(13, 3)}, 67 {fake(13, 3), fake(10, 0)}, 68 } 69 70 // request ids are not incremental 71 var ttable3 = []pair{ 72 {fake(7, 0), fake(7, 0)}, 73 {fake(1, 1), fake(1, 1)}, 74 {fake(9, 2), fake(3, 3)}, 75 {fake(3, 3), fake(9, 2)}, 76 } 77 78 // request ids are all the same 79 var ttable4 = []pair{ 80 {fake(1, 0), fake(1, 0)}, 81 {fake(1, 1), fake(1, 1)}, 82 {fake(1, 2), fake(1, 3)}, 83 {fake(1, 3), fake(1, 2)}, 84 } 85 86 var tables = [][]pair{ttable1, ttable2, ttable3, ttable4} 87 88 func TestPacketManager(t *testing.T) { 89 sender := newTestSender() 90 s := newPktMgr(sender) 91 92 for i := range tables { 93 table := tables[i] 94 orderedPairs := make([]orderedPair, 0, len(table)) 95 for _, p := range table { 96 orderedPairs = append(orderedPairs, orderedPair{ 97 in: orderedRequest{p.in, p.in.oid}, 98 out: orderedResponse{p.out, p.out.oid}, 99 }) 100 } 101 for _, p := range orderedPairs { 102 s.incomingPacket(p.in) 103 } 104 for _, p := range orderedPairs { 105 s.readyPacket(p.out) 106 } 107 for _, p := range table { 108 pkt := <-sender.sent 109 id := pkt.(orderedResponse).id() 110 assert.Equal(t, id, p.in.id()) 111 } 112 } 113 s.close() 114 } 115 116 func (p sshFxpRemovePacket) String() string { 117 return fmt.Sprintf("RmPkt:%d", p.ID) 118 } 119 func (p sshFxpOpenPacket) String() string { 120 return fmt.Sprintf("OpPkt:%d", p.ID) 121 } 122 func (p sshFxpWritePacket) String() string { 123 return fmt.Sprintf("WrPkt:%d", p.ID) 124 } 125 func (p sshFxpClosePacket) String() string { 126 return fmt.Sprintf("ClPkt:%d", p.ID) 127 }