github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/dnet/common.go (about)

     1  // Package dnet contains alternative net.Conn implementations.
     2  package dnet
     3  
     4  import (
     5  	"bytes"
     6  	"net"
     7  	"os"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  )
    12  
    13  // Conn is a net.Conn, plus a Wait() method.
    14  type Conn interface {
    15  	net.Conn
    16  
    17  	// Wait until either the connection is Close()d, or one of Read() or Write() encounters an
    18  	// error (*not* counting errors caused by deadlines).  If this returns because Close() was
    19  	// called, nil is returned; otherwise the triggering error is returned.
    20  	//
    21  	// Essentially: Wait until the connection is finished.
    22  	Wait() error
    23  }
    24  
    25  type Addr struct {
    26  	Net  string
    27  	Addr string
    28  }
    29  
    30  func (a Addr) Network() string { return a.Net }
    31  func (a Addr) String() string  { return a.Addr }
    32  
    33  // UnbufferedConn represents a reliable fully-synchronous stream with *no* internal buffering.  But
    34  // really, what it is is "everything that isn't generic enough to be in bufferedConn".
    35  type UnbufferedConn interface {
    36  	// Receive some data over the connection.
    37  	Recv() ([]byte, error)
    38  
    39  	// Send some data over the connection.  Because the connection is fully-synchronous and has
    40  	// no internal buffering, Send must not return until the remote end acknowledges the full
    41  	// transmission (or an error is encountered).
    42  	Send([]byte) error
    43  
    44  	// MTU returns the largest amount of data that is permissible to include in a single Send
    45  	// call.
    46  	MTU() int
    47  
    48  	// CloseOnce closes both the read-end and write-end of the connection.  Any blocked Recv or
    49  	// Send operations will be unblocked and return errors.  It is an error to call CloseOnce
    50  	// more than once.
    51  	CloseOnce() error
    52  
    53  	// LocalAddr returns the local network address.
    54  	LocalAddr() net.Addr
    55  
    56  	// RemoteAddr returns the remote network address.
    57  	RemoteAddr() net.Addr
    58  }
    59  
    60  // bufferedConn is a net.Conn implementation that uses a reliable fully-synchronous stream as the
    61  // underlying transport.
    62  //
    63  // This at first appears more complex than it needs to be: Why have buffers and pump goroutines,
    64  // instead of simply synchronously calling .Send and .Recv on the underlying stream?  stdlib
    65  // net.TCPConn gets away with that for synchronously reading and writing the underlying file
    66  // descriptor, so why can't we get away this the same simplicity?  Because the OS kernel is doing
    67  // that same buffering and pumping for TCP; Go stdlib doesn't have to do it because it's happening
    68  // in kernel space.  But since we don't have a raw FD that the kernel can do things with, we have to
    69  // do that those things in userspace.
    70  type bufferedConn struct {
    71  	// configuration
    72  
    73  	conn UnbufferedConn
    74  
    75  	// state
    76  
    77  	closeOnce sync.Once
    78  	closed    int32 // atomic
    79  	closeErr  error
    80  
    81  	readCond     sync.Cond
    82  	readBuff     bytes.Buffer // must hold readCond.L to access
    83  	readErr      error        // must hold readCond.L to access
    84  	readDone     chan struct{}
    85  	readDeadline atomicDeadline
    86  
    87  	writeCond     sync.Cond
    88  	writeBuff     bytes.Buffer // must hold writeCond.L to access
    89  	writeErr      error        // must hold writeCond.L to access
    90  	writeDone     chan struct{}
    91  	writeDeadline atomicDeadline
    92  }
    93  
    94  func WrapUnbufferedConn(inner UnbufferedConn) Conn {
    95  	c := &bufferedConn{
    96  		conn: inner,
    97  
    98  		readDone:  make(chan struct{}),
    99  		writeDone: make(chan struct{}),
   100  
   101  		readCond:  sync.Cond{L: &sync.Mutex{}},
   102  		writeCond: sync.Cond{L: &sync.Mutex{}},
   103  	}
   104  
   105  	c.readDeadline = atomicDeadline{
   106  		cbMu: c.readCond.L,
   107  		cb:   c.readReset,
   108  	}
   109  	go c.readPump()
   110  
   111  	c.writeDeadline = atomicDeadline{
   112  		cbMu: c.writeCond.L,
   113  		cb:   c.writeReset,
   114  	}
   115  	go c.writePump()
   116  
   117  	return c
   118  }
   119  
   120  func (c *bufferedConn) isClosed() bool {
   121  	return atomic.LoadInt32(&c.closed) != 0
   122  }
   123  
   124  func (c *bufferedConn) readPump() {
   125  	defer close(c.readDone)
   126  
   127  	keepGoing := true
   128  	// use isClosedPipe(c.writeDone) instead of c.isClosed() to keep the readPump running just a
   129  	// little longer, in case the other end acking our writes is blocking on us acking their
   130  	// writes.
   131  	for keepGoing && !isClosedChan(c.writeDone) {
   132  		data, err := c.conn.Recv()
   133  
   134  		c.readCond.L.Lock()
   135  		if len(data) > 0 {
   136  			c.readBuff.Write(data)
   137  		}
   138  		if err != nil {
   139  			c.readErr = err
   140  			keepGoing = false
   141  		}
   142  		c.readCond.L.Unlock()
   143  
   144  		if len(data) > 0 || err != nil {
   145  			// .Broadcast() instead of .Signal() in case there are multiple waiting
   146  			// readers that are each asking for less than len(chunk.Content) bytes.
   147  			c.readCond.Broadcast()
   148  		}
   149  	}
   150  }
   151  
   152  // Read implements net.Conn.
   153  func (c *bufferedConn) Read(b []byte) (int, error) {
   154  	if len(b) == 0 {
   155  		return 0, nil
   156  	}
   157  
   158  	c.readCond.L.Lock()
   159  	defer c.readCond.L.Unlock()
   160  
   161  	for c.readBuff.Len() == 0 {
   162  		switch {
   163  		case c.readErr != nil:
   164  			return 0, c.readErr
   165  		case c.isClosed():
   166  			return 0, os.ErrClosed
   167  		case c.readDeadline.isCanceled():
   168  			return 0, os.ErrDeadlineExceeded
   169  		}
   170  		c.readCond.Wait()
   171  	}
   172  
   173  	return c.readBuff.Read(b)
   174  }
   175  
   176  // must hold c.readCond.Mu to call readReset.
   177  func (c *bufferedConn) readReset() {
   178  	c.readBuff.Reset()
   179  	// This isn't so much to notify readers of the readBuff.Reset(), but of *whatever caused it*.
   180  	c.readCond.Broadcast()
   181  }
   182  
   183  // Write implements net.Conn.
   184  func (c *bufferedConn) Write(b []byte) (int, error) {
   185  	c.writeCond.L.Lock()
   186  	defer c.writeCond.L.Unlock()
   187  
   188  	switch {
   189  	case c.writeErr != nil:
   190  		return 0, c.writeErr
   191  	case c.isClosed():
   192  		return 0, os.ErrClosed
   193  	case c.writeDeadline.isCanceled():
   194  		return 0, os.ErrDeadlineExceeded
   195  	}
   196  
   197  	n, err := c.writeBuff.Write(b)
   198  	if n > 0 {
   199  		// The only reader is the singular writePump, so don't bother with .Broadcast() when
   200  		// .Signal() is fine.
   201  		c.writeCond.Signal()
   202  	}
   203  	return n, err
   204  }
   205  
   206  // must hold c.writeCond.Mu to call writeReset.
   207  func (c *bufferedConn) writeReset() {
   208  	c.writeBuff.Reset()
   209  	// Don't bother with c.writeCond.Broadcast() because this will only ever make the condition
   210  	// false.
   211  }
   212  
   213  func (c *bufferedConn) writePump() {
   214  	defer close(c.writeDone)
   215  
   216  	var buff []byte
   217  
   218  	for {
   219  		// Get the data to write.
   220  		c.writeCond.L.Lock()
   221  		for c.writeBuff.Len() == 0 && !c.isClosed() {
   222  			c.writeCond.Wait()
   223  		}
   224  		if c.writeBuff.Len() == 0 {
   225  			// closed
   226  			c.writeCond.L.Unlock()
   227  			return
   228  		}
   229  		tu := c.writeBuff.Len() // "transmission unit", as in "MTU"
   230  		if mtu := c.conn.MTU(); mtu > 0 && tu > mtu {
   231  			tu = mtu
   232  		}
   233  		if len(buff) < tu {
   234  			buff = make([]byte, tu)
   235  		}
   236  		n, _ := c.writeBuff.Read(buff[:tu])
   237  		c.writeCond.L.Unlock()
   238  
   239  		// Write the data.
   240  		if err := c.conn.Send(buff[:n]); err != nil {
   241  			c.writeCond.L.Lock()
   242  			c.writeErr = err
   243  			c.writeCond.L.Unlock()
   244  			return
   245  		}
   246  	}
   247  }
   248  
   249  // Close implements net.Conn.  Both the read-end and the write-end are closed.  Any blocked Read or
   250  // Write operations will be unblocked and return errors.
   251  func (c *bufferedConn) Close() error {
   252  	c.closeOnce.Do(func() {
   253  		atomic.StoreInt32(&c.closed, 1)
   254  
   255  		// Don't c.writeReset(), let the write buffer drain normally; otherwise the user has
   256  		// no way of ensuring that the write went through.  This is consistent with close(2)
   257  		// semantics on most unixes.
   258  		c.writeCond.Signal() // if writePump is blocked, notify it of the change to c.closed
   259  
   260  		// OTOH: c.readReset().  Unlike the write buffer, we do forcefully reset this
   261  		// instead of letting it drain normally.  If you close something you're reading
   262  		// before you've received EOF, you are liable to lose data; this is no different.
   263  		// It's only happenstance that we even have that data in our buffer; it might as
   264  		// well still be in transit on a slow wire.
   265  		c.readCond.L.Lock()
   266  		c.readReset()
   267  		c.readCond.L.Unlock()
   268  
   269  		// Wait for writePump to drain (triggered above).
   270  		<-c.writeDone
   271  
   272  		// readPump might be blocked on c.conn.Recv(), so we might need to force it closed
   273  		// to interrupt that.
   274  		c.writeCond.L.Lock()
   275  		c.closeErr = c.conn.CloseOnce()
   276  		c.writeCond.L.Unlock()
   277  
   278  		// Wait for readPump to drain (triggered above).
   279  		<-c.readDone
   280  	})
   281  	return c.closeErr
   282  }
   283  
   284  // LocalAddr implements net.Conn.
   285  func (c *bufferedConn) LocalAddr() net.Addr {
   286  	return c.conn.LocalAddr()
   287  }
   288  
   289  // RemoteAddr implements net.Conn.
   290  func (c *bufferedConn) RemoteAddr() net.Addr {
   291  	return c.conn.RemoteAddr()
   292  }
   293  
   294  // SetDeadline implements net.Conn.
   295  func (c *bufferedConn) SetDeadline(t time.Time) error {
   296  	if c.isClosed() {
   297  		return os.ErrClosed
   298  	}
   299  	c.readDeadline.set(t)
   300  	c.writeDeadline.set(t)
   301  	return nil
   302  }
   303  
   304  // SetReadDeadline implements net.Conn.
   305  func (c *bufferedConn) SetReadDeadline(t time.Time) error {
   306  	if c.isClosed() {
   307  		return os.ErrClosed
   308  	}
   309  	c.readDeadline.set(t)
   310  	return nil
   311  }
   312  
   313  // SetWriteDeadline implements net.Conn.
   314  func (c *bufferedConn) SetWriteDeadline(t time.Time) error {
   315  	if c.isClosed() {
   316  		return os.ErrClosed
   317  	}
   318  	c.writeDeadline.set(t)
   319  	return nil
   320  }
   321  
   322  // Wait implements dnet.Conn.
   323  func (c *bufferedConn) Wait() error {
   324  	<-c.readDone
   325  	if c.readErr != nil {
   326  		return c.readErr
   327  	}
   328  	if c.writeErr != nil {
   329  		return c.writeErr
   330  	}
   331  	if c.closeErr != nil {
   332  		return c.closeErr
   333  	}
   334  	return nil
   335  }
   336  
   337  func isClosedChan(c <-chan struct{}) bool {
   338  	select {
   339  	case <-c:
   340  		return true
   341  	default:
   342  		return false
   343  	}
   344  }