github.com/pawelgaczynski/giouring@v0.0.0-20230826085535-69588b89acb9/udp_recv_send_test.go (about)

     1  // MIT License
     2  //
     3  // Copyright (c) 2023 Paweł Gaczyński
     4  //
     5  // Permission is hereby granted, free of charge, to any person obtaining a
     6  // copy of this software and associated documentation files (the
     7  // "Software"), to deal in the Software without restriction, including
     8  // without limitation the rights to use, copy, modify, merge, publish,
     9  // distribute, sublicense, and/or sell copies of the Software, and to
    10  // permit persons to whom the Software is furnished to do so, subject to
    11  // the following conditions:
    12  //
    13  // The above copyright notice and this permission notice shall be included
    14  // in all copies or substantial portions of the Software.
    15  //
    16  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
    17  // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
    18  // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
    19  // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
    20  // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
    21  // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
    22  // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
    23  
    24  package giouring
    25  
    26  import (
    27  	"errors"
    28  	"fmt"
    29  	"log"
    30  	"net"
    31  	"syscall"
    32  	"testing"
    33  	"time"
    34  	"unsafe"
    35  
    36  	. "github.com/stretchr/testify/require"
    37  )
    38  
    39  const (
    40  	udpRecv = iota
    41  	udpSend
    42  )
    43  
    44  func anyToSockaddrInet4(rsa *syscall.RawSockaddrAny) (*syscall.SockaddrInet4, error) {
    45  	if rsa == nil {
    46  		return nil, syscall.EINVAL
    47  	}
    48  
    49  	if rsa.Addr.Family != syscall.AF_INET {
    50  		return nil, syscall.EAFNOSUPPORT
    51  	}
    52  
    53  	rsaPointer := (*syscall.RawSockaddrInet4)(unsafe.Pointer(rsa))
    54  	sockAddr := new(syscall.SockaddrInet4)
    55  	p := (*[2]byte)(unsafe.Pointer(&rsaPointer.Port))
    56  	sockAddr.Port = int(p[0])<<8 + int(p[1])
    57  
    58  	for i := 0; i < len(sockAddr.Addr); i++ {
    59  		sockAddr.Addr[i] = rsaPointer.Addr[i]
    60  	}
    61  
    62  	return sockAddr, nil
    63  }
    64  
    65  type udpConnection struct {
    66  	msg           *syscall.Msghdr
    67  	rsa           *syscall.RawSockaddrAny
    68  	buffer        []byte
    69  	controlBuffer []byte
    70  	fd            uint64
    71  	state         int
    72  }
    73  
    74  func udpLoop(t *testing.T, ring *Ring, socketFd int, connection *udpConnection) bool {
    75  	t.Helper()
    76  
    77  	cqe, err := ring.WaitCQE()
    78  	if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) ||
    79  		errors.Is(err, syscall.ETIME) {
    80  		return false
    81  	}
    82  
    83  	Nil(t, err)
    84  	entry := ring.GetSQE()
    85  	NotNil(t, entry)
    86  	ring.CQESeen(cqe)
    87  
    88  	switch connection.state {
    89  	case udpRecv:
    90  		_, err = anyToSockaddrInet4(connection.rsa)
    91  		if err != nil {
    92  			log.Panic(err)
    93  		}
    94  
    95  		Equal(t, "testdata1234567890", string(connection.buffer[:18]))
    96  		connection.buffer = connection.buffer[:0]
    97  		data := []byte("responsedata0123456789")
    98  		copied := copy(connection.buffer[:len(data)], data)
    99  		Equal(t, 22, copied)
   100  		buffer := connection.buffer[:len(data)]
   101  
   102  		connection.msg.Iov.Base = (*byte)(unsafe.Pointer(&buffer[0]))
   103  		connection.msg.Iov.SetLen(len(buffer))
   104  		entry.PrepareSendMsg(socketFd, connection.msg, 0)
   105  
   106  		entry.UserData = connection.fd
   107  		connection.state = udpSend
   108  
   109  	case udpSend:
   110  		Equal(t, connection.fd, cqe.UserData)
   111  		Equal(t, cqe.Res, int32(22))
   112  
   113  		return true
   114  	}
   115  	cqeNr, err := ring.Submit()
   116  	Nil(t, err)
   117  	Equal(t, uint(1), cqeNr)
   118  
   119  	return false
   120  }
   121  
   122  func TestUDPRecvSend(t *testing.T) {
   123  	ring, err := CreateRing(16)
   124  	Nil(t, err)
   125  
   126  	defer ring.QueueExit()
   127  
   128  	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
   129  	Nil(t, err)
   130  	err = syscall.SetsockoptInt(socketFd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
   131  	Nil(t, err)
   132  	testPort := getTestPort()
   133  	err = syscall.Bind(socketFd, &syscall.SockaddrInet4{
   134  		Port: testPort,
   135  	})
   136  	Nil(t, err)
   137  	err = syscall.SetNonblock(socketFd, false)
   138  	Nil(t, err)
   139  
   140  	defer func() {
   141  		closerErr := syscall.Close(socketFd)
   142  		Nil(t, closerErr)
   143  	}()
   144  
   145  	entry := ring.GetSQE()
   146  	NotNil(t, entry)
   147  
   148  	buffer := make([]byte, 64)
   149  
   150  	var iovec syscall.Iovec
   151  	iovec.Base = (*byte)(unsafe.Pointer(&buffer[0]))
   152  	iovec.SetLen(len(buffer))
   153  
   154  	controlBuffer := make([]byte, 64)
   155  
   156  	var (
   157  		msg syscall.Msghdr
   158  		rsa syscall.RawSockaddrAny
   159  	)
   160  	msg.Name = (*byte)(unsafe.Pointer(&rsa))
   161  	msg.Namelen = uint32(syscall.SizeofSockaddrAny)
   162  	msg.Iov = &iovec
   163  	msg.Iovlen = 1
   164  	msg.Control = (*byte)(unsafe.Pointer(&controlBuffer[0]))
   165  	msg.SetControllen(len(controlBuffer))
   166  
   167  	entry.PrepareRecvMsg(socketFd, &msg, 0)
   168  	entry.UserData = uint64(socketFd)
   169  
   170  	cqeNr, err := ring.Submit()
   171  	Nil(t, err)
   172  	Equal(t, uint(1), cqeNr)
   173  
   174  	connection := &udpConnection{state: udpRecv, buffer: buffer, msg: &msg, rsa: &rsa, controlBuffer: controlBuffer}
   175  
   176  	clientConnChan := make(chan net.Conn)
   177  	go func() {
   178  		conn, cErr := net.DialTimeout("udp", fmt.Sprintf("127.0.0.1:%d", testPort), time.Second)
   179  		Nil(t, cErr)
   180  		NotNil(t, conn)
   181  		bytesWritten, cErr := conn.Write([]byte("testdata1234567890"))
   182  		Nil(t, cErr)
   183  		Equal(t, 18, bytesWritten)
   184  
   185  		var buffer [22]byte
   186  		bytesWritten, cErr = conn.Read(buffer[:])
   187  		Nil(t, cErr)
   188  		Equal(t, 22, bytesWritten)
   189  		Equal(t, "responsedata0123456789", string(buffer[:]))
   190  		clientConnChan <- conn
   191  	}()
   192  
   193  	defer func() {
   194  		<-clientConnChan
   195  	}()
   196  
   197  	for {
   198  		if udpLoop(t, ring, socketFd, connection) {
   199  			break
   200  		}
   201  	}
   202  }