github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/buffer/magicring/ringbuffer_iouring_test.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package magicring 16 17 import ( 18 "crypto/rand" 19 "errors" 20 "fmt" 21 "net" 22 "syscall" 23 "testing" 24 "time" 25 "unsafe" 26 27 gainNet "github.com/pawelgaczynski/gain/pkg/net" 28 "github.com/pawelgaczynski/giouring" 29 . "github.com/stretchr/testify/require" 30 "golang.org/x/sys/unix" 31 ) 32 33 const ( 34 accept = iota 35 recv 36 send 37 ) 38 39 type conn struct { 40 fd uint64 41 inboundBuffer *RingBuffer 42 outboundBuffer *RingBuffer 43 state int 44 } 45 46 func loop(t *testing.T, ring *giouring.Ring, socketFd int, connection *conn, testCase *testCase) bool { 47 t.Helper() 48 49 cqe, err := ring.WaitCQE() 50 if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) || 51 errors.Is(err, syscall.ETIME) { 52 return false 53 } 54 55 Nil(t, err) 56 entry := ring.GetSQE() 57 NotNil(t, entry) 58 ring.CQESeen(cqe) 59 60 switch connection.state { 61 case accept: 62 Equal(t, uint64(socketFd), cqe.UserData) 63 Greater(t, cqe.Res, int32(0)) 64 connection.fd = uint64(cqe.Res) 65 entry.PrepareRecv( 66 int(connection.fd), 67 uintptr(connection.inboundBuffer.WriteAddress()), 68 uint32(connection.inboundBuffer.Available()), 69 0) 70 entry.UserData = connection.fd 71 connection.state = recv 72 73 case recv: 74 var data []byte 75 if testCase.recvIdx == 0 { 76 data = testCase.halfLenData 77 } else { 78 data = testCase.wholeLenData 79 } 80 testCase.recvIdx++ 81 82 Equal(t, connection.fd, cqe.UserData) 83 Equal(t, int32(len(data)), cqe.Res) 84 85 connection.inboundBuffer.AdvanceWrite(int(cqe.Res)) 86 readBuf := make([]byte, DefaultMagicBufferSize) 87 88 var bytesRead int 89 bytesRead, err = connection.inboundBuffer.Read(readBuf) 90 Nil(t, err) 91 Equal(t, len(data), bytesRead) 92 Equal(t, data, readBuf[:cqe.Res]) 93 94 var bytesWritten int 95 bytesWritten, err = connection.outboundBuffer.Write(data) 96 Nil(t, err) 97 Equal(t, len(data), bytesWritten) 98 99 entry.PrepareSend( 100 int(connection.fd), 101 uintptr(connection.outboundBuffer.ReadAddress()), 102 uint32(connection.outboundBuffer.Buffered()), 103 0) 104 entry.UserData = connection.fd 105 connection.state = send 106 107 case send: 108 var res int32 109 if testCase.sendIdx == 0 { 110 res = int32(DefaultMagicBufferSize / 2) 111 } else { 112 res = int32(DefaultMagicBufferSize) 113 } 114 115 Equal(t, connection.fd, cqe.UserData) 116 Equal(t, res, cqe.Res) 117 118 connection.outboundBuffer.AdvanceRead(int(cqe.Res)) 119 120 if testCase.sendIdx == 0 { 121 entry.PrepareRecv( 122 int(connection.fd), 123 uintptr(connection.inboundBuffer.WriteAddress()), 124 uint32(connection.inboundBuffer.Available()), 125 0) 126 entry.UserData = connection.fd 127 connection.state = recv 128 testCase.sendIdx++ 129 } else { 130 err = syscall.Shutdown(int(connection.fd), syscall.SHUT_RDWR) 131 Nil(t, err) 132 133 return true 134 } 135 } 136 cqeNr, err := ring.Submit() 137 Nil(t, err) 138 Equal(t, uint(1), cqeNr) 139 140 return false 141 } 142 143 type testCase struct { 144 halfLenData []byte 145 wholeLenData []byte 146 recvIdx int 147 sendIdx int 148 } 149 150 func TestMagicRingRecvSend(t *testing.T) { 151 ring, err := giouring.CreateRing(16) 152 Nil(t, err) 153 154 defer ring.QueueExit() 155 156 socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) 157 Nil(t, err) 158 err = syscall.SetsockoptInt(socketFd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) 159 Nil(t, err) 160 err = syscall.Bind(socketFd, &syscall.SockaddrInet4{ 161 Port: 9876, 162 }) 163 Nil(t, err) 164 err = syscall.SetNonblock(socketFd, false) 165 Nil(t, err) 166 err = syscall.Listen(socketFd, 128) 167 Nil(t, err) 168 169 defer func() { 170 closeErr := syscall.Close(socketFd) 171 Nil(t, closeErr) 172 }() 173 174 entry := ring.GetSQE() 175 NotNil(t, entry) 176 clientLen := new(uint32) 177 clientAddr := &unix.RawSockaddrAny{} 178 *clientLen = unix.SizeofSockaddrAny 179 clientAddrPointer := uintptr(unsafe.Pointer(clientAddr)) 180 clientLenPointer := uint64(uintptr(unsafe.Pointer(clientLen))) 181 entry.PrepareAccept(int(uintptr(socketFd)), clientAddrPointer, clientLenPointer, 0) 182 entry.UserData = uint64(socketFd) 183 cqeNr, err := ring.Submit() 184 Nil(t, err) 185 Equal(t, uint(1), cqeNr) 186 187 wholeLenData := make([]byte, DefaultMagicBufferSize) 188 halfLenData := make([]byte, DefaultMagicBufferSize/2) 189 bytesRead, err := rand.Read(wholeLenData) 190 Nil(t, err) 191 Equal(t, DefaultMagicBufferSize, bytesRead) 192 bytesRead, err = rand.Read(halfLenData) 193 Nil(t, err) 194 Equal(t, DefaultMagicBufferSize/2, bytesRead) 195 connection := &conn{ 196 state: accept, 197 inboundBuffer: NewMagicBuffer(DefaultMagicBufferSize), 198 outboundBuffer: NewMagicBuffer(DefaultMagicBufferSize), 199 } 200 201 clientConnChan := make(chan net.Conn) 202 go func() { 203 conn, cErr := net.DialTimeout(gainNet.TCP, fmt.Sprintf("127.0.0.1:%d", 9876), time.Second) 204 Nil(t, cErr) 205 NotNil(t, conn) 206 207 var bytesWritten int 208 bytesWritten, cErr = conn.Write(halfLenData) 209 Nil(t, cErr) 210 Equal(t, DefaultMagicBufferSize/2, bytesWritten) 211 buffer := make([]byte, DefaultMagicBufferSize) 212 bytesWritten, cErr = conn.Read(buffer) 213 Nil(t, cErr) 214 Equal(t, len(halfLenData), bytesWritten) 215 Equal(t, halfLenData, buffer[:DefaultMagicBufferSize/2]) 216 bytesWritten, cErr = conn.Write(wholeLenData) 217 Nil(t, cErr) 218 Equal(t, DefaultMagicBufferSize, bytesWritten) 219 bytesWritten, cErr = conn.Read(buffer) 220 Nil(t, cErr) 221 Equal(t, len(wholeLenData), bytesWritten) 222 Equal(t, wholeLenData, buffer[:DefaultMagicBufferSize]) 223 224 clientConnChan <- conn 225 }() 226 227 defer func() { 228 conn := <-clientConnChan 229 if tcpConn, ok := conn.(*net.TCPConn); ok { 230 lErr := tcpConn.SetLinger(0) 231 Nil(t, lErr) 232 } 233 }() 234 235 testCase := &testCase{ 236 halfLenData: halfLenData, 237 wholeLenData: wholeLenData, 238 } 239 240 for { 241 if loop(t, ring, socketFd, connection, testCase) { 242 break 243 } 244 } 245 }