github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/buffer/magicring/ringbuffer_iouring_test.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package magicring
    16  
    17  import (
    18  	"crypto/rand"
    19  	"errors"
    20  	"fmt"
    21  	"net"
    22  	"syscall"
    23  	"testing"
    24  	"time"
    25  	"unsafe"
    26  
    27  	gainNet "github.com/pawelgaczynski/gain/pkg/net"
    28  	"github.com/pawelgaczynski/giouring"
    29  	. "github.com/stretchr/testify/require"
    30  	"golang.org/x/sys/unix"
    31  )
    32  
    33  const (
    34  	accept = iota
    35  	recv
    36  	send
    37  )
    38  
    39  type conn struct {
    40  	fd             uint64
    41  	inboundBuffer  *RingBuffer
    42  	outboundBuffer *RingBuffer
    43  	state          int
    44  }
    45  
    46  func loop(t *testing.T, ring *giouring.Ring, socketFd int, connection *conn, testCase *testCase) bool {
    47  	t.Helper()
    48  
    49  	cqe, err := ring.WaitCQE()
    50  	if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) ||
    51  		errors.Is(err, syscall.ETIME) {
    52  		return false
    53  	}
    54  
    55  	Nil(t, err)
    56  	entry := ring.GetSQE()
    57  	NotNil(t, entry)
    58  	ring.CQESeen(cqe)
    59  
    60  	switch connection.state {
    61  	case accept:
    62  		Equal(t, uint64(socketFd), cqe.UserData)
    63  		Greater(t, cqe.Res, int32(0))
    64  		connection.fd = uint64(cqe.Res)
    65  		entry.PrepareRecv(
    66  			int(connection.fd),
    67  			uintptr(connection.inboundBuffer.WriteAddress()),
    68  			uint32(connection.inboundBuffer.Available()),
    69  			0)
    70  		entry.UserData = connection.fd
    71  		connection.state = recv
    72  
    73  	case recv:
    74  		var data []byte
    75  		if testCase.recvIdx == 0 {
    76  			data = testCase.halfLenData
    77  		} else {
    78  			data = testCase.wholeLenData
    79  		}
    80  		testCase.recvIdx++
    81  
    82  		Equal(t, connection.fd, cqe.UserData)
    83  		Equal(t, int32(len(data)), cqe.Res)
    84  
    85  		connection.inboundBuffer.AdvanceWrite(int(cqe.Res))
    86  		readBuf := make([]byte, DefaultMagicBufferSize)
    87  
    88  		var bytesRead int
    89  		bytesRead, err = connection.inboundBuffer.Read(readBuf)
    90  		Nil(t, err)
    91  		Equal(t, len(data), bytesRead)
    92  		Equal(t, data, readBuf[:cqe.Res])
    93  
    94  		var bytesWritten int
    95  		bytesWritten, err = connection.outboundBuffer.Write(data)
    96  		Nil(t, err)
    97  		Equal(t, len(data), bytesWritten)
    98  
    99  		entry.PrepareSend(
   100  			int(connection.fd),
   101  			uintptr(connection.outboundBuffer.ReadAddress()),
   102  			uint32(connection.outboundBuffer.Buffered()),
   103  			0)
   104  		entry.UserData = connection.fd
   105  		connection.state = send
   106  
   107  	case send:
   108  		var res int32
   109  		if testCase.sendIdx == 0 {
   110  			res = int32(DefaultMagicBufferSize / 2)
   111  		} else {
   112  			res = int32(DefaultMagicBufferSize)
   113  		}
   114  
   115  		Equal(t, connection.fd, cqe.UserData)
   116  		Equal(t, res, cqe.Res)
   117  
   118  		connection.outboundBuffer.AdvanceRead(int(cqe.Res))
   119  
   120  		if testCase.sendIdx == 0 {
   121  			entry.PrepareRecv(
   122  				int(connection.fd),
   123  				uintptr(connection.inboundBuffer.WriteAddress()),
   124  				uint32(connection.inboundBuffer.Available()),
   125  				0)
   126  			entry.UserData = connection.fd
   127  			connection.state = recv
   128  			testCase.sendIdx++
   129  		} else {
   130  			err = syscall.Shutdown(int(connection.fd), syscall.SHUT_RDWR)
   131  			Nil(t, err)
   132  
   133  			return true
   134  		}
   135  	}
   136  	cqeNr, err := ring.Submit()
   137  	Nil(t, err)
   138  	Equal(t, uint(1), cqeNr)
   139  
   140  	return false
   141  }
   142  
   143  type testCase struct {
   144  	halfLenData  []byte
   145  	wholeLenData []byte
   146  	recvIdx      int
   147  	sendIdx      int
   148  }
   149  
   150  func TestMagicRingRecvSend(t *testing.T) {
   151  	ring, err := giouring.CreateRing(16)
   152  	Nil(t, err)
   153  
   154  	defer ring.QueueExit()
   155  
   156  	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
   157  	Nil(t, err)
   158  	err = syscall.SetsockoptInt(socketFd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
   159  	Nil(t, err)
   160  	err = syscall.Bind(socketFd, &syscall.SockaddrInet4{
   161  		Port: 9876,
   162  	})
   163  	Nil(t, err)
   164  	err = syscall.SetNonblock(socketFd, false)
   165  	Nil(t, err)
   166  	err = syscall.Listen(socketFd, 128)
   167  	Nil(t, err)
   168  
   169  	defer func() {
   170  		closeErr := syscall.Close(socketFd)
   171  		Nil(t, closeErr)
   172  	}()
   173  
   174  	entry := ring.GetSQE()
   175  	NotNil(t, entry)
   176  	clientLen := new(uint32)
   177  	clientAddr := &unix.RawSockaddrAny{}
   178  	*clientLen = unix.SizeofSockaddrAny
   179  	clientAddrPointer := uintptr(unsafe.Pointer(clientAddr))
   180  	clientLenPointer := uint64(uintptr(unsafe.Pointer(clientLen)))
   181  	entry.PrepareAccept(int(uintptr(socketFd)), clientAddrPointer, clientLenPointer, 0)
   182  	entry.UserData = uint64(socketFd)
   183  	cqeNr, err := ring.Submit()
   184  	Nil(t, err)
   185  	Equal(t, uint(1), cqeNr)
   186  
   187  	wholeLenData := make([]byte, DefaultMagicBufferSize)
   188  	halfLenData := make([]byte, DefaultMagicBufferSize/2)
   189  	bytesRead, err := rand.Read(wholeLenData)
   190  	Nil(t, err)
   191  	Equal(t, DefaultMagicBufferSize, bytesRead)
   192  	bytesRead, err = rand.Read(halfLenData)
   193  	Nil(t, err)
   194  	Equal(t, DefaultMagicBufferSize/2, bytesRead)
   195  	connection := &conn{
   196  		state:          accept,
   197  		inboundBuffer:  NewMagicBuffer(DefaultMagicBufferSize),
   198  		outboundBuffer: NewMagicBuffer(DefaultMagicBufferSize),
   199  	}
   200  
   201  	clientConnChan := make(chan net.Conn)
   202  	go func() {
   203  		conn, cErr := net.DialTimeout(gainNet.TCP, fmt.Sprintf("127.0.0.1:%d", 9876), time.Second)
   204  		Nil(t, cErr)
   205  		NotNil(t, conn)
   206  
   207  		var bytesWritten int
   208  		bytesWritten, cErr = conn.Write(halfLenData)
   209  		Nil(t, cErr)
   210  		Equal(t, DefaultMagicBufferSize/2, bytesWritten)
   211  		buffer := make([]byte, DefaultMagicBufferSize)
   212  		bytesWritten, cErr = conn.Read(buffer)
   213  		Nil(t, cErr)
   214  		Equal(t, len(halfLenData), bytesWritten)
   215  		Equal(t, halfLenData, buffer[:DefaultMagicBufferSize/2])
   216  		bytesWritten, cErr = conn.Write(wholeLenData)
   217  		Nil(t, cErr)
   218  		Equal(t, DefaultMagicBufferSize, bytesWritten)
   219  		bytesWritten, cErr = conn.Read(buffer)
   220  		Nil(t, cErr)
   221  		Equal(t, len(wholeLenData), bytesWritten)
   222  		Equal(t, wholeLenData, buffer[:DefaultMagicBufferSize])
   223  
   224  		clientConnChan <- conn
   225  	}()
   226  
   227  	defer func() {
   228  		conn := <-clientConnChan
   229  		if tcpConn, ok := conn.(*net.TCPConn); ok {
   230  			lErr := tcpConn.SetLinger(0)
   231  			Nil(t, lErr)
   232  		}
   233  	}()
   234  
   235  	testCase := &testCase{
   236  		halfLenData:  halfLenData,
   237  		wholeLenData: wholeLenData,
   238  	}
   239  
   240  	for {
   241  		if loop(t, ring, socketFd, connection, testCase) {
   242  			break
   243  		}
   244  	}
   245  }