github.com/pkg/sftp@v1.13.6/packet-manager.go (about)

     1  package sftp
     2  
     3  import (
     4  	"encoding"
     5  	"sort"
     6  	"sync"
     7  )
     8  
     9  // The goal of the packetManager is to keep the outgoing packets in the same
    10  // order as the incoming as is requires by section 7 of the RFC.
    11  
    12  type packetManager struct {
    13  	requests    chan orderedPacket
    14  	responses   chan orderedPacket
    15  	fini        chan struct{}
    16  	incoming    orderedPackets
    17  	outgoing    orderedPackets
    18  	sender      packetSender // connection object
    19  	working     *sync.WaitGroup
    20  	packetCount uint32
    21  	// it is not nil if the allocator is enabled
    22  	alloc *allocator
    23  }
    24  
    25  type packetSender interface {
    26  	sendPacket(encoding.BinaryMarshaler) error
    27  }
    28  
    29  func newPktMgr(sender packetSender) *packetManager {
    30  	s := &packetManager{
    31  		requests:  make(chan orderedPacket, SftpServerWorkerCount),
    32  		responses: make(chan orderedPacket, SftpServerWorkerCount),
    33  		fini:      make(chan struct{}),
    34  		incoming:  make([]orderedPacket, 0, SftpServerWorkerCount),
    35  		outgoing:  make([]orderedPacket, 0, SftpServerWorkerCount),
    36  		sender:    sender,
    37  		working:   &sync.WaitGroup{},
    38  	}
    39  	go s.controller()
    40  	return s
    41  }
    42  
    43  // // packet ordering
    44  func (s *packetManager) newOrderID() uint32 {
    45  	s.packetCount++
    46  	return s.packetCount
    47  }
    48  
    49  // returns the next orderID without incrementing it.
    50  // This is used before receiving a new packet, with the allocator enabled, to associate
    51  // the slice allocated for the received packet with the orderID that will be used to mark
    52  // the allocated slices for reuse once the request is served
    53  func (s *packetManager) getNextOrderID() uint32 {
    54  	return s.packetCount + 1
    55  }
    56  
    57  type orderedRequest struct {
    58  	requestPacket
    59  	orderid uint32
    60  }
    61  
    62  func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
    63  	return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
    64  }
    65  func (p orderedRequest) orderID() uint32       { return p.orderid }
    66  func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }
    67  
    68  type orderedResponse struct {
    69  	responsePacket
    70  	orderid uint32
    71  }
    72  
    73  func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
    74  ) orderedResponse {
    75  	return orderedResponse{responsePacket: p, orderid: id}
    76  }
    77  func (p orderedResponse) orderID() uint32       { return p.orderid }
    78  func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }
    79  
    80  type orderedPacket interface {
    81  	id() uint32
    82  	orderID() uint32
    83  }
    84  type orderedPackets []orderedPacket
    85  
    86  func (o orderedPackets) Sort() {
    87  	sort.Slice(o, func(i, j int) bool {
    88  		return o[i].orderID() < o[j].orderID()
    89  	})
    90  }
    91  
    92  // // packet registry
    93  // register incoming packets to be handled
    94  func (s *packetManager) incomingPacket(pkt orderedRequest) {
    95  	s.working.Add(1)
    96  	s.requests <- pkt
    97  }
    98  
    99  // register outgoing packets as being ready
   100  func (s *packetManager) readyPacket(pkt orderedResponse) {
   101  	s.responses <- pkt
   102  	s.working.Done()
   103  }
   104  
   105  // shut down packetManager controller
   106  func (s *packetManager) close() {
   107  	// pause until current packets are processed
   108  	s.working.Wait()
   109  	close(s.fini)
   110  }
   111  
   112  // Passed a worker function, returns a channel for incoming packets.
   113  // Keep process packet responses in the order they are received while
   114  // maximizing throughput of file transfers.
   115  func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
   116  ) chan orderedRequest {
   117  	// multiple workers for faster read/writes
   118  	rwChan := make(chan orderedRequest, SftpServerWorkerCount)
   119  	for i := 0; i < SftpServerWorkerCount; i++ {
   120  		runWorker(rwChan)
   121  	}
   122  
   123  	// single worker to enforce sequential processing of everything else
   124  	cmdChan := make(chan orderedRequest)
   125  	runWorker(cmdChan)
   126  
   127  	pktChan := make(chan orderedRequest, SftpServerWorkerCount)
   128  	go func() {
   129  		for pkt := range pktChan {
   130  			switch pkt.requestPacket.(type) {
   131  			case *sshFxpReadPacket, *sshFxpWritePacket:
   132  				s.incomingPacket(pkt)
   133  				rwChan <- pkt
   134  				continue
   135  			case *sshFxpClosePacket:
   136  				// wait for reads/writes to finish when file is closed
   137  				// incomingPacket() call must occur after this
   138  				s.working.Wait()
   139  			}
   140  			s.incomingPacket(pkt)
   141  			// all non-RW use sequential cmdChan
   142  			cmdChan <- pkt
   143  		}
   144  		close(rwChan)
   145  		close(cmdChan)
   146  		s.close()
   147  	}()
   148  
   149  	return pktChan
   150  }
   151  
   152  // process packets
   153  func (s *packetManager) controller() {
   154  	for {
   155  		select {
   156  		case pkt := <-s.requests:
   157  			debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
   158  			s.incoming = append(s.incoming, pkt)
   159  			s.incoming.Sort()
   160  		case pkt := <-s.responses:
   161  			debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
   162  			s.outgoing = append(s.outgoing, pkt)
   163  			s.outgoing.Sort()
   164  		case <-s.fini:
   165  			return
   166  		}
   167  		s.maybeSendPackets()
   168  	}
   169  }
   170  
   171  // send as many packets as are ready
   172  func (s *packetManager) maybeSendPackets() {
   173  	for {
   174  		if len(s.outgoing) == 0 || len(s.incoming) == 0 {
   175  			debug("break! -- outgoing: %v; incoming: %v",
   176  				len(s.outgoing), len(s.incoming))
   177  			break
   178  		}
   179  		out := s.outgoing[0]
   180  		in := s.incoming[0]
   181  		// debug("incoming: %v", ids(s.incoming))
   182  		// debug("outgoing: %v", ids(s.outgoing))
   183  		if in.orderID() == out.orderID() {
   184  			debug("Sending packet: %v", out.id())
   185  			s.sender.sendPacket(out.(encoding.BinaryMarshaler))
   186  			if s.alloc != nil {
   187  				// mark for reuse the slices allocated for this request
   188  				s.alloc.ReleasePages(in.orderID())
   189  			}
   190  			// pop off heads
   191  			copy(s.incoming, s.incoming[1:])            // shift left
   192  			s.incoming[len(s.incoming)-1] = nil         // clear last
   193  			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
   194  			copy(s.outgoing, s.outgoing[1:])            // shift left
   195  			s.outgoing[len(s.outgoing)-1] = nil         // clear last
   196  			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
   197  		} else {
   198  			break
   199  		}
   200  	}
   201  }
   202  
   203  // func oids(o []orderedPacket) []uint32 {
   204  // 	res := make([]uint32, 0, len(o))
   205  // 	for _, v := range o {
   206  // 		res = append(res, v.orderId())
   207  // 	}
   208  // 	return res
   209  // }
   210  // func ids(o []orderedPacket) []uint32 {
   211  // 	res := make([]uint32, 0, len(o))
   212  // 	for _, v := range o {
   213  // 		res = append(res, v.id())
   214  // 	}
   215  // 	return res
   216  // }