github.com/pkg/sftp@v1.13.6/conn.go (about)

     1  package sftp
     2  
     3  import (
     4  	"encoding"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  )
     9  
    10  // conn implements a bidirectional channel on which client and server
    11  // connections are multiplexed.
    12  type conn struct {
    13  	io.Reader
    14  	io.WriteCloser
    15  	// this is the same allocator used in packet manager
    16  	alloc      *allocator
    17  	sync.Mutex // used to serialise writes to sendPacket
    18  }
    19  
    20  // the orderID is used in server mode if the allocator is enabled.
    21  // For the client mode just pass 0.
    22  // It returns io.EOF if the connection is closed and
    23  // there are no more packets to read.
    24  func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
    25  	return recvPacket(c, c.alloc, orderID)
    26  }
    27  
    28  func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
    29  	c.Lock()
    30  	defer c.Unlock()
    31  
    32  	return sendPacket(c, m)
    33  }
    34  
    35  func (c *conn) Close() error {
    36  	c.Lock()
    37  	defer c.Unlock()
    38  	return c.WriteCloser.Close()
    39  }
    40  
    41  type clientConn struct {
    42  	conn
    43  	wg sync.WaitGroup
    44  
    45  	sync.Mutex                          // protects inflight
    46  	inflight   map[uint32]chan<- result // outstanding requests
    47  
    48  	closed chan struct{}
    49  	err    error
    50  }
    51  
    52  // Wait blocks until the conn has shut down, and return the error
    53  // causing the shutdown. It can be called concurrently from multiple
    54  // goroutines.
    55  func (c *clientConn) Wait() error {
    56  	<-c.closed
    57  	return c.err
    58  }
    59  
    60  // Close closes the SFTP session.
    61  func (c *clientConn) Close() error {
    62  	defer c.wg.Wait()
    63  	return c.conn.Close()
    64  }
    65  
    66  // recv continuously reads from the server and forwards responses to the
    67  // appropriate channel.
    68  func (c *clientConn) recv() error {
    69  	defer c.conn.Close()
    70  
    71  	for {
    72  		typ, data, err := c.recvPacket(0)
    73  		if err != nil {
    74  			return err
    75  		}
    76  		sid, _, err := unmarshalUint32Safe(data)
    77  		if err != nil {
    78  			return err
    79  		}
    80  
    81  		ch, ok := c.getChannel(sid)
    82  		if !ok {
    83  			// This is an unexpected occurrence. Send the error
    84  			// back to all listeners so that they terminate
    85  			// gracefully.
    86  			return fmt.Errorf("sid not found: %d", sid)
    87  		}
    88  
    89  		ch <- result{typ: typ, data: data}
    90  	}
    91  }
    92  
    93  func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
    94  	c.Lock()
    95  	defer c.Unlock()
    96  
    97  	select {
    98  	case <-c.closed:
    99  		// already closed with broadcastErr, return error on chan.
   100  		ch <- result{err: ErrSSHFxConnectionLost}
   101  		return false
   102  	default:
   103  	}
   104  
   105  	c.inflight[sid] = ch
   106  	return true
   107  }
   108  
   109  func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
   110  	c.Lock()
   111  	defer c.Unlock()
   112  
   113  	ch, ok := c.inflight[sid]
   114  	delete(c.inflight, sid)
   115  
   116  	return ch, ok
   117  }
   118  
   119  // result captures the result of receiving the a packet from the server
   120  type result struct {
   121  	typ  byte
   122  	data []byte
   123  	err  error
   124  }
   125  
   126  type idmarshaler interface {
   127  	id() uint32
   128  	encoding.BinaryMarshaler
   129  }
   130  
   131  func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
   132  	if cap(ch) < 1 {
   133  		ch = make(chan result, 1)
   134  	}
   135  
   136  	c.dispatchRequest(ch, p)
   137  	s := <-ch
   138  	return s.typ, s.data, s.err
   139  }
   140  
   141  // dispatchRequest should ideally only be called by race-detection tests outside of this file,
   142  // where you have to ensure two packets are in flight sequentially after each other.
   143  func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
   144  	sid := p.id()
   145  
   146  	if !c.putChannel(ch, sid) {
   147  		// already closed.
   148  		return
   149  	}
   150  
   151  	if err := c.conn.sendPacket(p); err != nil {
   152  		if ch, ok := c.getChannel(sid); ok {
   153  			ch <- result{err: err}
   154  		}
   155  	}
   156  }
   157  
   158  // broadcastErr sends an error to all goroutines waiting for a response.
   159  func (c *clientConn) broadcastErr(err error) {
   160  	c.Lock()
   161  	defer c.Unlock()
   162  
   163  	bcastRes := result{err: ErrSSHFxConnectionLost}
   164  	for sid, ch := range c.inflight {
   165  		ch <- bcastRes
   166  
   167  		// Replace the chan in inflight,
   168  		// we have hijacked this chan,
   169  		// and this guarantees always-only-once sending.
   170  		c.inflight[sid] = make(chan<- result, 1)
   171  	}
   172  
   173  	c.err = err
   174  	close(c.closed)
   175  }
   176  
   177  type serverConn struct {
   178  	conn
   179  }
   180  
   181  func (s *serverConn) sendError(id uint32, err error) error {
   182  	return s.sendPacket(statusFromError(id, err))
   183  }