github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/common_conn_test.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain_test
    16  
    17  import (
    18  	"crypto/rand"
    19  	"fmt"
    20  	"log"
    21  	"net"
    22  	"os/exec"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/pawelgaczynski/gain"
    30  	. "github.com/stretchr/testify/require"
    31  )
    32  
    33  type connServerTester struct {
    34  	*testServerHandler
    35  	mutex                  sync.Mutex
    36  	writeWG                *sync.WaitGroup
    37  	writeCount             uint32
    38  	targetWriteCount       uint32
    39  	removeWGAfterMinWrites bool
    40  }
    41  
    42  func (t *connServerTester) waitForWrites() {
    43  	t.writeWG.Wait()
    44  }
    45  
    46  func (t *connServerTester) onReadCallback(conn gain.Conn, n int, _ string) {
    47  	buf, _ := conn.Next(n)
    48  	_, _ = conn.Write(buf)
    49  }
    50  
    51  func (t *connServerTester) onWriteCallback(_ gain.Conn, _ int, _ string) {
    52  	if t.writeWG != nil {
    53  		t.mutex.Lock()
    54  
    55  		t.writeCount++
    56  		if t.writeCount >= t.targetWriteCount {
    57  			t.writeWG.Done()
    58  
    59  			if t.removeWGAfterMinWrites {
    60  				t.writeWG = nil
    61  			}
    62  		}
    63  		t.mutex.Unlock()
    64  	}
    65  }
    66  
    67  func newConnServerTester(network string, writeCount int, removeWGAfterMinWrites bool) *connServerTester {
    68  	connServerTester := &connServerTester{}
    69  
    70  	if writeCount > 0 {
    71  		var writeWG sync.WaitGroup
    72  
    73  		writeWG.Add(1)
    74  		connServerTester.writeWG = &writeWG
    75  		connServerTester.targetWriteCount = uint32(writeCount)
    76  		connServerTester.removeWGAfterMinWrites = removeWGAfterMinWrites
    77  	}
    78  
    79  	testConnHandler := newTestServerHandler(connServerTester.onReadCallback, network)
    80  
    81  	testConnHandler.onWriteCallback = connServerTester.onWriteCallback
    82  	connServerTester.testServerHandler = testConnHandler
    83  
    84  	return connServerTester
    85  }
    86  
    87  func newEventHandlerTester(callbacks callbacksHolder, network string) *testServerHandler {
    88  	testHandler := &testServerHandler{
    89  		network: network,
    90  	}
    91  
    92  	var (
    93  		startedWg  sync.WaitGroup
    94  		onAcceptWg sync.WaitGroup
    95  		onReadWg   sync.WaitGroup
    96  		onWriteWg  sync.WaitGroup
    97  		onCloseWg  sync.WaitGroup
    98  	)
    99  
   100  	startedWg.Add(1)
   101  	testHandler.startedWg = &startedWg
   102  	testHandler.onAcceptWg = &onAcceptWg
   103  	testHandler.onReadWg = &onReadWg
   104  	testHandler.onWriteWg = &onWriteWg
   105  	testHandler.onCloseWg = &onCloseWg
   106  
   107  	testHandler.onStartCallback = callbacks.onStartCallback
   108  	testHandler.onAcceptCallback = callbacks.onAcceptCallback
   109  	testHandler.onReadCallback = callbacks.onReadCallback
   110  	testHandler.onWriteCallback = callbacks.onWriteCallback
   111  	testHandler.onCloseCallback = callbacks.onCloseCallback
   112  
   113  	return testHandler
   114  }
   115  
   116  type testConnClient struct {
   117  	t       *testing.T
   118  	conn    net.Conn
   119  	network string
   120  	port    int
   121  	idx     int
   122  }
   123  
   124  func (c *testConnClient) Dial() {
   125  	conn, err := net.DialTimeout(c.network, fmt.Sprintf("127.0.0.1:%d", c.port), time.Second)
   126  	Nil(c.t, err)
   127  	NotNil(c.t, conn)
   128  	c.conn = conn
   129  }
   130  
   131  func (c *testConnClient) Close() {
   132  	err := c.conn.Close()
   133  	Nil(c.t, err)
   134  }
   135  
   136  func (c *testConnClient) SetDeadline(t time.Time) {
   137  	err := c.conn.SetDeadline(t)
   138  	Nil(c.t, err)
   139  }
   140  
   141  func (c *testConnClient) Write(buffer []byte) {
   142  	bytesWritten, writeErr := c.conn.Write(buffer)
   143  	Nil(c.t, writeErr)
   144  	Equal(c.t, len(buffer), bytesWritten)
   145  }
   146  
   147  func (c *testConnClient) Read(buffer []byte) {
   148  	bytesRead, readErr := c.conn.Read(buffer)
   149  	Nil(c.t, readErr)
   150  	Equal(c.t, len(buffer), bytesRead)
   151  }
   152  
   153  func newTestConnClient(t *testing.T, idx int, network string, port int) *testConnClient {
   154  	t.Helper()
   155  
   156  	return &testConnClient{
   157  		t:       t,
   158  		network: network,
   159  		port:    port,
   160  		idx:     idx,
   161  	}
   162  }
   163  
   164  type testConnClientGroup struct {
   165  	clients []*testConnClient
   166  }
   167  
   168  func (c *testConnClientGroup) Dial() {
   169  	for i := 0; i < len(c.clients); i++ {
   170  		c.clients[i].Dial()
   171  	}
   172  }
   173  
   174  func (c *testConnClientGroup) Close() {
   175  	for i := 0; i < len(c.clients); i++ {
   176  		c.clients[i].Close()
   177  	}
   178  }
   179  
   180  func (c *testConnClientGroup) SetDeadline(t time.Time) {
   181  	for i := 0; i < len(c.clients); i++ {
   182  		c.clients[i].SetDeadline(t)
   183  	}
   184  }
   185  
   186  func (c *testConnClientGroup) Write(buffer []byte) {
   187  	for i := 0; i < len(c.clients); i++ {
   188  		c.clients[i].Write(buffer)
   189  	}
   190  }
   191  
   192  func (c *testConnClientGroup) Read(buffer []byte) {
   193  	for i := 0; i < len(c.clients); i++ {
   194  		c.clients[i].Read(buffer)
   195  	}
   196  }
   197  
   198  func newTestConnClientGroup(t *testing.T, network string, port int, n int) *testConnClientGroup {
   199  	t.Helper()
   200  	group := &testConnClientGroup{
   201  		clients: make([]*testConnClient, n),
   202  	}
   203  
   204  	for i := 0; i < n; i++ {
   205  		group.clients[i] = newTestConnClient(t, i, network, port)
   206  	}
   207  
   208  	return group
   209  }
   210  
   211  func newTestConnServer(
   212  	t *testing.T, network string, async bool, architecture gain.ServerArchitecture, eventHandler *testServerHandler,
   213  ) (gain.Server, int) {
   214  	t.Helper()
   215  	opts := []gain.ConfigOption{
   216  		gain.WithLoggerLevel(getTestLoggerLevel()),
   217  		gain.WithWorkers(4),
   218  		gain.WithArchitecture(architecture),
   219  		gain.WithAsyncHandler(async),
   220  		gain.WithMaxSQEntries(1024),
   221  		gain.WithMaxCQEvents(1024),
   222  	}
   223  
   224  	config := gain.NewConfig(opts...)
   225  
   226  	server := gain.NewServer(eventHandler, config)
   227  	testPort := getTestPort()
   228  
   229  	go func() {
   230  		err := server.Start(fmt.Sprintf("%s://127.0.0.1:%d", network, testPort))
   231  		if err != nil {
   232  			log.Panic(err)
   233  		}
   234  	}()
   235  
   236  	eventHandler.startedWg.Wait()
   237  
   238  	return server, int(port)
   239  }
   240  
   241  func getIPAndPort(addr net.Addr) (string, int) {
   242  	switch addr := addr.(type) {
   243  	case *net.UDPAddr:
   244  		return addr.IP.String(), addr.Port
   245  	case *net.TCPAddr:
   246  		return addr.IP.String(), addr.Port
   247  	}
   248  
   249  	return "", 0
   250  }
   251  
   252  func testConnAddress(
   253  	t *testing.T, network string, architecture gain.ServerArchitecture,
   254  ) {
   255  	t.Helper()
   256  	numberOfClients := 10
   257  	opts := []gain.ConfigOption{
   258  		gain.WithLoggerLevel(getTestLoggerLevel()),
   259  		gain.WithWorkers(4),
   260  		gain.WithArchitecture(architecture),
   261  	}
   262  
   263  	config := gain.NewConfig(opts...)
   264  
   265  	out, err := exec.Command("bash", "-c", "sysctl net.ipv4.ip_local_port_range | awk '{ print $3; }'").Output()
   266  	if err != nil {
   267  		log.Panic(err)
   268  	}
   269  
   270  	lowestEphemeralPort, err := strconv.Atoi(strings.ReplaceAll(string(out), "\n", ""))
   271  	if err != nil {
   272  		log.Panic(err)
   273  	}
   274  
   275  	verifyAddresses := func(t *testing.T, conn gain.Conn) {
   276  		t.Helper()
   277  		localAddr := conn.LocalAddr()
   278  		NotNil(t, localAddr)
   279  
   280  		ip, port := getIPAndPort(localAddr)
   281  		Equal(t, "127.0.0.1", ip)
   282  		Less(t, port, 10000)
   283  		GreaterOrEqual(t, port, 9000)
   284  		remoteAddr := conn.RemoteAddr()
   285  
   286  		ip, port = getIPAndPort(remoteAddr)
   287  		NotNil(t, remoteAddr)
   288  		Equal(t, "127.0.0.1", ip)
   289  		GreaterOrEqual(t, port, lowestEphemeralPort)
   290  	}
   291  
   292  	var wg sync.WaitGroup
   293  
   294  	wg.Add(numberOfClients)
   295  
   296  	onReadCallback := func(conn gain.Conn, n int, _ string) {
   297  		buf, _ := conn.Next(n)
   298  		_, _ = conn.Write(buf)
   299  
   300  		verifyAddresses(t, conn)
   301  
   302  		wg.Done()
   303  	}
   304  
   305  	testHandler := newTestServerHandler(onReadCallback, network)
   306  
   307  	server := gain.NewServer(testHandler, config)
   308  	testPort := getTestPort()
   309  
   310  	go func() {
   311  		serverErr := server.Start(fmt.Sprintf("%s://127.0.0.1:%d", network, testPort))
   312  		if err != nil {
   313  			log.Panic(serverErr)
   314  		}
   315  	}()
   316  
   317  	testHandler.startedWg.Wait()
   318  
   319  	clientsGroup := newTestConnClientGroup(t, network, testPort, numberOfClients)
   320  	clientsGroup.Dial()
   321  
   322  	data := make([]byte, 1024)
   323  	_, err = rand.Read(data)
   324  	Nil(t, err)
   325  	clientsGroup.Write(data)
   326  	buffer := make([]byte, 1024)
   327  	clientsGroup.Read(buffer)
   328  
   329  	wg.Wait()
   330  	server.Shutdown()
   331  }