github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/flipcall/flipcall_test.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package flipcall
    16  
    17  import (
    18  	"runtime"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/sync"
    23  )
    24  
    25  var testPacketWindowSize = pageSize
    26  
    27  type testConnection struct {
    28  	pwa      PacketWindowAllocator
    29  	clientEP Endpoint
    30  	serverEP Endpoint
    31  }
    32  
    33  func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection {
    34  	c := &testConnection{}
    35  	if err := c.pwa.Init(); err != nil {
    36  		tb.Fatalf("failed to create PacketWindowAllocator: %v", err)
    37  	}
    38  	pwd, err := c.pwa.Allocate(testPacketWindowSize)
    39  	if err != nil {
    40  		c.pwa.Destroy()
    41  		tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
    42  	}
    43  	if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil {
    44  		c.pwa.Destroy()
    45  		tb.Fatalf("failed to create client Endpoint: %v", err)
    46  	}
    47  	if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil {
    48  		c.pwa.Destroy()
    49  		c.clientEP.Destroy()
    50  		tb.Fatalf("failed to create server Endpoint: %v", err)
    51  	}
    52  	return c
    53  }
    54  
    55  func newTestConnection(tb testing.TB) *testConnection {
    56  	return newTestConnectionWithOptions(tb, nil, nil)
    57  }
    58  
    59  func (c *testConnection) destroy() {
    60  	c.pwa.Destroy()
    61  	c.clientEP.Destroy()
    62  	c.serverEP.Destroy()
    63  }
    64  
    65  func testSendRecv(t *testing.T, c *testConnection) {
    66  	// This shared variable is used to confirm that synchronization between
    67  	// flipcall endpoints is visible to the Go race detector.
    68  	state := 0
    69  	var serverRun sync.WaitGroup
    70  	serverRun.Add(1)
    71  	go func() {
    72  		defer serverRun.Done()
    73  		t.Logf("server Endpoint waiting for packet 1")
    74  		if _, err := c.serverEP.RecvFirst(); err != nil {
    75  			t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
    76  			return
    77  		}
    78  		state++
    79  		if state != 2 {
    80  			t.Errorf("shared state counter: got %d, wanted 2", state)
    81  		}
    82  		t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
    83  		if _, err := c.serverEP.SendRecv(0); err != nil {
    84  			t.Errorf("server Endpoint.SendRecv() failed: %v", err)
    85  			return
    86  		}
    87  		state++
    88  		if state != 4 {
    89  			t.Errorf("shared state counter: got %d, wanted 4", state)
    90  		}
    91  		t.Logf("server Endpoint got packet 3")
    92  	}()
    93  	defer func() {
    94  		// Ensure that the server goroutine is cleaned up before
    95  		// c.serverEP.Destroy(), even if the test fails.
    96  		c.serverEP.Shutdown()
    97  		serverRun.Wait()
    98  	}()
    99  
   100  	t.Logf("client Endpoint establishing connection")
   101  	if err := c.clientEP.Connect(); err != nil {
   102  		t.Fatalf("client Endpoint.Connect() failed: %v", err)
   103  	}
   104  	state++
   105  	if state != 1 {
   106  		t.Errorf("shared state counter: got %d, wanted 1", state)
   107  	}
   108  	t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
   109  	if _, err := c.clientEP.SendRecv(0); err != nil {
   110  		t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
   111  	}
   112  	state++
   113  	if state != 3 {
   114  		t.Errorf("shared state counter: got %d, wanted 3", state)
   115  	}
   116  	t.Logf("client Endpoint got packet 2, sending packet 3")
   117  	if err := c.clientEP.SendLast(0); err != nil {
   118  		t.Fatalf("client Endpoint.SendLast() failed: %v", err)
   119  	}
   120  	t.Logf("waiting for server goroutine to complete")
   121  	serverRun.Wait()
   122  }
   123  
   124  func TestSendRecv(t *testing.T) {
   125  	c := newTestConnection(t)
   126  	defer c.destroy()
   127  	testSendRecv(t, c)
   128  }
   129  
   130  func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
   131  	if remoteShutdown {
   132  		c.serverEP.Shutdown()
   133  	} else {
   134  		c.clientEP.Shutdown()
   135  	}
   136  	if err := c.clientEP.Connect(); err == nil {
   137  		t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
   138  	}
   139  }
   140  
   141  func TestShutdownBeforeConnectLocal(t *testing.T) {
   142  	c := newTestConnection(t)
   143  	defer c.destroy()
   144  	testShutdownBeforeConnect(t, c, false)
   145  }
   146  
   147  func TestShutdownBeforeConnectRemote(t *testing.T) {
   148  	c := newTestConnection(t)
   149  	defer c.destroy()
   150  	testShutdownBeforeConnect(t, c, true)
   151  }
   152  
   153  func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
   154  	var clientRun sync.WaitGroup
   155  	clientRun.Add(1)
   156  	go func() {
   157  		defer clientRun.Done()
   158  		if err := c.clientEP.Connect(); err == nil {
   159  			t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
   160  		}
   161  	}()
   162  	time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
   163  	if remoteShutdown {
   164  		c.serverEP.Shutdown()
   165  	} else {
   166  		c.clientEP.Shutdown()
   167  	}
   168  	clientRun.Wait()
   169  }
   170  
   171  func TestShutdownDuringConnectLocal(t *testing.T) {
   172  	c := newTestConnection(t)
   173  	defer c.destroy()
   174  	testShutdownDuringConnect(t, c, false)
   175  }
   176  
   177  func TestShutdownDuringConnectRemote(t *testing.T) {
   178  	c := newTestConnection(t)
   179  	defer c.destroy()
   180  	testShutdownDuringConnect(t, c, true)
   181  }
   182  
   183  func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) {
   184  	if remoteShutdown {
   185  		c.clientEP.Shutdown()
   186  	} else {
   187  		c.serverEP.Shutdown()
   188  	}
   189  	if _, err := c.serverEP.RecvFirst(); err == nil {
   190  		t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
   191  	}
   192  }
   193  
   194  func TestShutdownBeforeRecvFirstLocal(t *testing.T) {
   195  	c := newTestConnection(t)
   196  	defer c.destroy()
   197  	testShutdownBeforeRecvFirst(t, c, false)
   198  }
   199  
   200  func TestShutdownBeforeRecvFirstRemote(t *testing.T) {
   201  	c := newTestConnection(t)
   202  	defer c.destroy()
   203  	testShutdownBeforeRecvFirst(t, c, true)
   204  }
   205  
   206  func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
   207  	var serverRun sync.WaitGroup
   208  	serverRun.Add(1)
   209  	go func() {
   210  		defer serverRun.Done()
   211  		if _, err := c.serverEP.RecvFirst(); err == nil {
   212  			t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
   213  		}
   214  	}()
   215  	time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
   216  	if remoteShutdown {
   217  		c.clientEP.Shutdown()
   218  	} else {
   219  		c.serverEP.Shutdown()
   220  	}
   221  	serverRun.Wait()
   222  }
   223  
   224  func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) {
   225  	c := newTestConnection(t)
   226  	defer c.destroy()
   227  	testShutdownDuringRecvFirstBeforeConnect(t, c, false)
   228  }
   229  
   230  func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) {
   231  	c := newTestConnection(t)
   232  	defer c.destroy()
   233  	testShutdownDuringRecvFirstBeforeConnect(t, c, true)
   234  }
   235  
   236  func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
   237  	var serverRun sync.WaitGroup
   238  	serverRun.Add(1)
   239  	go func() {
   240  		defer serverRun.Done()
   241  		if _, err := c.serverEP.RecvFirst(); err == nil {
   242  			t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
   243  		}
   244  	}()
   245  	defer func() {
   246  		// Ensure that the server goroutine is cleaned up before
   247  		// c.serverEP.Destroy(), even if the test fails.
   248  		c.serverEP.Shutdown()
   249  		serverRun.Wait()
   250  	}()
   251  	if err := c.clientEP.Connect(); err != nil {
   252  		t.Fatalf("client Endpoint.Connect() failed: %v", err)
   253  	}
   254  	if remoteShutdown {
   255  		c.clientEP.Shutdown()
   256  	} else {
   257  		c.serverEP.Shutdown()
   258  	}
   259  	serverRun.Wait()
   260  }
   261  
   262  func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) {
   263  	c := newTestConnection(t)
   264  	defer c.destroy()
   265  	testShutdownDuringRecvFirstAfterConnect(t, c, false)
   266  }
   267  
   268  func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) {
   269  	c := newTestConnection(t)
   270  	defer c.destroy()
   271  	testShutdownDuringRecvFirstAfterConnect(t, c, true)
   272  }
   273  
   274  func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
   275  	var serverRun sync.WaitGroup
   276  	serverRun.Add(1)
   277  	go func() {
   278  		defer serverRun.Done()
   279  		if _, err := c.serverEP.RecvFirst(); err != nil {
   280  			t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
   281  		}
   282  		// At this point, the client must be blocked in c.clientEP.SendRecv().
   283  		if remoteShutdown {
   284  			c.serverEP.Shutdown()
   285  		} else {
   286  			c.clientEP.Shutdown()
   287  		}
   288  	}()
   289  	defer func() {
   290  		// Ensure that the server goroutine is cleaned up before
   291  		// c.serverEP.Destroy(), even if the test fails.
   292  		c.serverEP.Shutdown()
   293  		serverRun.Wait()
   294  	}()
   295  	if err := c.clientEP.Connect(); err != nil {
   296  		t.Fatalf("client Endpoint.Connect() failed: %v", err)
   297  	}
   298  	if _, err := c.clientEP.SendRecv(0); err == nil {
   299  		t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly")
   300  	}
   301  }
   302  
   303  func TestShutdownDuringClientSendRecvLocal(t *testing.T) {
   304  	c := newTestConnection(t)
   305  	defer c.destroy()
   306  	testShutdownDuringClientSendRecv(t, c, false)
   307  }
   308  
   309  func TestShutdownDuringClientSendRecvRemote(t *testing.T) {
   310  	c := newTestConnection(t)
   311  	defer c.destroy()
   312  	testShutdownDuringClientSendRecv(t, c, true)
   313  }
   314  
   315  func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
   316  	var serverRun sync.WaitGroup
   317  	serverRun.Add(1)
   318  	go func() {
   319  		defer serverRun.Done()
   320  		if _, err := c.serverEP.RecvFirst(); err != nil {
   321  			t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
   322  			return
   323  		}
   324  		if _, err := c.serverEP.SendRecv(0); err == nil {
   325  			t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
   326  		}
   327  	}()
   328  	defer func() {
   329  		// Ensure that the server goroutine is cleaned up before
   330  		// c.serverEP.Destroy(), even if the test fails.
   331  		c.serverEP.Shutdown()
   332  		serverRun.Wait()
   333  	}()
   334  	if err := c.clientEP.Connect(); err != nil {
   335  		t.Fatalf("client Endpoint.Connect() failed: %v", err)
   336  	}
   337  	if _, err := c.clientEP.SendRecv(0); err != nil {
   338  		t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
   339  	}
   340  	time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
   341  	if remoteShutdown {
   342  		c.clientEP.Shutdown()
   343  	} else {
   344  		c.serverEP.Shutdown()
   345  	}
   346  	serverRun.Wait()
   347  }
   348  
   349  func TestShutdownDuringServerSendRecvLocal(t *testing.T) {
   350  	c := newTestConnection(t)
   351  	defer c.destroy()
   352  	testShutdownDuringServerSendRecv(t, c, false)
   353  }
   354  
   355  func TestShutdownDuringServerSendRecvRemote(t *testing.T) {
   356  	c := newTestConnection(t)
   357  	defer c.destroy()
   358  	testShutdownDuringServerSendRecv(t, c, true)
   359  }
   360  
   361  func benchmarkSendRecv(b *testing.B, c *testConnection) {
   362  	var serverRun sync.WaitGroup
   363  	serverRun.Add(1)
   364  	go func() {
   365  		defer serverRun.Done()
   366  		if b.N == 0 {
   367  			return
   368  		}
   369  		if _, err := c.serverEP.RecvFirst(); err != nil {
   370  			b.Errorf("server Endpoint.RecvFirst() failed: %v", err)
   371  			return
   372  		}
   373  		for i := 1; i < b.N; i++ {
   374  			if _, err := c.serverEP.SendRecv(0); err != nil {
   375  				b.Errorf("server Endpoint.SendRecv() failed: %v", err)
   376  				return
   377  			}
   378  		}
   379  		if err := c.serverEP.SendLast(0); err != nil {
   380  			b.Errorf("server Endpoint.SendLast() failed: %v", err)
   381  		}
   382  	}()
   383  	defer func() {
   384  		c.serverEP.Shutdown()
   385  		serverRun.Wait()
   386  	}()
   387  
   388  	if err := c.clientEP.Connect(); err != nil {
   389  		b.Fatalf("client Endpoint.Connect() failed: %v", err)
   390  	}
   391  	runtime.GC()
   392  	b.ResetTimer()
   393  	for i := 0; i < b.N; i++ {
   394  		if _, err := c.clientEP.SendRecv(0); err != nil {
   395  			b.Fatalf("client Endpoint.SendRecv() failed: %v", err)
   396  		}
   397  	}
   398  	b.StopTimer()
   399  }
   400  
   401  func BenchmarkSendRecv(b *testing.B) {
   402  	c := newTestConnection(b)
   403  	defer c.destroy()
   404  	benchmarkSendRecv(b, c)
   405  }