github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/common_test.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain_test
    16  
    17  import (
    18  	"crypto/rand"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"log"
    23  	"net"
    24  	"os"
    25  	"sync"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/pawelgaczynski/gain"
    31  	gainErrors "github.com/pawelgaczynski/gain/pkg/errors"
    32  	gainNet "github.com/pawelgaczynski/gain/pkg/net"
    33  	. "github.com/stretchr/testify/require"
    34  )
    35  
    36  type testServerConfig struct {
    37  	protocol              string
    38  	numberOfClients       int
    39  	numberOfWorkers       int
    40  	cpuAffinity           bool
    41  	asyncHandler          bool
    42  	goroutinePool         bool
    43  	waitForDialAllClients bool
    44  	afterDial             afterDialCallback
    45  	writesCount           int
    46  	configOptions         []gain.ConfigOption
    47  
    48  	readHandler onReadCallback
    49  }
    50  
    51  var defaultTestOnReadCallback = func(c gain.Conn, n int, network string) {
    52  	buffer := make([]byte, 128)
    53  
    54  	_, err := c.Read(buffer)
    55  	if err != nil {
    56  		if errors.Is(err, gainErrors.ErrIsEmpty) {
    57  			return
    58  		}
    59  
    60  		log.Panic(err)
    61  	}
    62  
    63  	if string(buffer[0:6]) != "cindex" {
    64  		log.Panic(fmt.Errorf("unexpected data: %s", string(buffer[0:6])))
    65  	}
    66  
    67  	_, err = c.Write(append(buffer[0:10], []byte("TESTpayload12345")...))
    68  	if err != nil {
    69  		log.Panic(err)
    70  	}
    71  }
    72  
    73  type callbacksHolder struct {
    74  	onStartCallback  onStartCallback
    75  	onAcceptCallback onAcceptCallback
    76  	onReadCallback   onReadCallback
    77  	onWriteCallback  onWriteCallback
    78  	onCloseCallback  onCloseCallback
    79  }
    80  
    81  type testServerHandler struct {
    82  	callbacksHolder
    83  
    84  	onStartCount  atomic.Uint32
    85  	onAcceptCount atomic.Uint32
    86  	onReadCount   atomic.Uint32
    87  	onWriteCount  atomic.Uint32
    88  	onCloseCount  atomic.Uint32
    89  
    90  	startedWg  *sync.WaitGroup
    91  	onAcceptWg *sync.WaitGroup
    92  	onReadWg   *sync.WaitGroup
    93  	onWriteWg  *sync.WaitGroup
    94  	onCloseWg  *sync.WaitGroup
    95  
    96  	finished atomic.Bool
    97  
    98  	network string
    99  }
   100  
   101  func (h *testServerHandler) OnStart(server gain.Server) {
   102  	if !h.finished.Load() {
   103  		h.startedWg.Done()
   104  
   105  		if h.onStartCallback != nil {
   106  			h.onStartCallback(server, h.network)
   107  		}
   108  
   109  		h.onStartCount.Add(1)
   110  	}
   111  }
   112  
   113  func (h *testServerHandler) OnAccept(c gain.Conn) {
   114  	if !h.finished.Load() {
   115  		if h.onAcceptCallback != nil {
   116  			h.onAcceptCallback(c, h.network)
   117  		}
   118  
   119  		h.onAcceptCount.Add(1)
   120  
   121  		if h.onAcceptWg != nil {
   122  			h.onAcceptWg.Done()
   123  		}
   124  	}
   125  }
   126  
   127  func (h *testServerHandler) OnClose(c gain.Conn, err error) {
   128  	if !h.finished.Load() {
   129  		if h.onCloseCallback != nil {
   130  			h.onCloseCallback(c, err, h.network)
   131  		}
   132  
   133  		h.onCloseCount.Add(1)
   134  
   135  		if h.onCloseWg != nil {
   136  			h.onCloseWg.Done()
   137  		}
   138  	}
   139  }
   140  
   141  func (h *testServerHandler) OnRead(conn gain.Conn, n int) {
   142  	if !h.finished.Load() {
   143  		if h.onReadCallback != nil {
   144  			h.onReadCallback(conn, n, h.network)
   145  		}
   146  
   147  		h.onReadCount.Add(1)
   148  
   149  		if h.onReadWg != nil {
   150  			h.onReadWg.Done()
   151  		}
   152  	}
   153  }
   154  
   155  func (h *testServerHandler) OnWrite(c gain.Conn, n int) {
   156  	if !h.finished.Load() {
   157  		if h.onWriteCallback != nil {
   158  			h.onWriteCallback(c, n, h.network)
   159  		}
   160  
   161  		h.onWriteCount.Add(1)
   162  
   163  		if h.onWriteWg != nil {
   164  			h.onWriteWg.Done()
   165  		}
   166  	}
   167  }
   168  
   169  type afterDialCallback func(*testing.T, net.Conn, int, int)
   170  
   171  var deafultAfterDial = func(t *testing.T, conn net.Conn, repeats, clientIndex int) {
   172  	t.Helper()
   173  	err := conn.SetDeadline(time.Now().Add(time.Second * 2))
   174  	Nil(t, err)
   175  
   176  	clientIndexBytes := []byte(fmt.Sprintf("cindex%04d", clientIndex))
   177  
   178  	for i := 0; i < repeats; i++ {
   179  		var bytesWritten int
   180  		bytesWritten, err = conn.Write(append(clientIndexBytes, []byte("testdata1234567890")...))
   181  
   182  		Nil(t, err)
   183  		Equal(t, 28, bytesWritten)
   184  		var buffer [64]byte
   185  		var bytesRead int
   186  		bytesRead, err = conn.Read(buffer[:])
   187  
   188  		Nil(t, err)
   189  		Equal(t, 26, bytesRead)
   190  		Equal(t, string(append(clientIndexBytes, "TESTpayload12345"...)),
   191  			string(buffer[:bytesRead]), "CONNFD: %d", getFdFromConn(conn))
   192  	}
   193  }
   194  
   195  func dialClient(t *testing.T, protocol string, port int, clientConnChan chan net.Conn) {
   196  	t.Helper()
   197  	conn, err := net.DialTimeout(protocol, fmt.Sprintf("127.0.0.1:%d", port), time.Second)
   198  	Nil(t, err)
   199  	NotNil(t, conn)
   200  	clientConnChan <- conn
   201  }
   202  
   203  func dialClientRW(t *testing.T, protocol string, port int,
   204  	afterDial afterDialCallback, repeats, clientIndex int, clientConnChan chan net.Conn,
   205  ) {
   206  	t.Helper()
   207  	conn, err := net.DialTimeout(protocol, fmt.Sprintf("127.0.0.1:%d", port), 2*time.Second)
   208  	Nil(t, err)
   209  	NotNil(t, conn)
   210  	afterDial(t, conn, repeats, clientIndex)
   211  	clientConnChan <- conn
   212  }
   213  
   214  func newTestServerHandler(onReadCallback onReadCallback, network string) *testServerHandler {
   215  	testHandler := &testServerHandler{
   216  		network: network,
   217  	}
   218  
   219  	var startedWg sync.WaitGroup
   220  
   221  	startedWg.Add(1)
   222  	testHandler.startedWg = &startedWg
   223  
   224  	if onReadCallback != nil {
   225  		testHandler.onReadCallback = onReadCallback
   226  	} else {
   227  		testHandler.onReadCallback = defaultTestOnReadCallback
   228  	}
   229  
   230  	return testHandler
   231  }
   232  
   233  func testServer(t *testing.T, testConfig testServerConfig, architecture gain.ServerArchitecture) {
   234  	t.Helper()
   235  
   236  	if testConfig.protocol == "" {
   237  		log.Panic("network protocol is missing")
   238  	}
   239  	opts := []gain.ConfigOption{
   240  		gain.WithLoggerLevel(getTestLoggerLevel()),
   241  		gain.WithAsyncHandler(testConfig.asyncHandler),
   242  		gain.WithGoroutinePool(testConfig.goroutinePool),
   243  		gain.WithCPUAffinity(testConfig.cpuAffinity),
   244  		gain.WithWorkers(testConfig.numberOfWorkers),
   245  		gain.WithCBPF(false),
   246  		gain.WithArchitecture(architecture),
   247  	}
   248  
   249  	if testConfig.configOptions != nil {
   250  		opts = append(opts, testConfig.configOptions...)
   251  	}
   252  
   253  	config := gain.NewConfig(opts...)
   254  
   255  	testHandler := newTestServerHandler(testConfig.readHandler, testConfig.protocol)
   256  
   257  	server := gain.NewServer(testHandler, config)
   258  
   259  	defer func() {
   260  		server.Shutdown()
   261  	}()
   262  	testPort := getTestPort()
   263  
   264  	go func() {
   265  		err := server.Start(fmt.Sprintf("%s://127.0.0.1:%d", testConfig.protocol, testPort))
   266  		if err != nil {
   267  			log.Panic(err)
   268  		}
   269  	}()
   270  
   271  	clientConnChan := make(chan net.Conn, testConfig.numberOfClients)
   272  
   273  	testHandler.startedWg.Wait()
   274  
   275  	if testConfig.waitForDialAllClients {
   276  		clientConnectWG := new(sync.WaitGroup)
   277  		clientConnectWG.Add(testConfig.numberOfClients)
   278  		testHandler.onAcceptCallback = func(c gain.Conn, _ string) {
   279  			clientConnectWG.Done()
   280  		}
   281  
   282  		for i := 0; i < testConfig.numberOfClients; i++ {
   283  			go dialClient(t, testConfig.protocol, testPort, clientConnChan)
   284  		}
   285  		clientConnectWG.Wait()
   286  		Equal(t, testConfig.numberOfClients, server.ActiveConnections())
   287  
   288  		for i := 0; i < testConfig.numberOfClients; i++ {
   289  			conn := <-clientConnChan
   290  			NotNil(t, conn)
   291  
   292  			if tcpConn, ok := conn.(*net.TCPConn); ok {
   293  				err := tcpConn.SetLinger(0)
   294  				Nil(t, err)
   295  			}
   296  		}
   297  	} else {
   298  		var clientConnectWG *sync.WaitGroup
   299  		if testConfig.protocol == gainNet.TCP {
   300  			clientConnectWG = new(sync.WaitGroup)
   301  			clientConnectWG.Add(testConfig.numberOfClients)
   302  		}
   303  		clientRWWG := new(sync.WaitGroup)
   304  		if testConfig.writesCount == 0 {
   305  			testConfig.writesCount = 1
   306  		}
   307  		clientRWWG.Add(testConfig.numberOfClients * testConfig.writesCount)
   308  		if testConfig.protocol == gainNet.TCP {
   309  			testHandler.onAcceptCallback = func(c gain.Conn, _ string) {
   310  				clientConnectWG.Done()
   311  			}
   312  		}
   313  		testHandler.onWriteCallback = func(c gain.Conn, n int, network string) {
   314  			clientRWWG.Done()
   315  		}
   316  		afterDial := deafultAfterDial
   317  		if testConfig.afterDial != nil {
   318  			afterDial = testConfig.afterDial
   319  		}
   320  		for i := 0; i < testConfig.numberOfClients; i++ {
   321  			go func(clientIndex int) {
   322  				dialClientRW(t, testConfig.protocol, testPort, afterDial, testConfig.writesCount, clientIndex, clientConnChan)
   323  			}(i)
   324  		}
   325  		if testConfig.protocol == gainNet.TCP {
   326  			clientConnectWG.Wait()
   327  		}
   328  		clientRWWG.Wait()
   329  		for i := 0; i < testConfig.numberOfClients; i++ {
   330  			conn := <-clientConnChan
   331  			NotNil(t, conn)
   332  			if tcpConn, ok := conn.(*net.TCPConn); ok {
   333  				err := tcpConn.SetLinger(0)
   334  				Nil(t, err)
   335  			}
   336  		}
   337  	}
   338  }
   339  
   340  var randomDataSize128 = make([]byte, 128)
   341  
   342  type RingBufferTestDataHandler struct {
   343  	t            *testing.T
   344  	testFinished atomic.Bool
   345  }
   346  
   347  func (r *RingBufferTestDataHandler) OnRead(conn gain.Conn, _ int, _ string) {
   348  	buffer := make([]byte, 128)
   349  	bytesRead, readErr := conn.Read(buffer)
   350  
   351  	if !r.testFinished.Load() {
   352  		Equal(r.t, 128, bytesRead)
   353  
   354  		if readErr != nil {
   355  			log.Panic(readErr)
   356  		}
   357  		bytesWritten, writeErr := conn.Write(randomDataSize128)
   358  		Equal(r.t, 128, bytesWritten)
   359  
   360  		if writeErr != nil {
   361  			log.Panic(writeErr)
   362  		}
   363  	}
   364  }
   365  
   366  func testRingBuffer(t *testing.T, protocol string, architecture gain.ServerArchitecture) {
   367  	t.Helper()
   368  	handler := RingBufferTestDataHandler{
   369  		t: t,
   370  	}
   371  	bytesRandom, err := rand.Read(randomDataSize128)
   372  	Nil(t, err)
   373  	Equal(t, 128, bytesRandom)
   374  	writesCount := 1000
   375  	testServer(t, testServerConfig{
   376  		numberOfClients: 1,
   377  		numberOfWorkers: 1,
   378  		protocol:        protocol,
   379  		readHandler:     handler.OnRead,
   380  		writesCount:     writesCount,
   381  		afterDial: func(t *testing.T, conn net.Conn, _, _ int) {
   382  			t.Helper()
   383  			deadlineErr := conn.SetDeadline(time.Now().Add(time.Second * 1))
   384  			Nil(t, deadlineErr)
   385  			var buffer [256]byte
   386  			for i := 0; i < writesCount; i++ {
   387  				bytesWritten, writeErr := conn.Write(randomDataSize128)
   388  				Nil(t, writeErr)
   389  				Equal(t, 128, bytesWritten)
   390  				bytesRead, readErr := conn.Read(buffer[:])
   391  				Nil(t, readErr)
   392  				Equal(t, 128, bytesRead)
   393  				Equal(t, randomDataSize128, buffer[:bytesRead])
   394  			}
   395  			handler.testFinished.Store(true)
   396  		},
   397  	}, architecture)
   398  }
   399  
   400  func testCloseServer(t *testing.T, network string, architecture gain.ServerArchitecture, doubleShutdown bool) {
   401  	t.Helper()
   402  	testHandler := newConnServerTester(network, 10, false)
   403  	server, port := newTestConnServer(t, network, false, architecture, testHandler.testServerHandler)
   404  	clientsGroup := newTestConnClientGroup(t, network, port, 10)
   405  	clientsGroup.Dial()
   406  
   407  	data := make([]byte, 512)
   408  
   409  	_, err := rand.Read(data)
   410  	Nil(t, err)
   411  	clientsGroup.SetDeadline(time.Now().Add(time.Second))
   412  	clientsGroup.Write(data)
   413  	buffer := make([]byte, 512)
   414  
   415  	clientsGroup.SetDeadline(time.Now().Add(time.Second))
   416  	clientsGroup.Read(buffer)
   417  
   418  	clientsGroup.SetDeadline(time.Time{})
   419  
   420  	testHandler.waitForWrites()
   421  	clientsGroup.Close()
   422  	server.Shutdown()
   423  
   424  	if doubleShutdown {
   425  		server.Shutdown()
   426  	}
   427  }
   428  
   429  func testCloseServerWithConnectedClients(t *testing.T, architecture gain.ServerArchitecture) {
   430  	t.Helper()
   431  	testHandler := newConnServerTester(gainNet.TCP, 10, false)
   432  	server, port := newTestConnServer(t, gainNet.TCP, false, architecture, testHandler.testServerHandler)
   433  
   434  	clientsGroup := newTestConnClientGroup(t, gainNet.TCP, port, 10)
   435  	clientsGroup.Dial()
   436  
   437  	data := make([]byte, 1024)
   438  	_, err := rand.Read(data)
   439  	Nil(t, err)
   440  	clientsGroup.Write(data)
   441  	buffer := make([]byte, 1024)
   442  	clientsGroup.Read(buffer)
   443  
   444  	testHandler.waitForWrites()
   445  	server.Shutdown()
   446  }
   447  
   448  func testCloseConn(t *testing.T, async bool, architecture gain.ServerArchitecture, justClose bool) {
   449  	t.Helper()
   450  	testHandler := newTestServerHandler(func(conn gain.Conn, n int, network string) {
   451  		if !justClose {
   452  			buf, err := conn.Next(n)
   453  			if err != nil {
   454  				log.Panic(err)
   455  			}
   456  
   457  			_, err = conn.Write(buf)
   458  			if err != nil {
   459  				log.Panic(err)
   460  			}
   461  		}
   462  
   463  		err := conn.Close()
   464  		if err != nil {
   465  			log.Panic(err)
   466  		}
   467  	}, gainNet.TCP)
   468  
   469  	server, port := newTestConnServer(t, gainNet.TCP, async, architecture, testHandler)
   470  
   471  	var clientDoneWg sync.WaitGroup
   472  
   473  	clientDoneWg.Add(1)
   474  
   475  	go func(wg *sync.WaitGroup) {
   476  		conn, cErr := net.DialTimeout(gainNet.TCP, fmt.Sprintf("127.0.0.1:%d", port), time.Second)
   477  		Nil(t, cErr)
   478  		NotNil(t, conn)
   479  		testData := []byte("testdata1234567890")
   480  		bytesN, cErr := conn.Write(testData)
   481  		Nil(t, cErr)
   482  		Equal(t, len(testData), bytesN)
   483  		buffer := make([]byte, len(testData))
   484  		bytesN, cErr = conn.Read(buffer)
   485  
   486  		if !justClose {
   487  			Nil(t, cErr)
   488  			Equal(t, len(testData), bytesN)
   489  			Equal(t, testData, buffer)
   490  			bytesN, cErr = conn.Write(testData)
   491  			Nil(t, cErr)
   492  			Equal(t, len(testData), bytesN)
   493  			bytesN, cErr = conn.Read(buffer)
   494  		}
   495  
   496  		Equal(t, io.EOF, cErr)
   497  		Equal(t, 0, bytesN)
   498  		wg.Done()
   499  	}(&clientDoneWg)
   500  
   501  	clientDoneWg.Wait()
   502  	server.Shutdown()
   503  }
   504  
   505  func testLargeRead(t *testing.T, network string, architecture gain.ServerArchitecture) {
   506  	t.Helper()
   507  
   508  	if !checkKernelCompatibility(5, 19) {
   509  		//nolint
   510  		fmt.Println("Not supported by kernel")
   511  
   512  		return
   513  	}
   514  
   515  	doublePageSize := os.Getpagesize() * 4
   516  	data := make([]byte, doublePageSize)
   517  	_, err := rand.Read(data)
   518  	Nil(t, err)
   519  
   520  	var doneWg sync.WaitGroup
   521  
   522  	doneWg.Add(1)
   523  	onReadCallback := func(c gain.Conn, _ int, _ string) {
   524  		readBuffer := make([]byte, doublePageSize)
   525  
   526  		n, cErr := c.Read(readBuffer)
   527  		if err != nil {
   528  			log.Panic(cErr)
   529  		}
   530  
   531  		doneWg.Done()
   532  		Equal(t, doublePageSize, n)
   533  
   534  		n, cErr = c.Write(readBuffer)
   535  		if cErr != nil {
   536  			log.Panic(cErr)
   537  		}
   538  
   539  		Equal(t, doublePageSize, n)
   540  	}
   541  
   542  	testConnHandler := newTestServerHandler(onReadCallback, network)
   543  	server, port := newTestConnServer(t, network, false, architecture, testConnHandler)
   544  
   545  	clientsGroup := newTestConnClientGroup(t, network, port, 1)
   546  	clientsGroup.Dial()
   547  
   548  	clientsGroup.Write(data)
   549  	buffer := make([]byte, len(data))
   550  	clientsGroup.Read(buffer)
   551  
   552  	Equal(t, data, buffer)
   553  
   554  	doneWg.Wait()
   555  
   556  	server.Shutdown()
   557  }
   558  
   559  func testMultipleReads(t *testing.T, network string, asyncHandler bool, architecture gain.ServerArchitecture) {
   560  	t.Helper()
   561  
   562  	pageSize := os.Getpagesize()
   563  	data := make([]byte, pageSize)
   564  	_, err := rand.Read(data)
   565  	Nil(t, err)
   566  
   567  	var (
   568  		doneWg        sync.WaitGroup
   569  		expectedReads int64 = 10
   570  		readsCount    atomic.Int64
   571  	)
   572  
   573  	doneWg.Add(int(expectedReads))
   574  	onReadCallback := func(c gain.Conn, _ int, _ string) {
   575  		readBuffer := make([]byte, pageSize)
   576  
   577  		n, cErr := c.Read(readBuffer)
   578  		if err != nil {
   579  			log.Panic(cErr)
   580  		}
   581  
   582  		readsCount.Add(1)
   583  		doneWg.Done()
   584  		Equal(t, pageSize, n)
   585  	}
   586  
   587  	testConnHandler := newTestServerHandler(onReadCallback, network)
   588  	server, port := newTestConnServer(t, network, asyncHandler, architecture, testConnHandler)
   589  
   590  	clientsGroup := newTestConnClientGroup(t, network, port, 1)
   591  	clientsGroup.Dial()
   592  
   593  	go func() {
   594  		for i := 0; i < int(expectedReads); i++ {
   595  			clientsGroup.Write(data)
   596  			time.Sleep(time.Millisecond * 100)
   597  		}
   598  	}()
   599  
   600  	doneWg.Wait()
   601  
   602  	Equal(t, expectedReads, readsCount.Load())
   603  
   604  	server.Shutdown()
   605  }