github.com/pkg/sftp@v1.13.6/conn.go (about) 1 package sftp 2 3 import ( 4 "encoding" 5 "fmt" 6 "io" 7 "sync" 8 ) 9 10 // conn implements a bidirectional channel on which client and server 11 // connections are multiplexed. 12 type conn struct { 13 io.Reader 14 io.WriteCloser 15 // this is the same allocator used in packet manager 16 alloc *allocator 17 sync.Mutex // used to serialise writes to sendPacket 18 } 19 20 // the orderID is used in server mode if the allocator is enabled. 21 // For the client mode just pass 0. 22 // It returns io.EOF if the connection is closed and 23 // there are no more packets to read. 24 func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { 25 return recvPacket(c, c.alloc, orderID) 26 } 27 28 func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { 29 c.Lock() 30 defer c.Unlock() 31 32 return sendPacket(c, m) 33 } 34 35 func (c *conn) Close() error { 36 c.Lock() 37 defer c.Unlock() 38 return c.WriteCloser.Close() 39 } 40 41 type clientConn struct { 42 conn 43 wg sync.WaitGroup 44 45 sync.Mutex // protects inflight 46 inflight map[uint32]chan<- result // outstanding requests 47 48 closed chan struct{} 49 err error 50 } 51 52 // Wait blocks until the conn has shut down, and return the error 53 // causing the shutdown. It can be called concurrently from multiple 54 // goroutines. 55 func (c *clientConn) Wait() error { 56 <-c.closed 57 return c.err 58 } 59 60 // Close closes the SFTP session. 61 func (c *clientConn) Close() error { 62 defer c.wg.Wait() 63 return c.conn.Close() 64 } 65 66 // recv continuously reads from the server and forwards responses to the 67 // appropriate channel. 68 func (c *clientConn) recv() error { 69 defer c.conn.Close() 70 71 for { 72 typ, data, err := c.recvPacket(0) 73 if err != nil { 74 return err 75 } 76 sid, _, err := unmarshalUint32Safe(data) 77 if err != nil { 78 return err 79 } 80 81 ch, ok := c.getChannel(sid) 82 if !ok { 83 // This is an unexpected occurrence. Send the error 84 // back to all listeners so that they terminate 85 // gracefully. 86 return fmt.Errorf("sid not found: %d", sid) 87 } 88 89 ch <- result{typ: typ, data: data} 90 } 91 } 92 93 func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool { 94 c.Lock() 95 defer c.Unlock() 96 97 select { 98 case <-c.closed: 99 // already closed with broadcastErr, return error on chan. 100 ch <- result{err: ErrSSHFxConnectionLost} 101 return false 102 default: 103 } 104 105 c.inflight[sid] = ch 106 return true 107 } 108 109 func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { 110 c.Lock() 111 defer c.Unlock() 112 113 ch, ok := c.inflight[sid] 114 delete(c.inflight, sid) 115 116 return ch, ok 117 } 118 119 // result captures the result of receiving the a packet from the server 120 type result struct { 121 typ byte 122 data []byte 123 err error 124 } 125 126 type idmarshaler interface { 127 id() uint32 128 encoding.BinaryMarshaler 129 } 130 131 func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) { 132 if cap(ch) < 1 { 133 ch = make(chan result, 1) 134 } 135 136 c.dispatchRequest(ch, p) 137 s := <-ch 138 return s.typ, s.data, s.err 139 } 140 141 // dispatchRequest should ideally only be called by race-detection tests outside of this file, 142 // where you have to ensure two packets are in flight sequentially after each other. 143 func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { 144 sid := p.id() 145 146 if !c.putChannel(ch, sid) { 147 // already closed. 148 return 149 } 150 151 if err := c.conn.sendPacket(p); err != nil { 152 if ch, ok := c.getChannel(sid); ok { 153 ch <- result{err: err} 154 } 155 } 156 } 157 158 // broadcastErr sends an error to all goroutines waiting for a response. 159 func (c *clientConn) broadcastErr(err error) { 160 c.Lock() 161 defer c.Unlock() 162 163 bcastRes := result{err: ErrSSHFxConnectionLost} 164 for sid, ch := range c.inflight { 165 ch <- bcastRes 166 167 // Replace the chan in inflight, 168 // we have hijacked this chan, 169 // and this guarantees always-only-once sending. 170 c.inflight[sid] = make(chan<- result, 1) 171 } 172 173 c.err = err 174 close(c.closed) 175 } 176 177 type serverConn struct { 178 conn 179 } 180 181 func (s *serverConn) sendError(id uint32, err error) error { 182 return s.sendPacket(statusFromError(id, err)) 183 }