github.com/pkg/sftp@v1.13.6/packet-manager_test.go (about)

     1  package sftp
     2  
     3  import (
     4  	"encoding"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  type _testSender struct {
    12  	sent chan encoding.BinaryMarshaler
    13  }
    14  
    15  func newTestSender() *_testSender {
    16  	return &_testSender{make(chan encoding.BinaryMarshaler)}
    17  }
    18  
    19  func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
    20  	s.sent <- p
    21  	return nil
    22  }
    23  
    24  type fakepacket struct {
    25  	reqid uint32
    26  	oid   uint32
    27  }
    28  
    29  func fake(rid, order uint32) fakepacket {
    30  	return fakepacket{reqid: rid, oid: order}
    31  }
    32  
    33  func (fakepacket) MarshalBinary() ([]byte, error) {
    34  	return make([]byte, 4), nil
    35  }
    36  
    37  func (fakepacket) UnmarshalBinary([]byte) error {
    38  	return nil
    39  }
    40  
    41  func (f fakepacket) id() uint32 {
    42  	return f.reqid
    43  }
    44  
    45  type pair struct {
    46  	in, out fakepacket
    47  }
    48  
    49  type orderedPair struct {
    50  	in  orderedRequest
    51  	out orderedResponse
    52  }
    53  
    54  // basic test
    55  var ttable1 = []pair{
    56  	{fake(0, 0), fake(0, 0)},
    57  	{fake(1, 1), fake(1, 1)},
    58  	{fake(2, 2), fake(2, 2)},
    59  	{fake(3, 3), fake(3, 3)},
    60  }
    61  
    62  // outgoing packets out of order
    63  var ttable2 = []pair{
    64  	{fake(10, 0), fake(12, 2)},
    65  	{fake(11, 1), fake(11, 1)},
    66  	{fake(12, 2), fake(13, 3)},
    67  	{fake(13, 3), fake(10, 0)},
    68  }
    69  
    70  // request ids are not incremental
    71  var ttable3 = []pair{
    72  	{fake(7, 0), fake(7, 0)},
    73  	{fake(1, 1), fake(1, 1)},
    74  	{fake(9, 2), fake(3, 3)},
    75  	{fake(3, 3), fake(9, 2)},
    76  }
    77  
    78  // request ids are all the same
    79  var ttable4 = []pair{
    80  	{fake(1, 0), fake(1, 0)},
    81  	{fake(1, 1), fake(1, 1)},
    82  	{fake(1, 2), fake(1, 3)},
    83  	{fake(1, 3), fake(1, 2)},
    84  }
    85  
    86  var tables = [][]pair{ttable1, ttable2, ttable3, ttable4}
    87  
    88  func TestPacketManager(t *testing.T) {
    89  	sender := newTestSender()
    90  	s := newPktMgr(sender)
    91  
    92  	for i := range tables {
    93  		table := tables[i]
    94  		orderedPairs := make([]orderedPair, 0, len(table))
    95  		for _, p := range table {
    96  			orderedPairs = append(orderedPairs, orderedPair{
    97  				in:  orderedRequest{p.in, p.in.oid},
    98  				out: orderedResponse{p.out, p.out.oid},
    99  			})
   100  		}
   101  		for _, p := range orderedPairs {
   102  			s.incomingPacket(p.in)
   103  		}
   104  		for _, p := range orderedPairs {
   105  			s.readyPacket(p.out)
   106  		}
   107  		for _, p := range table {
   108  			pkt := <-sender.sent
   109  			id := pkt.(orderedResponse).id()
   110  			assert.Equal(t, id, p.in.id())
   111  		}
   112  	}
   113  	s.close()
   114  }
   115  
   116  func (p sshFxpRemovePacket) String() string {
   117  	return fmt.Sprintf("RmPkt:%d", p.ID)
   118  }
   119  func (p sshFxpOpenPacket) String() string {
   120  	return fmt.Sprintf("OpPkt:%d", p.ID)
   121  }
   122  func (p sshFxpWritePacket) String() string {
   123  	return fmt.Sprintf("WrPkt:%d", p.ID)
   124  }
   125  func (p sshFxpClosePacket) String() string {
   126  	return fmt.Sprintf("ClPkt:%d", p.ID)
   127  }