github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/common_conn_test.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gain_test 16 17 import ( 18 "crypto/rand" 19 "fmt" 20 "log" 21 "net" 22 "os/exec" 23 "strconv" 24 "strings" 25 "sync" 26 "testing" 27 "time" 28 29 "github.com/pawelgaczynski/gain" 30 . "github.com/stretchr/testify/require" 31 ) 32 33 type connServerTester struct { 34 *testServerHandler 35 mutex sync.Mutex 36 writeWG *sync.WaitGroup 37 writeCount uint32 38 targetWriteCount uint32 39 removeWGAfterMinWrites bool 40 } 41 42 func (t *connServerTester) waitForWrites() { 43 t.writeWG.Wait() 44 } 45 46 func (t *connServerTester) onReadCallback(conn gain.Conn, n int, _ string) { 47 buf, _ := conn.Next(n) 48 _, _ = conn.Write(buf) 49 } 50 51 func (t *connServerTester) onWriteCallback(_ gain.Conn, _ int, _ string) { 52 if t.writeWG != nil { 53 t.mutex.Lock() 54 55 t.writeCount++ 56 if t.writeCount >= t.targetWriteCount { 57 t.writeWG.Done() 58 59 if t.removeWGAfterMinWrites { 60 t.writeWG = nil 61 } 62 } 63 t.mutex.Unlock() 64 } 65 } 66 67 func newConnServerTester(network string, writeCount int, removeWGAfterMinWrites bool) *connServerTester { 68 connServerTester := &connServerTester{} 69 70 if writeCount > 0 { 71 var writeWG sync.WaitGroup 72 73 writeWG.Add(1) 74 connServerTester.writeWG = &writeWG 75 connServerTester.targetWriteCount = uint32(writeCount) 76 connServerTester.removeWGAfterMinWrites = removeWGAfterMinWrites 77 } 78 79 testConnHandler := newTestServerHandler(connServerTester.onReadCallback, network) 80 81 testConnHandler.onWriteCallback = connServerTester.onWriteCallback 82 connServerTester.testServerHandler = testConnHandler 83 84 return connServerTester 85 } 86 87 func newEventHandlerTester(callbacks callbacksHolder, network string) *testServerHandler { 88 testHandler := &testServerHandler{ 89 network: network, 90 } 91 92 var ( 93 startedWg sync.WaitGroup 94 onAcceptWg sync.WaitGroup 95 onReadWg sync.WaitGroup 96 onWriteWg sync.WaitGroup 97 onCloseWg sync.WaitGroup 98 ) 99 100 startedWg.Add(1) 101 testHandler.startedWg = &startedWg 102 testHandler.onAcceptWg = &onAcceptWg 103 testHandler.onReadWg = &onReadWg 104 testHandler.onWriteWg = &onWriteWg 105 testHandler.onCloseWg = &onCloseWg 106 107 testHandler.onStartCallback = callbacks.onStartCallback 108 testHandler.onAcceptCallback = callbacks.onAcceptCallback 109 testHandler.onReadCallback = callbacks.onReadCallback 110 testHandler.onWriteCallback = callbacks.onWriteCallback 111 testHandler.onCloseCallback = callbacks.onCloseCallback 112 113 return testHandler 114 } 115 116 type testConnClient struct { 117 t *testing.T 118 conn net.Conn 119 network string 120 port int 121 idx int 122 } 123 124 func (c *testConnClient) Dial() { 125 conn, err := net.DialTimeout(c.network, fmt.Sprintf("127.0.0.1:%d", c.port), time.Second) 126 Nil(c.t, err) 127 NotNil(c.t, conn) 128 c.conn = conn 129 } 130 131 func (c *testConnClient) Close() { 132 err := c.conn.Close() 133 Nil(c.t, err) 134 } 135 136 func (c *testConnClient) SetDeadline(t time.Time) { 137 err := c.conn.SetDeadline(t) 138 Nil(c.t, err) 139 } 140 141 func (c *testConnClient) Write(buffer []byte) { 142 bytesWritten, writeErr := c.conn.Write(buffer) 143 Nil(c.t, writeErr) 144 Equal(c.t, len(buffer), bytesWritten) 145 } 146 147 func (c *testConnClient) Read(buffer []byte) { 148 bytesRead, readErr := c.conn.Read(buffer) 149 Nil(c.t, readErr) 150 Equal(c.t, len(buffer), bytesRead) 151 } 152 153 func newTestConnClient(t *testing.T, idx int, network string, port int) *testConnClient { 154 t.Helper() 155 156 return &testConnClient{ 157 t: t, 158 network: network, 159 port: port, 160 idx: idx, 161 } 162 } 163 164 type testConnClientGroup struct { 165 clients []*testConnClient 166 } 167 168 func (c *testConnClientGroup) Dial() { 169 for i := 0; i < len(c.clients); i++ { 170 c.clients[i].Dial() 171 } 172 } 173 174 func (c *testConnClientGroup) Close() { 175 for i := 0; i < len(c.clients); i++ { 176 c.clients[i].Close() 177 } 178 } 179 180 func (c *testConnClientGroup) SetDeadline(t time.Time) { 181 for i := 0; i < len(c.clients); i++ { 182 c.clients[i].SetDeadline(t) 183 } 184 } 185 186 func (c *testConnClientGroup) Write(buffer []byte) { 187 for i := 0; i < len(c.clients); i++ { 188 c.clients[i].Write(buffer) 189 } 190 } 191 192 func (c *testConnClientGroup) Read(buffer []byte) { 193 for i := 0; i < len(c.clients); i++ { 194 c.clients[i].Read(buffer) 195 } 196 } 197 198 func newTestConnClientGroup(t *testing.T, network string, port int, n int) *testConnClientGroup { 199 t.Helper() 200 group := &testConnClientGroup{ 201 clients: make([]*testConnClient, n), 202 } 203 204 for i := 0; i < n; i++ { 205 group.clients[i] = newTestConnClient(t, i, network, port) 206 } 207 208 return group 209 } 210 211 func newTestConnServer( 212 t *testing.T, network string, async bool, architecture gain.ServerArchitecture, eventHandler *testServerHandler, 213 ) (gain.Server, int) { 214 t.Helper() 215 opts := []gain.ConfigOption{ 216 gain.WithLoggerLevel(getTestLoggerLevel()), 217 gain.WithWorkers(4), 218 gain.WithArchitecture(architecture), 219 gain.WithAsyncHandler(async), 220 gain.WithMaxSQEntries(1024), 221 gain.WithMaxCQEvents(1024), 222 } 223 224 config := gain.NewConfig(opts...) 225 226 server := gain.NewServer(eventHandler, config) 227 testPort := getTestPort() 228 229 go func() { 230 err := server.Start(fmt.Sprintf("%s://127.0.0.1:%d", network, testPort)) 231 if err != nil { 232 log.Panic(err) 233 } 234 }() 235 236 eventHandler.startedWg.Wait() 237 238 return server, int(port) 239 } 240 241 func getIPAndPort(addr net.Addr) (string, int) { 242 switch addr := addr.(type) { 243 case *net.UDPAddr: 244 return addr.IP.String(), addr.Port 245 case *net.TCPAddr: 246 return addr.IP.String(), addr.Port 247 } 248 249 return "", 0 250 } 251 252 func testConnAddress( 253 t *testing.T, network string, architecture gain.ServerArchitecture, 254 ) { 255 t.Helper() 256 numberOfClients := 10 257 opts := []gain.ConfigOption{ 258 gain.WithLoggerLevel(getTestLoggerLevel()), 259 gain.WithWorkers(4), 260 gain.WithArchitecture(architecture), 261 } 262 263 config := gain.NewConfig(opts...) 264 265 out, err := exec.Command("bash", "-c", "sysctl net.ipv4.ip_local_port_range | awk '{ print $3; }'").Output() 266 if err != nil { 267 log.Panic(err) 268 } 269 270 lowestEphemeralPort, err := strconv.Atoi(strings.ReplaceAll(string(out), "\n", "")) 271 if err != nil { 272 log.Panic(err) 273 } 274 275 verifyAddresses := func(t *testing.T, conn gain.Conn) { 276 t.Helper() 277 localAddr := conn.LocalAddr() 278 NotNil(t, localAddr) 279 280 ip, port := getIPAndPort(localAddr) 281 Equal(t, "127.0.0.1", ip) 282 Less(t, port, 10000) 283 GreaterOrEqual(t, port, 9000) 284 remoteAddr := conn.RemoteAddr() 285 286 ip, port = getIPAndPort(remoteAddr) 287 NotNil(t, remoteAddr) 288 Equal(t, "127.0.0.1", ip) 289 GreaterOrEqual(t, port, lowestEphemeralPort) 290 } 291 292 var wg sync.WaitGroup 293 294 wg.Add(numberOfClients) 295 296 onReadCallback := func(conn gain.Conn, n int, _ string) { 297 buf, _ := conn.Next(n) 298 _, _ = conn.Write(buf) 299 300 verifyAddresses(t, conn) 301 302 wg.Done() 303 } 304 305 testHandler := newTestServerHandler(onReadCallback, network) 306 307 server := gain.NewServer(testHandler, config) 308 testPort := getTestPort() 309 310 go func() { 311 serverErr := server.Start(fmt.Sprintf("%s://127.0.0.1:%d", network, testPort)) 312 if err != nil { 313 log.Panic(serverErr) 314 } 315 }() 316 317 testHandler.startedWg.Wait() 318 319 clientsGroup := newTestConnClientGroup(t, network, testPort, numberOfClients) 320 clientsGroup.Dial() 321 322 data := make([]byte, 1024) 323 _, err = rand.Read(data) 324 Nil(t, err) 325 clientsGroup.Write(data) 326 buffer := make([]byte, 1024) 327 clientsGroup.Read(buffer) 328 329 wg.Wait() 330 server.Shutdown() 331 }