github.com/pkg/sftp@v1.13.6/packet-manager.go (about) 1 package sftp 2 3 import ( 4 "encoding" 5 "sort" 6 "sync" 7 ) 8 9 // The goal of the packetManager is to keep the outgoing packets in the same 10 // order as the incoming as is requires by section 7 of the RFC. 11 12 type packetManager struct { 13 requests chan orderedPacket 14 responses chan orderedPacket 15 fini chan struct{} 16 incoming orderedPackets 17 outgoing orderedPackets 18 sender packetSender // connection object 19 working *sync.WaitGroup 20 packetCount uint32 21 // it is not nil if the allocator is enabled 22 alloc *allocator 23 } 24 25 type packetSender interface { 26 sendPacket(encoding.BinaryMarshaler) error 27 } 28 29 func newPktMgr(sender packetSender) *packetManager { 30 s := &packetManager{ 31 requests: make(chan orderedPacket, SftpServerWorkerCount), 32 responses: make(chan orderedPacket, SftpServerWorkerCount), 33 fini: make(chan struct{}), 34 incoming: make([]orderedPacket, 0, SftpServerWorkerCount), 35 outgoing: make([]orderedPacket, 0, SftpServerWorkerCount), 36 sender: sender, 37 working: &sync.WaitGroup{}, 38 } 39 go s.controller() 40 return s 41 } 42 43 // // packet ordering 44 func (s *packetManager) newOrderID() uint32 { 45 s.packetCount++ 46 return s.packetCount 47 } 48 49 // returns the next orderID without incrementing it. 50 // This is used before receiving a new packet, with the allocator enabled, to associate 51 // the slice allocated for the received packet with the orderID that will be used to mark 52 // the allocated slices for reuse once the request is served 53 func (s *packetManager) getNextOrderID() uint32 { 54 return s.packetCount + 1 55 } 56 57 type orderedRequest struct { 58 requestPacket 59 orderid uint32 60 } 61 62 func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest { 63 return orderedRequest{requestPacket: p, orderid: s.newOrderID()} 64 } 65 func (p orderedRequest) orderID() uint32 { return p.orderid } 66 func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid } 67 68 type orderedResponse struct { 69 responsePacket 70 orderid uint32 71 } 72 73 func (s *packetManager) newOrderedResponse(p responsePacket, id uint32, 74 ) orderedResponse { 75 return orderedResponse{responsePacket: p, orderid: id} 76 } 77 func (p orderedResponse) orderID() uint32 { return p.orderid } 78 func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid } 79 80 type orderedPacket interface { 81 id() uint32 82 orderID() uint32 83 } 84 type orderedPackets []orderedPacket 85 86 func (o orderedPackets) Sort() { 87 sort.Slice(o, func(i, j int) bool { 88 return o[i].orderID() < o[j].orderID() 89 }) 90 } 91 92 // // packet registry 93 // register incoming packets to be handled 94 func (s *packetManager) incomingPacket(pkt orderedRequest) { 95 s.working.Add(1) 96 s.requests <- pkt 97 } 98 99 // register outgoing packets as being ready 100 func (s *packetManager) readyPacket(pkt orderedResponse) { 101 s.responses <- pkt 102 s.working.Done() 103 } 104 105 // shut down packetManager controller 106 func (s *packetManager) close() { 107 // pause until current packets are processed 108 s.working.Wait() 109 close(s.fini) 110 } 111 112 // Passed a worker function, returns a channel for incoming packets. 113 // Keep process packet responses in the order they are received while 114 // maximizing throughput of file transfers. 115 func (s *packetManager) workerChan(runWorker func(chan orderedRequest), 116 ) chan orderedRequest { 117 // multiple workers for faster read/writes 118 rwChan := make(chan orderedRequest, SftpServerWorkerCount) 119 for i := 0; i < SftpServerWorkerCount; i++ { 120 runWorker(rwChan) 121 } 122 123 // single worker to enforce sequential processing of everything else 124 cmdChan := make(chan orderedRequest) 125 runWorker(cmdChan) 126 127 pktChan := make(chan orderedRequest, SftpServerWorkerCount) 128 go func() { 129 for pkt := range pktChan { 130 switch pkt.requestPacket.(type) { 131 case *sshFxpReadPacket, *sshFxpWritePacket: 132 s.incomingPacket(pkt) 133 rwChan <- pkt 134 continue 135 case *sshFxpClosePacket: 136 // wait for reads/writes to finish when file is closed 137 // incomingPacket() call must occur after this 138 s.working.Wait() 139 } 140 s.incomingPacket(pkt) 141 // all non-RW use sequential cmdChan 142 cmdChan <- pkt 143 } 144 close(rwChan) 145 close(cmdChan) 146 s.close() 147 }() 148 149 return pktChan 150 } 151 152 // process packets 153 func (s *packetManager) controller() { 154 for { 155 select { 156 case pkt := <-s.requests: 157 debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID()) 158 s.incoming = append(s.incoming, pkt) 159 s.incoming.Sort() 160 case pkt := <-s.responses: 161 debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID()) 162 s.outgoing = append(s.outgoing, pkt) 163 s.outgoing.Sort() 164 case <-s.fini: 165 return 166 } 167 s.maybeSendPackets() 168 } 169 } 170 171 // send as many packets as are ready 172 func (s *packetManager) maybeSendPackets() { 173 for { 174 if len(s.outgoing) == 0 || len(s.incoming) == 0 { 175 debug("break! -- outgoing: %v; incoming: %v", 176 len(s.outgoing), len(s.incoming)) 177 break 178 } 179 out := s.outgoing[0] 180 in := s.incoming[0] 181 // debug("incoming: %v", ids(s.incoming)) 182 // debug("outgoing: %v", ids(s.outgoing)) 183 if in.orderID() == out.orderID() { 184 debug("Sending packet: %v", out.id()) 185 s.sender.sendPacket(out.(encoding.BinaryMarshaler)) 186 if s.alloc != nil { 187 // mark for reuse the slices allocated for this request 188 s.alloc.ReleasePages(in.orderID()) 189 } 190 // pop off heads 191 copy(s.incoming, s.incoming[1:]) // shift left 192 s.incoming[len(s.incoming)-1] = nil // clear last 193 s.incoming = s.incoming[:len(s.incoming)-1] // remove last 194 copy(s.outgoing, s.outgoing[1:]) // shift left 195 s.outgoing[len(s.outgoing)-1] = nil // clear last 196 s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last 197 } else { 198 break 199 } 200 } 201 } 202 203 // func oids(o []orderedPacket) []uint32 { 204 // res := make([]uint32, 0, len(o)) 205 // for _, v := range o { 206 // res = append(res, v.orderId()) 207 // } 208 // return res 209 // } 210 // func ids(o []orderedPacket) []uint32 { 211 // res := make([]uint32, 0, len(o)) 212 // for _, v := range o { 213 // res = append(res, v.id()) 214 // } 215 // return res 216 // }