github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/common_test.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gain_test 16 17 import ( 18 "crypto/rand" 19 "errors" 20 "fmt" 21 "io" 22 "log" 23 "net" 24 "os" 25 "sync" 26 "sync/atomic" 27 "testing" 28 "time" 29 30 "github.com/pawelgaczynski/gain" 31 gainErrors "github.com/pawelgaczynski/gain/pkg/errors" 32 gainNet "github.com/pawelgaczynski/gain/pkg/net" 33 . "github.com/stretchr/testify/require" 34 ) 35 36 type testServerConfig struct { 37 protocol string 38 numberOfClients int 39 numberOfWorkers int 40 cpuAffinity bool 41 asyncHandler bool 42 goroutinePool bool 43 waitForDialAllClients bool 44 afterDial afterDialCallback 45 writesCount int 46 configOptions []gain.ConfigOption 47 48 readHandler onReadCallback 49 } 50 51 var defaultTestOnReadCallback = func(c gain.Conn, n int, network string) { 52 buffer := make([]byte, 128) 53 54 _, err := c.Read(buffer) 55 if err != nil { 56 if errors.Is(err, gainErrors.ErrIsEmpty) { 57 return 58 } 59 60 log.Panic(err) 61 } 62 63 if string(buffer[0:6]) != "cindex" { 64 log.Panic(fmt.Errorf("unexpected data: %s", string(buffer[0:6]))) 65 } 66 67 _, err = c.Write(append(buffer[0:10], []byte("TESTpayload12345")...)) 68 if err != nil { 69 log.Panic(err) 70 } 71 } 72 73 type callbacksHolder struct { 74 onStartCallback onStartCallback 75 onAcceptCallback onAcceptCallback 76 onReadCallback onReadCallback 77 onWriteCallback onWriteCallback 78 onCloseCallback onCloseCallback 79 } 80 81 type testServerHandler struct { 82 callbacksHolder 83 84 onStartCount atomic.Uint32 85 onAcceptCount atomic.Uint32 86 onReadCount atomic.Uint32 87 onWriteCount atomic.Uint32 88 onCloseCount atomic.Uint32 89 90 startedWg *sync.WaitGroup 91 onAcceptWg *sync.WaitGroup 92 onReadWg *sync.WaitGroup 93 onWriteWg *sync.WaitGroup 94 onCloseWg *sync.WaitGroup 95 96 finished atomic.Bool 97 98 network string 99 } 100 101 func (h *testServerHandler) OnStart(server gain.Server) { 102 if !h.finished.Load() { 103 h.startedWg.Done() 104 105 if h.onStartCallback != nil { 106 h.onStartCallback(server, h.network) 107 } 108 109 h.onStartCount.Add(1) 110 } 111 } 112 113 func (h *testServerHandler) OnAccept(c gain.Conn) { 114 if !h.finished.Load() { 115 if h.onAcceptCallback != nil { 116 h.onAcceptCallback(c, h.network) 117 } 118 119 h.onAcceptCount.Add(1) 120 121 if h.onAcceptWg != nil { 122 h.onAcceptWg.Done() 123 } 124 } 125 } 126 127 func (h *testServerHandler) OnClose(c gain.Conn, err error) { 128 if !h.finished.Load() { 129 if h.onCloseCallback != nil { 130 h.onCloseCallback(c, err, h.network) 131 } 132 133 h.onCloseCount.Add(1) 134 135 if h.onCloseWg != nil { 136 h.onCloseWg.Done() 137 } 138 } 139 } 140 141 func (h *testServerHandler) OnRead(conn gain.Conn, n int) { 142 if !h.finished.Load() { 143 if h.onReadCallback != nil { 144 h.onReadCallback(conn, n, h.network) 145 } 146 147 h.onReadCount.Add(1) 148 149 if h.onReadWg != nil { 150 h.onReadWg.Done() 151 } 152 } 153 } 154 155 func (h *testServerHandler) OnWrite(c gain.Conn, n int) { 156 if !h.finished.Load() { 157 if h.onWriteCallback != nil { 158 h.onWriteCallback(c, n, h.network) 159 } 160 161 h.onWriteCount.Add(1) 162 163 if h.onWriteWg != nil { 164 h.onWriteWg.Done() 165 } 166 } 167 } 168 169 type afterDialCallback func(*testing.T, net.Conn, int, int) 170 171 var deafultAfterDial = func(t *testing.T, conn net.Conn, repeats, clientIndex int) { 172 t.Helper() 173 err := conn.SetDeadline(time.Now().Add(time.Second * 2)) 174 Nil(t, err) 175 176 clientIndexBytes := []byte(fmt.Sprintf("cindex%04d", clientIndex)) 177 178 for i := 0; i < repeats; i++ { 179 var bytesWritten int 180 bytesWritten, err = conn.Write(append(clientIndexBytes, []byte("testdata1234567890")...)) 181 182 Nil(t, err) 183 Equal(t, 28, bytesWritten) 184 var buffer [64]byte 185 var bytesRead int 186 bytesRead, err = conn.Read(buffer[:]) 187 188 Nil(t, err) 189 Equal(t, 26, bytesRead) 190 Equal(t, string(append(clientIndexBytes, "TESTpayload12345"...)), 191 string(buffer[:bytesRead]), "CONNFD: %d", getFdFromConn(conn)) 192 } 193 } 194 195 func dialClient(t *testing.T, protocol string, port int, clientConnChan chan net.Conn) { 196 t.Helper() 197 conn, err := net.DialTimeout(protocol, fmt.Sprintf("127.0.0.1:%d", port), time.Second) 198 Nil(t, err) 199 NotNil(t, conn) 200 clientConnChan <- conn 201 } 202 203 func dialClientRW(t *testing.T, protocol string, port int, 204 afterDial afterDialCallback, repeats, clientIndex int, clientConnChan chan net.Conn, 205 ) { 206 t.Helper() 207 conn, err := net.DialTimeout(protocol, fmt.Sprintf("127.0.0.1:%d", port), 2*time.Second) 208 Nil(t, err) 209 NotNil(t, conn) 210 afterDial(t, conn, repeats, clientIndex) 211 clientConnChan <- conn 212 } 213 214 func newTestServerHandler(onReadCallback onReadCallback, network string) *testServerHandler { 215 testHandler := &testServerHandler{ 216 network: network, 217 } 218 219 var startedWg sync.WaitGroup 220 221 startedWg.Add(1) 222 testHandler.startedWg = &startedWg 223 224 if onReadCallback != nil { 225 testHandler.onReadCallback = onReadCallback 226 } else { 227 testHandler.onReadCallback = defaultTestOnReadCallback 228 } 229 230 return testHandler 231 } 232 233 func testServer(t *testing.T, testConfig testServerConfig, architecture gain.ServerArchitecture) { 234 t.Helper() 235 236 if testConfig.protocol == "" { 237 log.Panic("network protocol is missing") 238 } 239 opts := []gain.ConfigOption{ 240 gain.WithLoggerLevel(getTestLoggerLevel()), 241 gain.WithAsyncHandler(testConfig.asyncHandler), 242 gain.WithGoroutinePool(testConfig.goroutinePool), 243 gain.WithCPUAffinity(testConfig.cpuAffinity), 244 gain.WithWorkers(testConfig.numberOfWorkers), 245 gain.WithCBPF(false), 246 gain.WithArchitecture(architecture), 247 } 248 249 if testConfig.configOptions != nil { 250 opts = append(opts, testConfig.configOptions...) 251 } 252 253 config := gain.NewConfig(opts...) 254 255 testHandler := newTestServerHandler(testConfig.readHandler, testConfig.protocol) 256 257 server := gain.NewServer(testHandler, config) 258 259 defer func() { 260 server.Shutdown() 261 }() 262 testPort := getTestPort() 263 264 go func() { 265 err := server.Start(fmt.Sprintf("%s://127.0.0.1:%d", testConfig.protocol, testPort)) 266 if err != nil { 267 log.Panic(err) 268 } 269 }() 270 271 clientConnChan := make(chan net.Conn, testConfig.numberOfClients) 272 273 testHandler.startedWg.Wait() 274 275 if testConfig.waitForDialAllClients { 276 clientConnectWG := new(sync.WaitGroup) 277 clientConnectWG.Add(testConfig.numberOfClients) 278 testHandler.onAcceptCallback = func(c gain.Conn, _ string) { 279 clientConnectWG.Done() 280 } 281 282 for i := 0; i < testConfig.numberOfClients; i++ { 283 go dialClient(t, testConfig.protocol, testPort, clientConnChan) 284 } 285 clientConnectWG.Wait() 286 Equal(t, testConfig.numberOfClients, server.ActiveConnections()) 287 288 for i := 0; i < testConfig.numberOfClients; i++ { 289 conn := <-clientConnChan 290 NotNil(t, conn) 291 292 if tcpConn, ok := conn.(*net.TCPConn); ok { 293 err := tcpConn.SetLinger(0) 294 Nil(t, err) 295 } 296 } 297 } else { 298 var clientConnectWG *sync.WaitGroup 299 if testConfig.protocol == gainNet.TCP { 300 clientConnectWG = new(sync.WaitGroup) 301 clientConnectWG.Add(testConfig.numberOfClients) 302 } 303 clientRWWG := new(sync.WaitGroup) 304 if testConfig.writesCount == 0 { 305 testConfig.writesCount = 1 306 } 307 clientRWWG.Add(testConfig.numberOfClients * testConfig.writesCount) 308 if testConfig.protocol == gainNet.TCP { 309 testHandler.onAcceptCallback = func(c gain.Conn, _ string) { 310 clientConnectWG.Done() 311 } 312 } 313 testHandler.onWriteCallback = func(c gain.Conn, n int, network string) { 314 clientRWWG.Done() 315 } 316 afterDial := deafultAfterDial 317 if testConfig.afterDial != nil { 318 afterDial = testConfig.afterDial 319 } 320 for i := 0; i < testConfig.numberOfClients; i++ { 321 go func(clientIndex int) { 322 dialClientRW(t, testConfig.protocol, testPort, afterDial, testConfig.writesCount, clientIndex, clientConnChan) 323 }(i) 324 } 325 if testConfig.protocol == gainNet.TCP { 326 clientConnectWG.Wait() 327 } 328 clientRWWG.Wait() 329 for i := 0; i < testConfig.numberOfClients; i++ { 330 conn := <-clientConnChan 331 NotNil(t, conn) 332 if tcpConn, ok := conn.(*net.TCPConn); ok { 333 err := tcpConn.SetLinger(0) 334 Nil(t, err) 335 } 336 } 337 } 338 } 339 340 var randomDataSize128 = make([]byte, 128) 341 342 type RingBufferTestDataHandler struct { 343 t *testing.T 344 testFinished atomic.Bool 345 } 346 347 func (r *RingBufferTestDataHandler) OnRead(conn gain.Conn, _ int, _ string) { 348 buffer := make([]byte, 128) 349 bytesRead, readErr := conn.Read(buffer) 350 351 if !r.testFinished.Load() { 352 Equal(r.t, 128, bytesRead) 353 354 if readErr != nil { 355 log.Panic(readErr) 356 } 357 bytesWritten, writeErr := conn.Write(randomDataSize128) 358 Equal(r.t, 128, bytesWritten) 359 360 if writeErr != nil { 361 log.Panic(writeErr) 362 } 363 } 364 } 365 366 func testRingBuffer(t *testing.T, protocol string, architecture gain.ServerArchitecture) { 367 t.Helper() 368 handler := RingBufferTestDataHandler{ 369 t: t, 370 } 371 bytesRandom, err := rand.Read(randomDataSize128) 372 Nil(t, err) 373 Equal(t, 128, bytesRandom) 374 writesCount := 1000 375 testServer(t, testServerConfig{ 376 numberOfClients: 1, 377 numberOfWorkers: 1, 378 protocol: protocol, 379 readHandler: handler.OnRead, 380 writesCount: writesCount, 381 afterDial: func(t *testing.T, conn net.Conn, _, _ int) { 382 t.Helper() 383 deadlineErr := conn.SetDeadline(time.Now().Add(time.Second * 1)) 384 Nil(t, deadlineErr) 385 var buffer [256]byte 386 for i := 0; i < writesCount; i++ { 387 bytesWritten, writeErr := conn.Write(randomDataSize128) 388 Nil(t, writeErr) 389 Equal(t, 128, bytesWritten) 390 bytesRead, readErr := conn.Read(buffer[:]) 391 Nil(t, readErr) 392 Equal(t, 128, bytesRead) 393 Equal(t, randomDataSize128, buffer[:bytesRead]) 394 } 395 handler.testFinished.Store(true) 396 }, 397 }, architecture) 398 } 399 400 func testCloseServer(t *testing.T, network string, architecture gain.ServerArchitecture, doubleShutdown bool) { 401 t.Helper() 402 testHandler := newConnServerTester(network, 10, false) 403 server, port := newTestConnServer(t, network, false, architecture, testHandler.testServerHandler) 404 clientsGroup := newTestConnClientGroup(t, network, port, 10) 405 clientsGroup.Dial() 406 407 data := make([]byte, 512) 408 409 _, err := rand.Read(data) 410 Nil(t, err) 411 clientsGroup.SetDeadline(time.Now().Add(time.Second)) 412 clientsGroup.Write(data) 413 buffer := make([]byte, 512) 414 415 clientsGroup.SetDeadline(time.Now().Add(time.Second)) 416 clientsGroup.Read(buffer) 417 418 clientsGroup.SetDeadline(time.Time{}) 419 420 testHandler.waitForWrites() 421 clientsGroup.Close() 422 server.Shutdown() 423 424 if doubleShutdown { 425 server.Shutdown() 426 } 427 } 428 429 func testCloseServerWithConnectedClients(t *testing.T, architecture gain.ServerArchitecture) { 430 t.Helper() 431 testHandler := newConnServerTester(gainNet.TCP, 10, false) 432 server, port := newTestConnServer(t, gainNet.TCP, false, architecture, testHandler.testServerHandler) 433 434 clientsGroup := newTestConnClientGroup(t, gainNet.TCP, port, 10) 435 clientsGroup.Dial() 436 437 data := make([]byte, 1024) 438 _, err := rand.Read(data) 439 Nil(t, err) 440 clientsGroup.Write(data) 441 buffer := make([]byte, 1024) 442 clientsGroup.Read(buffer) 443 444 testHandler.waitForWrites() 445 server.Shutdown() 446 } 447 448 func testCloseConn(t *testing.T, async bool, architecture gain.ServerArchitecture, justClose bool) { 449 t.Helper() 450 testHandler := newTestServerHandler(func(conn gain.Conn, n int, network string) { 451 if !justClose { 452 buf, err := conn.Next(n) 453 if err != nil { 454 log.Panic(err) 455 } 456 457 _, err = conn.Write(buf) 458 if err != nil { 459 log.Panic(err) 460 } 461 } 462 463 err := conn.Close() 464 if err != nil { 465 log.Panic(err) 466 } 467 }, gainNet.TCP) 468 469 server, port := newTestConnServer(t, gainNet.TCP, async, architecture, testHandler) 470 471 var clientDoneWg sync.WaitGroup 472 473 clientDoneWg.Add(1) 474 475 go func(wg *sync.WaitGroup) { 476 conn, cErr := net.DialTimeout(gainNet.TCP, fmt.Sprintf("127.0.0.1:%d", port), time.Second) 477 Nil(t, cErr) 478 NotNil(t, conn) 479 testData := []byte("testdata1234567890") 480 bytesN, cErr := conn.Write(testData) 481 Nil(t, cErr) 482 Equal(t, len(testData), bytesN) 483 buffer := make([]byte, len(testData)) 484 bytesN, cErr = conn.Read(buffer) 485 486 if !justClose { 487 Nil(t, cErr) 488 Equal(t, len(testData), bytesN) 489 Equal(t, testData, buffer) 490 bytesN, cErr = conn.Write(testData) 491 Nil(t, cErr) 492 Equal(t, len(testData), bytesN) 493 bytesN, cErr = conn.Read(buffer) 494 } 495 496 Equal(t, io.EOF, cErr) 497 Equal(t, 0, bytesN) 498 wg.Done() 499 }(&clientDoneWg) 500 501 clientDoneWg.Wait() 502 server.Shutdown() 503 } 504 505 func testLargeRead(t *testing.T, network string, architecture gain.ServerArchitecture) { 506 t.Helper() 507 508 if !checkKernelCompatibility(5, 19) { 509 //nolint 510 fmt.Println("Not supported by kernel") 511 512 return 513 } 514 515 doublePageSize := os.Getpagesize() * 4 516 data := make([]byte, doublePageSize) 517 _, err := rand.Read(data) 518 Nil(t, err) 519 520 var doneWg sync.WaitGroup 521 522 doneWg.Add(1) 523 onReadCallback := func(c gain.Conn, _ int, _ string) { 524 readBuffer := make([]byte, doublePageSize) 525 526 n, cErr := c.Read(readBuffer) 527 if err != nil { 528 log.Panic(cErr) 529 } 530 531 doneWg.Done() 532 Equal(t, doublePageSize, n) 533 534 n, cErr = c.Write(readBuffer) 535 if cErr != nil { 536 log.Panic(cErr) 537 } 538 539 Equal(t, doublePageSize, n) 540 } 541 542 testConnHandler := newTestServerHandler(onReadCallback, network) 543 server, port := newTestConnServer(t, network, false, architecture, testConnHandler) 544 545 clientsGroup := newTestConnClientGroup(t, network, port, 1) 546 clientsGroup.Dial() 547 548 clientsGroup.Write(data) 549 buffer := make([]byte, len(data)) 550 clientsGroup.Read(buffer) 551 552 Equal(t, data, buffer) 553 554 doneWg.Wait() 555 556 server.Shutdown() 557 } 558 559 func testMultipleReads(t *testing.T, network string, asyncHandler bool, architecture gain.ServerArchitecture) { 560 t.Helper() 561 562 pageSize := os.Getpagesize() 563 data := make([]byte, pageSize) 564 _, err := rand.Read(data) 565 Nil(t, err) 566 567 var ( 568 doneWg sync.WaitGroup 569 expectedReads int64 = 10 570 readsCount atomic.Int64 571 ) 572 573 doneWg.Add(int(expectedReads)) 574 onReadCallback := func(c gain.Conn, _ int, _ string) { 575 readBuffer := make([]byte, pageSize) 576 577 n, cErr := c.Read(readBuffer) 578 if err != nil { 579 log.Panic(cErr) 580 } 581 582 readsCount.Add(1) 583 doneWg.Done() 584 Equal(t, pageSize, n) 585 } 586 587 testConnHandler := newTestServerHandler(onReadCallback, network) 588 server, port := newTestConnServer(t, network, asyncHandler, architecture, testConnHandler) 589 590 clientsGroup := newTestConnClientGroup(t, network, port, 1) 591 clientsGroup.Dial() 592 593 go func() { 594 for i := 0; i < int(expectedReads); i++ { 595 clientsGroup.Write(data) 596 time.Sleep(time.Millisecond * 100) 597 } 598 }() 599 600 doneWg.Wait() 601 602 Equal(t, expectedReads, readsCount.Load()) 603 604 server.Shutdown() 605 }