github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/event_handler_test.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain_test
    16  
    17  import (
    18  	"crypto/rand"
    19  	"errors"
    20  	"fmt"
    21  	"log"
    22  	"net"
    23  	"sync"
    24  	"syscall"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/pawelgaczynski/gain"
    29  	gainErrors "github.com/pawelgaczynski/gain/pkg/errors"
    30  	gainNet "github.com/pawelgaczynski/gain/pkg/net"
    31  	. "github.com/stretchr/testify/require"
    32  )
    33  
    34  const (
    35  	tcp = iota
    36  	udp
    37  	both
    38  )
    39  
    40  type clientBehavior func(net.Conn)
    41  
    42  func testHandlerMethod(
    43  	t *testing.T, network string, asyncHandler bool, architecture gain.ServerArchitecture,
    44  	callbacks callbacksHolder, clientBehavior clientBehavior, callCounts []int, shutdown bool,
    45  ) {
    46  	t.Helper()
    47  	Equal(t, 4, len(callCounts))
    48  
    49  	eventHandlerTester := newEventHandlerTester(callbacks, network)
    50  	eventHandlerTester.onAcceptWg.Add(callCounts[0])
    51  	eventHandlerTester.onReadWg.Add(callCounts[1])
    52  	eventHandlerTester.onWriteWg.Add(callCounts[2])
    53  	eventHandlerTester.onCloseWg.Add(callCounts[3])
    54  
    55  	server, port := newTestConnServer(t, network, asyncHandler, architecture, eventHandlerTester)
    56  
    57  	conn, err := net.DialTimeout(network, fmt.Sprintf("127.0.0.1:%d", port), time.Second)
    58  	if err != nil && !errors.Is(err, syscall.ECONNRESET) {
    59  		conn, err = net.DialTimeout(network, fmt.Sprintf("127.0.0.1:%d", port), time.Second)
    60  		if err != nil {
    61  			log.Panic(err)
    62  		}
    63  	}
    64  
    65  	clientBehavior(conn)
    66  
    67  	if callCounts[0] > 0 {
    68  		eventHandlerTester.onAcceptWg.Wait()
    69  	}
    70  
    71  	if callCounts[1] > 0 {
    72  		eventHandlerTester.onReadWg.Wait()
    73  	}
    74  
    75  	if callCounts[2] > 0 {
    76  		eventHandlerTester.onWriteWg.Wait()
    77  	}
    78  
    79  	if callCounts[3] > 0 {
    80  		eventHandlerTester.onCloseWg.Wait()
    81  	}
    82  
    83  	eventHandlerTester.finished.Store(true)
    84  
    85  	Equal(t, 1, int(eventHandlerTester.onStartCount.Load()))
    86  	Equal(t, callCounts[0], int(eventHandlerTester.onAcceptCount.Load()))
    87  	Equal(t, callCounts[1], int(eventHandlerTester.onReadCount.Load()))
    88  	Equal(t, callCounts[2], int(eventHandlerTester.onWriteCount.Load()))
    89  	Equal(t, callCounts[3], int(eventHandlerTester.onCloseCount.Load()))
    90  
    91  	if shutdown {
    92  		server.Shutdown()
    93  	}
    94  }
    95  
    96  const eventHandlerTestDataSize = 512
    97  
    98  var eventHandlerTestData = func() []byte {
    99  	data := make([]byte, eventHandlerTestDataSize)
   100  	_, err := rand.Read(data)
   101  	if err != nil {
   102  		log.Panic(err)
   103  	}
   104  
   105  	return data
   106  }()
   107  
   108  type eventHandlerTestCase struct {
   109  	name           string
   110  	network        string
   111  	async          bool
   112  	architecture   gain.ServerArchitecture
   113  	callbacks      callbacksHolder
   114  	clientBehavior clientBehavior
   115  	callCounts     []int
   116  }
   117  
   118  func testEventHandler(t *testing.T, testCases []eventHandlerTestCase, shutdown bool) {
   119  	t.Helper()
   120  
   121  	for _, testCase := range testCases {
   122  		t.Run(testCase.name, func(t *testing.T) {
   123  			testHandlerMethod(
   124  				t, testCase.network, testCase.async, testCase.architecture,
   125  				testCase.callbacks, testCase.clientBehavior, testCase.callCounts, shutdown,
   126  			)
   127  		})
   128  	}
   129  }
   130  
   131  func createTestCases(
   132  	suffix string, networks int, callbacks callbacksHolder, clientBehavior clientBehavior, callCounts [][]int,
   133  ) []eventHandlerTestCase {
   134  	tcpTestCases := []eventHandlerTestCase{}
   135  
   136  	if networks == tcp || networks == both {
   137  		tcpTestCases = append(tcpTestCases, []eventHandlerTestCase{
   138  			{
   139  				fmt.Sprintf("TestShardingTCPSync%s", suffix),
   140  				gainNet.TCP, false, gain.SocketSharding, callbacks, clientBehavior, callCounts[0],
   141  			},
   142  			{
   143  				fmt.Sprintf("TestShardingTCPAsync%s", suffix),
   144  				gainNet.TCP, true, gain.SocketSharding, callbacks, clientBehavior, callCounts[1],
   145  			},
   146  			{
   147  				fmt.Sprintf("TestReactorTCPSync%s", suffix),
   148  				gainNet.TCP, false, gain.Reactor, callbacks, clientBehavior, callCounts[2],
   149  			},
   150  			{
   151  				fmt.Sprintf("TestReactorTCPAsync%s", suffix),
   152  				gainNet.TCP, true, gain.Reactor, callbacks, clientBehavior, callCounts[3],
   153  			},
   154  		}...)
   155  	}
   156  
   157  	udpTestCases := []eventHandlerTestCase{}
   158  
   159  	if networks == udp || networks == both {
   160  		var index int
   161  		if networks == both {
   162  			index = 4
   163  		}
   164  
   165  		udpTestCases = append(udpTestCases, []eventHandlerTestCase{
   166  			{
   167  				fmt.Sprintf("TestShardingUDPSync%s", suffix),
   168  				gainNet.UDP, false, gain.SocketSharding, callbacks, clientBehavior, callCounts[index],
   169  			},
   170  			{
   171  				fmt.Sprintf("TestShardingUDPAsync%s", suffix),
   172  				gainNet.UDP, true, gain.SocketSharding, callbacks, clientBehavior, callCounts[index+1],
   173  			},
   174  		}...)
   175  	}
   176  	testCases := []eventHandlerTestCase{}
   177  	testCases = append(testCases, tcpTestCases...)
   178  	testCases = append(testCases, udpTestCases...)
   179  
   180  	return testCases
   181  }
   182  
   183  func TestEventHandlerOnRead(t *testing.T) {
   184  	callbacks := callbacksHolder{
   185  		onReadCallback: func(conn gain.Conn, n int, network string) {
   186  			buffer, err := conn.Next(n)
   187  			Nil(t, err)
   188  			Equal(t, eventHandlerTestData, buffer)
   189  		},
   190  	}
   191  	clientBehavior := func(conn net.Conn) {
   192  		err := conn.SetWriteDeadline(time.Now().Add(time.Millisecond * 500))
   193  		if err != nil {
   194  			log.Panic(err)
   195  		}
   196  
   197  		n, err := conn.Write(eventHandlerTestData)
   198  		Equal(t, eventHandlerTestDataSize, n)
   199  		Nil(t, err)
   200  		buffer := make([]byte, 1024)
   201  
   202  		err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * 500))
   203  		if err != nil {
   204  			log.Panic(err)
   205  		}
   206  		n, err = conn.Read(buffer)
   207  		Equal(t, n, 0)
   208  		NotNil(t, err)
   209  		conn.Close()
   210  	}
   211  
   212  	testCases := createTestCases("JustRead", both, callbacks, clientBehavior, [][]int{
   213  		{1, 1, 0, 1},
   214  		{1, 1, 0, 1},
   215  		{1, 1, 0, 1},
   216  		{1, 1, 0, 1},
   217  		{0, 1, 0, 0},
   218  		{0, 1, 0, 0},
   219  	})
   220  
   221  	testEventHandler(t, testCases, true)
   222  
   223  	callbacks = callbacksHolder{
   224  		onReadCallback: func(conn gain.Conn, n int, network string) {
   225  			buffer, err := conn.Next(n)
   226  			Nil(t, err)
   227  			Equal(t, eventHandlerTestData, buffer)
   228  			bytesWritten, err := conn.Write(buffer)
   229  			Nil(t, err)
   230  			Equal(t, eventHandlerTestDataSize, bytesWritten)
   231  		},
   232  		onWriteCallback: func(conn gain.Conn, n int, network string) {
   233  			buf, err := conn.Next(-1)
   234  			Equal(t, 0, len(buf))
   235  			Nil(t, err)
   236  		},
   237  	}
   238  	clientBehavior = func(conn net.Conn) {
   239  		n, err := conn.Write(eventHandlerTestData)
   240  		Equal(t, eventHandlerTestDataSize, n)
   241  		Nil(t, err)
   242  		buffer := make([]byte, eventHandlerTestDataSize*2)
   243  		n, err = conn.Read(buffer)
   244  		Equal(t, eventHandlerTestDataSize, n)
   245  		Nil(t, err)
   246  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   247  		conn.Close()
   248  	}
   249  
   250  	testCases = createTestCases("ReadAndWrite", both, callbacks, clientBehavior, [][]int{
   251  		{1, 1, 1, 1},
   252  		{1, 1, 1, 1},
   253  		{1, 1, 1, 1},
   254  		{1, 1, 1, 1},
   255  		{0, 1, 1, 0},
   256  		{0, 1, 1, 0},
   257  	})
   258  
   259  	testEventHandler(t, testCases, true)
   260  
   261  	callbacks = callbacksHolder{
   262  		onReadCallback: func(conn gain.Conn, n int, network string) {
   263  			buffer, err := conn.Next(-1)
   264  			Nil(t, err)
   265  			Equal(t, eventHandlerTestData, buffer)
   266  			bytesWritten, err := conn.Write(buffer)
   267  			Nil(t, err)
   268  			Equal(t, eventHandlerTestDataSize, bytesWritten)
   269  			err = conn.Close()
   270  			Nil(t, err)
   271  		},
   272  		onWriteCallback: func(conn gain.Conn, n int, network string) {
   273  			buf, err := conn.Next(-1)
   274  			Equal(t, 0, len(buf))
   275  			Equal(t, gainErrors.ErrConnectionClosed, err)
   276  		},
   277  	}
   278  	clientBehavior = func(conn net.Conn) {
   279  		n, err := conn.Write(eventHandlerTestData)
   280  		Equal(t, eventHandlerTestDataSize, n)
   281  		Nil(t, err)
   282  		buffer := make([]byte, eventHandlerTestDataSize*2)
   283  		n, err = conn.Read(buffer)
   284  		Equal(t, eventHandlerTestDataSize, n)
   285  		Nil(t, err)
   286  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   287  	}
   288  
   289  	testCases = createTestCases("ReadWriteAndClose", tcp, callbacks, clientBehavior, [][]int{
   290  		{1, 1, 1, 1},
   291  		{1, 1, 1, 1},
   292  		{1, 1, 1, 1},
   293  		{1, 1, 1, 1},
   294  	})
   295  
   296  	testEventHandler(t, testCases, true)
   297  }
   298  
   299  func TestEventHandlerOnAccept(t *testing.T) {
   300  	callbacks := callbacksHolder{
   301  		onAcceptCallback: func(conn gain.Conn, network string) {
   302  			err := conn.SetLinger(0)
   303  			Nil(t, err)
   304  			err = conn.Close()
   305  			Nil(t, err)
   306  		},
   307  	}
   308  	clientBehavior := func(conn net.Conn) {
   309  		if conn != nil {
   310  			time.Sleep(time.Millisecond * 50)
   311  			n, err := conn.Write(eventHandlerTestData)
   312  			Equal(t, 0, n)
   313  			NotNil(t, err)
   314  		}
   315  	}
   316  
   317  	testCases := createTestCases("JustClose", tcp, callbacks, clientBehavior, [][]int{
   318  		{1, 0, 0, 1},
   319  		{1, 0, 0, 1},
   320  		{1, 0, 0, 1},
   321  		{1, 0, 0, 1},
   322  	})
   323  
   324  	testEventHandler(t, testCases, true)
   325  
   326  	callbacks = callbacksHolder{
   327  		onAcceptCallback: func(conn gain.Conn, network string) {
   328  			err := conn.SetLinger(0)
   329  			Nil(t, err)
   330  			err = conn.Close()
   331  			Nil(t, err)
   332  		},
   333  		onReadCallback: func(conn gain.Conn, n int, network string) {
   334  			buffer, err := conn.Next(n)
   335  			Nil(t, err)
   336  			Equal(t, eventHandlerTestData, buffer)
   337  			bytesWritten, err := conn.Write(buffer)
   338  			Nil(t, err)
   339  			Equal(t, eventHandlerTestDataSize, bytesWritten)
   340  			err = conn.Close()
   341  			Nil(t, err)
   342  		},
   343  	}
   344  	clientBehavior = func(conn net.Conn) {
   345  		n, err := conn.Write(eventHandlerTestData)
   346  		Equal(t, eventHandlerTestDataSize, n)
   347  		Nil(t, err)
   348  		buffer := make([]byte, eventHandlerTestDataSize*2)
   349  		n, err = conn.Read(buffer)
   350  		Equal(t, eventHandlerTestDataSize, n)
   351  		Nil(t, err)
   352  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   353  	}
   354  
   355  	testCases = createTestCases("JustClose", udp, callbacks, clientBehavior, [][]int{
   356  		{0, 1, 1, 0},
   357  		{0, 1, 1, 0},
   358  	})
   359  
   360  	testEventHandler(t, testCases, true)
   361  
   362  	callbacks = callbacksHolder{
   363  		onAcceptCallback: func(conn gain.Conn, network string) {
   364  			n, err := conn.Write(eventHandlerTestData)
   365  			Nil(t, err)
   366  			Equal(t, eventHandlerTestDataSize, n)
   367  		},
   368  	}
   369  	clientBehavior = func(conn net.Conn) {
   370  		buffer := make([]byte, eventHandlerTestDataSize*2)
   371  		n, err := conn.Read(buffer)
   372  		Equal(t, eventHandlerTestDataSize, n)
   373  		Nil(t, err)
   374  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   375  		conn.Close()
   376  	}
   377  
   378  	testCases = createTestCases("Write", tcp, callbacks, clientBehavior, [][]int{
   379  		{1, 0, 1, 1},
   380  		{1, 0, 1, 1},
   381  		{1, 0, 1, 1},
   382  		{1, 0, 1, 1},
   383  	})
   384  
   385  	testEventHandler(t, testCases, true)
   386  
   387  	callbacks = callbacksHolder{
   388  		onAcceptCallback: func(conn gain.Conn, network string) {
   389  			n, err := conn.Write(eventHandlerTestData)
   390  			Nil(t, err)
   391  			Equal(t, eventHandlerTestDataSize, n)
   392  			err = conn.Close()
   393  			Nil(t, err)
   394  		},
   395  	}
   396  	clientBehavior = func(conn net.Conn) {
   397  		buffer := make([]byte, eventHandlerTestDataSize*2)
   398  		n, err := conn.Read(buffer)
   399  		Equal(t, eventHandlerTestDataSize, n)
   400  		Nil(t, err)
   401  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   402  	}
   403  
   404  	testCases = createTestCases("WriteAndClose", tcp, callbacks, clientBehavior, [][]int{
   405  		{1, 0, 1, 1},
   406  		{1, 0, 1, 1},
   407  		{1, 0, 1, 1},
   408  		{1, 0, 1, 1},
   409  	})
   410  
   411  	testEventHandler(t, testCases, true)
   412  }
   413  
   414  func TestEventHandlerOnWrite(t *testing.T) {
   415  	callbacks := callbacksHolder{
   416  		onAcceptCallback: func(conn gain.Conn, network string) {
   417  			var once sync.Once
   418  			conn.SetContext(&once)
   419  		},
   420  		onReadCallback: func(conn gain.Conn, n int, network string) {
   421  			buffer, err := conn.Next(n)
   422  			Nil(t, err)
   423  			Equal(t, eventHandlerTestData, buffer)
   424  			bytesWritten, err := conn.Write(buffer)
   425  			Nil(t, err)
   426  			Equal(t, eventHandlerTestDataSize, bytesWritten)
   427  		},
   428  		onWriteCallback: func(conn gain.Conn, n int, network string) {
   429  			time.Sleep(time.Millisecond * 100)
   430  			once, ok := conn.Context().(*sync.Once)
   431  			if !ok {
   432  				log.Panic()
   433  			}
   434  
   435  			once.Do(func() {
   436  				bytesWritten, err := conn.Write(eventHandlerTestData)
   437  				Nil(t, err)
   438  				Equal(t, eventHandlerTestDataSize, bytesWritten)
   439  			})
   440  		},
   441  	}
   442  	clientBehavior := func(conn net.Conn) {
   443  		n, err := conn.Write(eventHandlerTestData)
   444  		Equal(t, eventHandlerTestDataSize, n)
   445  		Nil(t, err)
   446  
   447  		for i := 0; i < 2; i++ {
   448  			buffer := make([]byte, eventHandlerTestDataSize*2)
   449  			n, err = conn.Read(buffer)
   450  			Equal(t, eventHandlerTestDataSize, n)
   451  			Nil(t, err)
   452  			Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   453  		}
   454  
   455  		conn.Close()
   456  	}
   457  
   458  	testCases := createTestCases("AdditionalWrite", tcp, callbacks, clientBehavior, [][]int{
   459  		{1, 1, 2, 1},
   460  		{1, 1, 2, 1},
   461  		{1, 1, 2, 1},
   462  		{1, 1, 2, 1},
   463  	})
   464  
   465  	testEventHandler(t, testCases, true)
   466  }
   467  
   468  func TestEventHandlerSetConnectionProperties(t *testing.T) {
   469  	setConnectionProperties := func(conn gain.Conn, network string) {
   470  		err := conn.SetLinger(30)
   471  		NoError(t, err)
   472  		err = conn.SetReadBuffer(2048)
   473  		NoError(t, err)
   474  		err = conn.SetWriteBuffer(2048)
   475  		NoError(t, err)
   476  
   477  		if network == gainNet.TCP {
   478  			err = conn.SetKeepAlivePeriod(time.Minute)
   479  			NoError(t, err)
   480  			err = conn.SetNoDelay(true)
   481  			NoError(t, err)
   482  		}
   483  	}
   484  	callbacks := callbacksHolder{
   485  		onAcceptCallback: func(conn gain.Conn, network string) {
   486  			setConnectionProperties(conn, network)
   487  		},
   488  		onReadCallback: func(conn gain.Conn, n int, network string) {
   489  			setConnectionProperties(conn, network)
   490  			buffer, err := conn.Next(n)
   491  			Nil(t, err)
   492  			Equal(t, eventHandlerTestData, buffer)
   493  			bytesWritten, err := conn.Write(buffer)
   494  			Nil(t, err)
   495  			Equal(t, eventHandlerTestDataSize, bytesWritten)
   496  		},
   497  		onWriteCallback: func(conn gain.Conn, n int, network string) {
   498  			setConnectionProperties(conn, network)
   499  			conn.Close()
   500  		},
   501  	}
   502  	clientBehavior := func(conn net.Conn) {
   503  		n, err := conn.Write(eventHandlerTestData)
   504  		Equal(t, eventHandlerTestDataSize, n)
   505  		Nil(t, err)
   506  
   507  		buffer := make([]byte, eventHandlerTestDataSize*2)
   508  		n, err = conn.Read(buffer)
   509  		Equal(t, eventHandlerTestDataSize, n)
   510  		Nil(t, err)
   511  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   512  	}
   513  
   514  	testCases := createTestCases("All", both, callbacks, clientBehavior, [][]int{
   515  		{1, 1, 1, 1},
   516  		{1, 1, 1, 1},
   517  		{1, 1, 1, 1},
   518  		{1, 1, 1, 1},
   519  		{0, 1, 1, 0},
   520  		{0, 1, 1, 0},
   521  	})
   522  
   523  	testEventHandler(t, testCases, true)
   524  }
   525  
   526  func TestEventHandlerAsyncShutdown(t *testing.T) {
   527  	var server gain.Server
   528  	callbacks := callbacksHolder{
   529  		onStartCallback: func(s gain.Server, network string) {
   530  			server = s
   531  		},
   532  		onReadCallback: func(conn gain.Conn, n int, network string) {
   533  			buffer, err := conn.Next(n)
   534  			Nil(t, err)
   535  			Equal(t, eventHandlerTestData, buffer)
   536  			bytesWritten, err := conn.Write(buffer)
   537  			Nil(t, err)
   538  			Equal(t, eventHandlerTestDataSize, bytesWritten)
   539  		},
   540  	}
   541  	clientBehavior := func(conn net.Conn) {
   542  		n, err := conn.Write(eventHandlerTestData)
   543  		Equal(t, eventHandlerTestDataSize, n)
   544  		Nil(t, err)
   545  
   546  		buffer := make([]byte, eventHandlerTestDataSize*2)
   547  		n, err = conn.Read(buffer)
   548  		Equal(t, eventHandlerTestDataSize, n)
   549  		Nil(t, err)
   550  		Equal(t, eventHandlerTestData, buffer[:eventHandlerTestDataSize])
   551  
   552  		conn.Close()
   553  
   554  		server.AsyncShutdown()
   555  	}
   556  
   557  	testCases := createTestCases("All", both, callbacks, clientBehavior, [][]int{
   558  		{1, 1, 1, 1},
   559  		{1, 1, 1, 1},
   560  		{1, 1, 1, 1},
   561  		{1, 1, 1, 1},
   562  		{0, 1, 1, 0},
   563  		{0, 1, 1, 0},
   564  	})
   565  
   566  	testEventHandler(t, testCases, false)
   567  }