github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/net.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package testutils
    12  
    13  import (
    14  	"context"
    15  	"io"
    16  	"net"
    17  	"sync"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/util/log"
    20  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    21  	"github.com/cockroachdb/errors"
    22  )
    23  
    24  // bufferSize is the size of the buffer used by PartitionableConn. Writes to a
    25  // partitioned connection will block after the buffer gets filled.
    26  const bufferSize = 16 << 10 // 16 KB
    27  
    28  // PartitionableConn is an implementation of net.Conn that allows the
    29  // client->server and/or the server->client directions to be temporarily
    30  // partitioned.
    31  //
    32  // A PartitionableConn wraps a provided net.Conn (the serverConn member) and
    33  // forwards every read and write to it. It interposes an arbiter in front of it
    34  // that's used to block reads/writes while the PartitionableConn is in the
    35  // partitioned mode.
    36  //
    37  // While a direction is partitioned, data sent in that direction doesn't flow. A
    38  // write while partitioned will block after an internal buffer gets filled. Data
    39  // written to the conn after the partition has been established is not delivered
    40  // to the remote party until the partition is lifted. At that time, all the
    41  // buffered data is delivered. Since data is delivered async, data written
    42  // before the partition is established may or may not be blocked by the
    43  // partition; use application-level ACKs if that's important.
    44  type PartitionableConn struct {
    45  	// We embed a net.Conn so that we inherit the interface. Note that we override
    46  	// Read() and Write().
    47  	//
    48  	// This embedded Conn is half of a net.Pipe(). The other half is clientConn.
    49  	net.Conn
    50  
    51  	clientConn net.Conn
    52  	serverConn net.Conn
    53  
    54  	mu struct {
    55  		syncutil.Mutex
    56  
    57  		// err, if set, is returned by any subsequent call to Read or Write.
    58  		err error
    59  
    60  		// Are any of the two direction (client-to-server, server-to-client)
    61  		// currently partitioned?
    62  		c2sPartitioned bool
    63  		s2cPartitioned bool
    64  
    65  		c2sBuffer buf
    66  		s2cBuffer buf
    67  
    68  		// Conds to be signaled when the corresponding partition is lifted.
    69  		c2sWaiter *sync.Cond
    70  		s2cWaiter *sync.Cond
    71  	}
    72  }
    73  
    74  type buf struct {
    75  	// A mutex used to synchronize access to all the fields. It will be set to the
    76  	// parent PartitionableConn's mutex.
    77  	*syncutil.Mutex
    78  
    79  	data     []byte
    80  	capacity int
    81  	closed   bool
    82  	// The error that was passed to Close(err). See Close() for more info.
    83  	closedErr error
    84  	name      string // A human-readable name, useful for debugging.
    85  
    86  	// readerWait is signaled when the reader should wake up and check the
    87  	// buffer's state: when new data is put in the buffer, when the buffer is
    88  	// closed, and whenever the PartitionableConn wants to unblock all reads (i.e.
    89  	// on partition).
    90  	readerWait *sync.Cond
    91  
    92  	// capacityWait is signaled when a blocked writer should wake up because data
    93  	// is taken out of the buffer and there's now some capacity. It's also
    94  	// signaled when the buffer is closed.
    95  	capacityWait *sync.Cond
    96  }
    97  
    98  func makeBuf(name string, capacity int, mu *syncutil.Mutex) buf {
    99  	return buf{
   100  		Mutex:        mu,
   101  		name:         name,
   102  		capacity:     capacity,
   103  		readerWait:   sync.NewCond(mu),
   104  		capacityWait: sync.NewCond(mu),
   105  	}
   106  }
   107  
   108  // Write adds data to the buffer. If there's zero free capacity, it will block
   109  // until there's some capacity available or the buffer is closed. If there's
   110  // non-zero insufficient capacity, it will perform a partial write.
   111  //
   112  // The number of bytes written is returned.
   113  func (b *buf) Write(data []byte) (int, error) {
   114  	b.Lock()
   115  	defer b.Unlock()
   116  	for b.capacity == len(b.data) && !b.closed {
   117  		// Block for capacity.
   118  		b.capacityWait.Wait()
   119  	}
   120  	if b.closed {
   121  		return 0, b.closedErr
   122  	}
   123  	available := b.capacity - len(b.data)
   124  	toCopy := available
   125  	if len(data) < available {
   126  		toCopy = len(data)
   127  	}
   128  	b.data = append(b.data, data[:toCopy]...)
   129  	b.wakeReaderLocked()
   130  	return toCopy, nil
   131  }
   132  
   133  // errEAgain is returned by buf.readLocked() when the read was blocked at the
   134  // time when buf.readerWait was signaled (in particular, after the
   135  // PartitionableConn interrupted the read because of a partition). The caller is
   136  // expected to try the read again after the partition is gone.
   137  var errEAgain = errors.New("try read again")
   138  
   139  // readLocked returns data from buf, up to "size" bytes. If there's no data in
   140  // the buffer, it blocks until either some data becomes available or the buffer
   141  // is closed.
   142  func (b *buf) readLocked(size int) ([]byte, error) {
   143  	if len(b.data) == 0 && !b.closed {
   144  		b.readerWait.Wait()
   145  		// We were unblocked either by data arrving, or by a partition, or by
   146  		// another uninteresting reason. Return to the caller, in case it's because
   147  		// of a partition.
   148  		return nil, errEAgain
   149  	}
   150  	if b.closed && len(b.data) == 0 {
   151  		return nil, b.closedErr
   152  	}
   153  	var ret []byte
   154  	if len(b.data) < size {
   155  		ret = b.data
   156  		b.data = nil
   157  	} else {
   158  		ret = b.data[:size]
   159  		b.data = b.data[size:]
   160  	}
   161  	b.capacityWait.Signal()
   162  	return ret, nil
   163  }
   164  
   165  // Close closes the buffer. All reads and writes that are currently blocked will
   166  // be woken and they'll all return err.
   167  func (b *buf) Close(err error) {
   168  	b.Lock()
   169  	defer b.Unlock()
   170  	b.closed = true
   171  	b.closedErr = err
   172  	b.readerWait.Signal()
   173  	b.capacityWait.Signal()
   174  }
   175  
   176  // wakeReaderLocked wakes the reader in case it's blocked.
   177  // See comments on readerWait.
   178  //
   179  // This needs to be called while holding the buffer's mutex.
   180  func (b *buf) wakeReaderLocked() {
   181  	b.readerWait.Signal()
   182  }
   183  
   184  // NewPartitionableConn wraps serverConn in a PartitionableConn.
   185  func NewPartitionableConn(serverConn net.Conn) *PartitionableConn {
   186  	clientEnd, clientConn := net.Pipe()
   187  	c := &PartitionableConn{
   188  		Conn:       clientEnd,
   189  		clientConn: clientConn,
   190  		serverConn: serverConn,
   191  	}
   192  	c.mu.c2sWaiter = sync.NewCond(&c.mu.Mutex)
   193  	c.mu.s2cWaiter = sync.NewCond(&c.mu.Mutex)
   194  	c.mu.c2sBuffer = makeBuf("c2sBuf", bufferSize, &c.mu.Mutex)
   195  	c.mu.s2cBuffer = makeBuf("s2cBuf", bufferSize, &c.mu.Mutex)
   196  
   197  	// Start copying from client to server.
   198  	go func() {
   199  		err := c.copy(
   200  			c.clientConn, // src
   201  			c.serverConn, // dst
   202  			&c.mu.c2sBuffer,
   203  			func() { // waitForNoPartitionLocked
   204  				for c.mu.c2sPartitioned {
   205  					c.mu.c2sWaiter.Wait()
   206  				}
   207  			})
   208  		c.mu.Lock()
   209  		c.mu.err = err
   210  		c.mu.Unlock()
   211  		if err := c.clientConn.Close(); err != nil {
   212  			log.Errorf(context.TODO(), "unexpected error closing internal pipe: %s", err)
   213  		}
   214  		if err := c.serverConn.Close(); err != nil {
   215  			log.Errorf(context.TODO(), "error closing server conn: %s", err)
   216  		}
   217  	}()
   218  
   219  	// Start copying from server to client.
   220  	go func() {
   221  		err := c.copy(
   222  			c.serverConn, // src
   223  			c.clientConn, // dst
   224  			&c.mu.s2cBuffer,
   225  			func() { // waitForNoPartitionLocked
   226  				for c.mu.s2cPartitioned {
   227  					c.mu.s2cWaiter.Wait()
   228  				}
   229  			})
   230  		c.mu.Lock()
   231  		c.mu.err = err
   232  		c.mu.Unlock()
   233  		if err := c.clientConn.Close(); err != nil {
   234  			log.Fatalf(context.TODO(), "unexpected error closing internal pipe: %s", err)
   235  		}
   236  		if err := c.serverConn.Close(); err != nil {
   237  			log.Errorf(context.TODO(), "error closing server conn: %s", err)
   238  		}
   239  	}()
   240  
   241  	return c
   242  }
   243  
   244  // Finish removes any partitions that may exist so that blocked goroutines can
   245  // finish.
   246  // Finish() must be called if a connection may have been left in a partitioned
   247  // state.
   248  func (c *PartitionableConn) Finish() {
   249  	c.mu.Lock()
   250  	defer c.mu.Unlock()
   251  	c.mu.c2sPartitioned = false
   252  	c.mu.c2sWaiter.Signal()
   253  	c.mu.s2cPartitioned = false
   254  	c.mu.s2cWaiter.Signal()
   255  }
   256  
   257  // PartitionC2S partitions the client-to-server direction.
   258  // If UnpartitionC2S() is not called, Finish() must be called.
   259  func (c *PartitionableConn) PartitionC2S() {
   260  	c.mu.Lock()
   261  	defer c.mu.Unlock()
   262  	if c.mu.c2sPartitioned {
   263  		panic("already partitioned")
   264  	}
   265  	c.mu.c2sPartitioned = true
   266  	c.mu.c2sBuffer.wakeReaderLocked()
   267  }
   268  
   269  // UnpartitionC2S lifts an existing client-to-server partition.
   270  func (c *PartitionableConn) UnpartitionC2S() {
   271  	c.mu.Lock()
   272  	defer c.mu.Unlock()
   273  	if !c.mu.c2sPartitioned {
   274  		panic("not partitioned")
   275  	}
   276  	c.mu.c2sPartitioned = false
   277  	c.mu.c2sWaiter.Signal()
   278  }
   279  
   280  // PartitionS2C partitions the server-to-client direction.
   281  // If UnpartitionS2C() is not called, Finish() must be called.
   282  func (c *PartitionableConn) PartitionS2C() {
   283  	c.mu.Lock()
   284  	defer c.mu.Unlock()
   285  	if c.mu.s2cPartitioned {
   286  		panic("already partitioned")
   287  	}
   288  	c.mu.s2cPartitioned = true
   289  	c.mu.s2cBuffer.wakeReaderLocked()
   290  }
   291  
   292  // UnpartitionS2C lifts an existing server-to-client partition.
   293  func (c *PartitionableConn) UnpartitionS2C() {
   294  	c.mu.Lock()
   295  	defer c.mu.Unlock()
   296  	if !c.mu.s2cPartitioned {
   297  		panic("not partitioned")
   298  	}
   299  	c.mu.s2cPartitioned = false
   300  	c.mu.s2cWaiter.Signal()
   301  }
   302  
   303  // Read is part of the net.Conn interface.
   304  func (c *PartitionableConn) Read(b []byte) (n int, err error) {
   305  	c.mu.Lock()
   306  	err = c.mu.err
   307  	c.mu.Unlock()
   308  	if err != nil {
   309  		return 0, err
   310  	}
   311  
   312  	// Forward to the embedded connection.
   313  	return c.Conn.Read(b)
   314  }
   315  
   316  // Write is part of the net.Conn interface.
   317  func (c *PartitionableConn) Write(b []byte) (n int, err error) {
   318  	c.mu.Lock()
   319  	err = c.mu.err
   320  	c.mu.Unlock()
   321  	if err != nil {
   322  		return 0, err
   323  	}
   324  
   325  	// Forward to the embedded connection.
   326  	return c.Conn.Write(b)
   327  }
   328  
   329  // readFrom copies data from src into the buffer until src.Read() returns an
   330  // error (e.g. io.EOF). That error is returned.
   331  //
   332  // readFrom is written in the spirit of interface io.ReaderFrom, except it
   333  // returns the io.EOF error, and also doesn't guarantee that every byte that has
   334  // been read from src is put into the buffer (as the buffer allows concurrent
   335  // access and buf.Write can return an error).
   336  func (b *buf) readFrom(src io.Reader) error {
   337  	data := make([]byte, 1024)
   338  	for {
   339  		nr, err := src.Read(data)
   340  		if err != nil {
   341  			return err
   342  		}
   343  		toSend := data[:nr]
   344  		for {
   345  			nw, ew := b.Write(toSend)
   346  			if ew != nil {
   347  				return ew
   348  			}
   349  			if nw == len(toSend) {
   350  				break
   351  			}
   352  			toSend = toSend[nw:]
   353  		}
   354  	}
   355  }
   356  
   357  // copyFromBuffer copies data from src to dst until src.Read() returns EOF.
   358  // The EOF is returned (i.e. the return value is always != nil). This is because
   359  // the PartitionableConn wants to hold on to any error, including EOF.
   360  //
   361  // waitForNoPartitionLocked is a function to be called before consuming data
   362  // from src, in order to make sure that we only consume data when we're not
   363  // partitioned. It needs to be called under src.Mutex, as the check needs to be
   364  // done atomically with consuming the buffer's data.
   365  func (c *PartitionableConn) copyFromBuffer(
   366  	src *buf, dst net.Conn, waitForNoPartitionLocked func(),
   367  ) error {
   368  	for {
   369  		// Don't read from the buffer while we're partitioned.
   370  		src.Mutex.Lock()
   371  		waitForNoPartitionLocked()
   372  		data, err := src.readLocked(1024 * 1024)
   373  		src.Mutex.Unlock()
   374  
   375  		if len(data) > 0 {
   376  			nw, ew := dst.Write(data)
   377  			if ew != nil {
   378  				err = ew
   379  			}
   380  			if len(data) != nw {
   381  				err = io.ErrShortWrite
   382  			}
   383  		} else if err == nil {
   384  			err = io.EOF
   385  		} else if errors.Is(err, errEAgain) {
   386  			continue
   387  		}
   388  		if err != nil {
   389  			return err
   390  		}
   391  	}
   392  }
   393  
   394  // copy copies data from src to dst while we're not partitioned and stops doing
   395  // so while partitioned.
   396  //
   397  // It runs two goroutines internally: one copying from src to an internal buffer
   398  // and one copying from the buffer to dst. The 2nd one deals with partitions.
   399  func (c *PartitionableConn) copy(
   400  	src net.Conn, dst net.Conn, buf *buf, waitForNoPartitionLocked func(),
   401  ) error {
   402  	tasks := make(chan error)
   403  	go func() {
   404  		err := buf.readFrom(src)
   405  		buf.Close(err)
   406  		tasks <- err
   407  	}()
   408  	go func() {
   409  		err := c.copyFromBuffer(buf, dst, waitForNoPartitionLocked)
   410  		buf.Close(err)
   411  		tasks <- err
   412  	}()
   413  	err := <-tasks
   414  	err2 := <-tasks
   415  	if err == nil {
   416  		err = err2
   417  	}
   418  	return err
   419  }