github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/net_test.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package testutils
    12  
    13  import (
    14  	"bufio"
    15  	"fmt"
    16  	"io"
    17  	"net"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/util"
    22  	"github.com/cockroachdb/cockroach/pkg/util/grpcutil"
    23  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    24  	"github.com/cockroachdb/cockroach/pkg/util/netutil"
    25  	"github.com/cockroachdb/errors"
    26  )
    27  
    28  // RunEchoServer runs a network server that accepts one connection from ln and
    29  // echos the data sent on it.
    30  //
    31  // If serverSideCh != nil, every slice of data received by the server is also
    32  // sent on this channel before being echoed back on the connection it came on.
    33  // Useful to observe what the server has received when this server is used with
    34  // partitioned connections.
    35  func RunEchoServer(ln net.Listener, serverSideCh chan<- []byte) error {
    36  	conn, err := ln.Accept()
    37  	if err != nil {
    38  		if grpcutil.IsClosedConnection(err) {
    39  			return nil
    40  		}
    41  		return err
    42  	}
    43  	if _, err := copyWithSideChan(conn, conn, serverSideCh); err != nil {
    44  		return err
    45  	}
    46  	return nil
    47  }
    48  
    49  // copyWithSideChan is like io.Copy(), but also takes a channel on which data
    50  // read from src is sent before being written to dst.
    51  func copyWithSideChan(dst io.Writer, src io.Reader, ch chan<- []byte) (written int64, err error) {
    52  	buf := make([]byte, 32*1024)
    53  	for {
    54  		nr, er := src.Read(buf)
    55  		if nr > 0 {
    56  			if ch != nil {
    57  				ch <- buf[:nr]
    58  			}
    59  
    60  			nw, ew := dst.Write(buf[0:nr])
    61  			if nw > 0 {
    62  				written += int64(nw)
    63  			}
    64  			if ew != nil {
    65  				err = ew
    66  				break
    67  			}
    68  			if nr != nw {
    69  				err = io.ErrShortWrite
    70  				break
    71  			}
    72  		}
    73  		if er != nil {
    74  			if er != io.EOF {
    75  				err = er
    76  			}
    77  			break
    78  		}
    79  	}
    80  	return written, err
    81  }
    82  
    83  func TestPartitionableConnBasic(t *testing.T) {
    84  	defer leaktest.AfterTest(t)()
    85  	addr := util.TestAddr
    86  	ln, err := net.Listen(addr.Network(), addr.String())
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	go func() {
    91  		if err := RunEchoServer(ln, nil); err != nil {
    92  			t.Error(err)
    93  		}
    94  	}()
    95  	defer func() {
    96  		netutil.FatalIfUnexpected(ln.Close())
    97  	}()
    98  
    99  	serverConn, err := net.Dial("tcp", ln.Addr().String())
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  
   104  	pConn := NewPartitionableConn(serverConn)
   105  	defer pConn.Close()
   106  
   107  	exp := "let's see if this value comes back\n"
   108  	fmt.Fprint(pConn, exp)
   109  	got, err := bufio.NewReader(pConn).ReadString('\n')
   110  	if err != nil {
   111  		t.Fatal(err)
   112  	}
   113  	if got != exp {
   114  		t.Fatalf("expecting: %q , got %q", exp, got)
   115  	}
   116  }
   117  
   118  func TestPartitionableConnPartitionC2S(t *testing.T) {
   119  	defer leaktest.AfterTest(t)()
   120  
   121  	addr := util.TestAddr
   122  	ln, err := net.Listen(addr.Network(), addr.String())
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  	serverSideCh := make(chan []byte)
   127  	go func() {
   128  		if err := RunEchoServer(ln, serverSideCh); err != nil {
   129  			t.Error(err)
   130  		}
   131  	}()
   132  	defer func() {
   133  		netutil.FatalIfUnexpected(ln.Close())
   134  	}()
   135  
   136  	serverConn, err := net.Dial("tcp", ln.Addr().String())
   137  	if err != nil {
   138  		t.Fatal(err)
   139  	}
   140  
   141  	pConn := NewPartitionableConn(serverConn)
   142  	defer pConn.Close()
   143  
   144  	// Partition the client->server connection. Afterwards, we're going to send
   145  	// something and assert that the server doesn't get it (within a timeout) by
   146  	// snooping on the server's side channel. Then we'll resolve the partition and
   147  	// expect that the server gets the message that was pending and echoes it
   148  	// back.
   149  
   150  	pConn.PartitionC2S()
   151  
   152  	// Client sends data.
   153  	exp := "let's see when this value comes back\n"
   154  	fmt.Fprint(pConn, exp)
   155  
   156  	// In the background, the client waits on a read.
   157  	clientDoneCh := make(chan error)
   158  	go func() {
   159  		clientDoneCh <- func() error {
   160  			got, err := bufio.NewReader(pConn).ReadString('\n')
   161  			if err != nil {
   162  				return err
   163  			}
   164  			if got != exp {
   165  				return errors.Errorf("expecting: %q , got %q", exp, got)
   166  			}
   167  			return nil
   168  		}()
   169  	}()
   170  
   171  	timerDoneCh := make(chan error)
   172  	time.AfterFunc(3*time.Millisecond, func() {
   173  		var err error
   174  		select {
   175  		case err = <-clientDoneCh:
   176  			err = errors.Errorf("unexpected reply while partitioned: %v", err)
   177  		case buf := <-serverSideCh:
   178  			err = errors.Errorf("server was not supposed to have received data while partitioned: %q", buf)
   179  		default:
   180  		}
   181  		timerDoneCh <- err
   182  	})
   183  
   184  	if err := <-timerDoneCh; err != nil {
   185  		t.Fatal(err)
   186  	}
   187  
   188  	// Now unpartition and expect the pending data to be sent and a reply to be
   189  	// received.
   190  
   191  	pConn.UnpartitionC2S()
   192  
   193  	// Expect the server to receive the data.
   194  	<-serverSideCh
   195  
   196  	if err := <-clientDoneCh; err != nil {
   197  		t.Fatal(err)
   198  	}
   199  }
   200  
   201  func TestPartitionableConnPartitionS2C(t *testing.T) {
   202  	defer leaktest.AfterTest(t)()
   203  
   204  	addr := util.TestAddr
   205  	ln, err := net.Listen(addr.Network(), addr.String())
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  	serverSideCh := make(chan []byte)
   210  	go func() {
   211  		if err := RunEchoServer(ln, serverSideCh); err != nil {
   212  			t.Error(err)
   213  		}
   214  	}()
   215  	defer func() {
   216  		netutil.FatalIfUnexpected(ln.Close())
   217  	}()
   218  
   219  	serverConn, err := net.Dial("tcp", ln.Addr().String())
   220  	if err != nil {
   221  		t.Fatal(err)
   222  	}
   223  
   224  	pConn := NewPartitionableConn(serverConn)
   225  	defer pConn.Close()
   226  
   227  	// We're going to partition the server->client connection. Then we'll send
   228  	// some data and assert that the server gets it (by snooping on the server's
   229  	// side-channel). Then we'll assert that the client doesn't get the reply
   230  	// (with a timeout). Then we resolve the partition and assert that the client
   231  	// gets the reply.
   232  
   233  	pConn.PartitionS2C()
   234  
   235  	// Client sends data.
   236  	exp := "let's see when this value comes back\n"
   237  	fmt.Fprint(pConn, exp)
   238  
   239  	if s := <-serverSideCh; string(s) != exp {
   240  		t.Fatalf("expected server to receive %q, got %q", exp, s)
   241  	}
   242  
   243  	// In the background, the client waits on a read.
   244  	clientDoneCh := make(chan error)
   245  	go func() {
   246  		clientDoneCh <- func() error {
   247  			got, err := bufio.NewReader(pConn).ReadString('\n')
   248  			if err != nil {
   249  				return err
   250  			}
   251  			if got != exp {
   252  				return errors.Errorf("expecting: %q , got %q", exp, got)
   253  			}
   254  			return nil
   255  		}()
   256  	}()
   257  
   258  	// Check that the client does not get the server's response.
   259  	time.AfterFunc(3*time.Millisecond, func() {
   260  		select {
   261  		case err := <-clientDoneCh:
   262  			t.Errorf("unexpected reply while partitioned: %v", err)
   263  		default:
   264  		}
   265  	})
   266  
   267  	// Now unpartition and expect the pending data to be sent and a reply to be
   268  	// received.
   269  
   270  	pConn.UnpartitionS2C()
   271  
   272  	if err := <-clientDoneCh; err != nil {
   273  		t.Fatal(err)
   274  	}
   275  }
   276  
   277  // Test that, while partitioned, a sender doesn't block while the internal
   278  // buffer is not full.
   279  func TestPartitionableConnBuffering(t *testing.T) {
   280  	defer leaktest.AfterTest(t)()
   281  
   282  	addr := util.TestAddr
   283  	ln, err := net.Listen(addr.Network(), addr.String())
   284  	if err != nil {
   285  		t.Fatal(err)
   286  	}
   287  
   288  	// In the background, the server reads everything.
   289  	exp := 5 * (bufferSize / 10)
   290  	serverDoneCh := make(chan error)
   291  	go func() {
   292  		serverDoneCh <- func() error {
   293  			conn, err := ln.Accept()
   294  			if err != nil {
   295  				return err
   296  			}
   297  			received := 0
   298  			for {
   299  				data := make([]byte, 1024*1024)
   300  				nr, err := conn.Read(data)
   301  				if err != nil {
   302  					if err == io.EOF {
   303  						break
   304  					}
   305  					return err
   306  				}
   307  				received += nr
   308  			}
   309  			if received != exp {
   310  				return errors.Errorf("server expecting: %d , got %d", exp, received)
   311  			}
   312  			return nil
   313  		}()
   314  	}()
   315  
   316  	serverConn, err := net.Dial("tcp", ln.Addr().String())
   317  	if err != nil {
   318  		t.Fatal(err)
   319  	}
   320  
   321  	pConn := NewPartitionableConn(serverConn)
   322  	defer pConn.Close()
   323  
   324  	pConn.PartitionC2S()
   325  	defer pConn.Finish()
   326  
   327  	// Send chunks such that they don't add up to the buffer size exactly.
   328  	data := make([]byte, bufferSize/10)
   329  	for i := 0; i < 5; i++ {
   330  		nw, err := pConn.Write(data)
   331  		if err != nil {
   332  			t.Fatal(err)
   333  		}
   334  		if nw != len(data) {
   335  			t.Fatal("unexpected partial write; PartitionableConn always writes fully")
   336  		}
   337  	}
   338  	pConn.UnpartitionC2S()
   339  	pConn.Close()
   340  
   341  	if err := <-serverDoneCh; err != nil {
   342  		t.Fatal(err)
   343  	}
   344  }
   345  
   346  // Test that, while partitioned, a party can close the connection and the other
   347  // party will not observe this until after the partition is lifted.
   348  func TestPartitionableConnCloseDeliveredAfterPartition(t *testing.T) {
   349  	defer leaktest.AfterTest(t)()
   350  
   351  	addr := util.TestAddr
   352  	ln, err := net.Listen(addr.Network(), addr.String())
   353  	if err != nil {
   354  		t.Fatal(err)
   355  	}
   356  
   357  	// In the background, the server reads everything.
   358  	serverDoneCh := make(chan error)
   359  	go func() {
   360  		serverDoneCh <- func() error {
   361  			conn, err := ln.Accept()
   362  			if err != nil {
   363  				return err
   364  			}
   365  			received := 0
   366  			for {
   367  				data := make([]byte, 1<<20 /* 1 MiB */)
   368  				nr, err := conn.Read(data)
   369  				if err != nil {
   370  					if err == io.EOF {
   371  						return nil
   372  					}
   373  					return err
   374  				}
   375  				received += nr
   376  			}
   377  		}()
   378  	}()
   379  
   380  	serverConn, err := net.Dial("tcp", ln.Addr().String())
   381  	if err != nil {
   382  		t.Fatal(err)
   383  	}
   384  
   385  	pConn := NewPartitionableConn(serverConn)
   386  	defer pConn.Close()
   387  
   388  	pConn.PartitionC2S()
   389  	defer pConn.Finish()
   390  
   391  	pConn.Close()
   392  
   393  	timerDoneCh := make(chan error)
   394  	time.AfterFunc(3*time.Millisecond, func() {
   395  		var err error
   396  		select {
   397  		case err = <-serverDoneCh:
   398  			err = errors.Wrap(err, "server was not supposed to see the closing while partitioned")
   399  		default:
   400  		}
   401  		timerDoneCh <- err
   402  	})
   403  
   404  	if err := <-timerDoneCh; err != nil {
   405  		t.Fatal(err)
   406  	}
   407  
   408  	pConn.UnpartitionC2S()
   409  
   410  	if err := <-serverDoneCh; err != nil {
   411  		t.Fatal(err)
   412  	}
   413  }