github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/conn.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain
    16  
    17  import (
    18  	"io"
    19  	"net"
    20  	"sync/atomic"
    21  	"syscall"
    22  	"time"
    23  	"unsafe"
    24  
    25  	"github.com/pawelgaczynski/gain/pkg/buffer/magicring"
    26  	"github.com/pawelgaczynski/gain/pkg/errors"
    27  	"github.com/pawelgaczynski/gain/pkg/pool/byteslice"
    28  	"github.com/pawelgaczynski/gain/pkg/pool/ringbuffer"
    29  	"github.com/pawelgaczynski/gain/pkg/socket"
    30  )
    31  
    32  type connectionState int
    33  
    34  const (
    35  	connInvalid connectionState = iota
    36  	connAccept
    37  	connRead
    38  	connWrite
    39  	connClose
    40  )
    41  
    42  func (s connectionState) String() string {
    43  	switch s {
    44  	case connAccept:
    45  		return "accept"
    46  	case connRead:
    47  		return "read"
    48  	case connWrite:
    49  		return "write"
    50  	case connClose:
    51  		return "close"
    52  	default:
    53  		return "invalid"
    54  	}
    55  }
    56  
    57  const (
    58  	kernelSpace = iota
    59  	userSpace
    60  )
    61  
    62  func connModeString(m uint32) string {
    63  	switch m {
    64  	case kernelSpace:
    65  		return "kernelSpace"
    66  	case userSpace:
    67  		return "userSpace"
    68  	default:
    69  		return "invalid"
    70  	}
    71  }
    72  
    73  const (
    74  	msgControlBufferSize = 64
    75  )
    76  
    77  const (
    78  	noOp = iota
    79  	readOp
    80  	writeOp
    81  	closeOp
    82  )
    83  
    84  const (
    85  	tcp = iota
    86  	udp
    87  )
    88  
    89  type connection struct {
    90  	fd      int
    91  	key     int
    92  	network uint32
    93  
    94  	inboundBuffer  *magicring.RingBuffer
    95  	outboundBuffer *magicring.RingBuffer
    96  	state          connectionState
    97  	mode           atomic.Uint32
    98  	closed         atomic.Bool
    99  
   100  	msgHdr      *syscall.Msghdr
   101  	rawSockaddr *syscall.RawSockaddrAny
   102  
   103  	localAddr  net.Addr
   104  	remoteAddr net.Addr
   105  
   106  	ctx interface{}
   107  
   108  	nextAsyncOp int
   109  }
   110  
   111  func (c *connection) outboundReadAddress() unsafe.Pointer {
   112  	return c.outboundBuffer.ReadAddress()
   113  }
   114  
   115  func (c *connection) inboundWriteAddress() unsafe.Pointer {
   116  	return c.inboundBuffer.WriteAddress()
   117  }
   118  
   119  func (c *connection) setKernelSpace() {
   120  	c.mode.Store(kernelSpace)
   121  }
   122  
   123  func (c *connection) setUserSpace() {
   124  	c.mode.Store(userSpace)
   125  }
   126  
   127  func (c *connection) Context() interface{} {
   128  	return c.ctx
   129  }
   130  
   131  func (c *connection) SetContext(ctx interface{}) {
   132  	c.ctx = ctx
   133  }
   134  
   135  func (c *connection) LocalAddr() net.Addr {
   136  	return c.localAddr
   137  }
   138  
   139  func (c *connection) RemoteAddr() net.Addr {
   140  	return c.remoteAddr
   141  }
   142  
   143  func (c *connection) Fd() int {
   144  	return c.fd
   145  }
   146  
   147  func (c *connection) userOpAllowed(name string) error {
   148  	if c.closed.Load() {
   149  		return errors.ErrConnectionClosed
   150  	}
   151  
   152  	if mode := c.mode.Load(); mode != userSpace {
   153  		return errors.ErrorOpNotAvailableInMode(name, connModeString(mode))
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  func (c *connection) SetReadBuffer(bytes int) error {
   160  	err := c.userOpAllowed("setReadBuffer")
   161  	if err != nil {
   162  		return err
   163  	}
   164  	//nolint:wrapcheck
   165  	return socket.SetRecvBuffer(c.fd, bytes)
   166  }
   167  
   168  func (c *connection) SetWriteBuffer(bytes int) error {
   169  	err := c.userOpAllowed("setWriteBuffer")
   170  	if err != nil {
   171  		return err
   172  	}
   173  	//nolint:wrapcheck
   174  	return socket.SetSendBuffer(c.fd, bytes)
   175  }
   176  
   177  func (c *connection) SetLinger(sec int) error {
   178  	err := c.userOpAllowed("setLinger")
   179  	if err != nil {
   180  		return err
   181  	}
   182  	//nolint:wrapcheck
   183  	return socket.SetLinger(c.fd, sec)
   184  }
   185  
   186  func (c *connection) SetNoDelay(noDelay bool) error {
   187  	err := c.userOpAllowed("setNoDelay")
   188  	if err != nil {
   189  		return err
   190  	}
   191  	//nolint:wrapcheck
   192  	return socket.SetNoDelay(c.fd, boolToInt(noDelay))
   193  }
   194  
   195  func (c *connection) SetKeepAlivePeriod(period time.Duration) error {
   196  	err := c.userOpAllowed("setKeepAlivePeriod")
   197  	if err != nil {
   198  		return err
   199  	}
   200  	//nolint:wrapcheck
   201  	return socket.SetKeepAlivePeriod(c.fd, int(period.Seconds()))
   202  }
   203  
   204  func (c *connection) onKernelRead(n int) {
   205  	c.inboundBuffer.AdvanceWrite(n)
   206  }
   207  
   208  func (c *connection) onKernelWrite(n int) {
   209  	c.outboundBuffer.AdvanceRead(n)
   210  }
   211  
   212  func (c *connection) isClosed() bool {
   213  	return c.closed.Load()
   214  }
   215  
   216  func (c *connection) Close() error {
   217  	if network := atomic.LoadUint32(&c.network); network == udp {
   218  		return nil
   219  	}
   220  
   221  	if c.closed.Load() {
   222  		return errors.ErrConnectionAlreadyClosed
   223  	}
   224  
   225  	c.closed.Store(true)
   226  
   227  	return nil
   228  }
   229  
   230  func (c *connection) Next(n int) ([]byte, error) {
   231  	err := c.userOpAllowed("next")
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	//nolint:wrapcheck
   237  	return c.inboundBuffer.Next(n)
   238  }
   239  
   240  func (c *connection) Discard(n int) (int, error) {
   241  	err := c.userOpAllowed("discard")
   242  	if err != nil {
   243  		return 0, err
   244  	}
   245  
   246  	return c.inboundBuffer.Discard(n), nil
   247  }
   248  
   249  func (c *connection) Peek(n int) ([]byte, error) {
   250  	err := c.userOpAllowed("peek")
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	return c.inboundBuffer.Peek(n), nil
   256  }
   257  
   258  func (c *connection) ReadFrom(reader io.Reader) (int64, error) {
   259  	err := c.userOpAllowed("readFrom")
   260  	if err != nil {
   261  		return 0, err
   262  	}
   263  
   264  	//nolint:wrapcheck
   265  	return c.outboundBuffer.ReadFrom(reader)
   266  }
   267  
   268  func (c *connection) WriteTo(writer io.Writer) (int64, error) {
   269  	err := c.userOpAllowed("writeTo")
   270  	if err != nil {
   271  		return 0, err
   272  	}
   273  
   274  	//nolint:wrapcheck
   275  	return c.inboundBuffer.WriteTo(writer)
   276  }
   277  
   278  func (c *connection) Read(buffer []byte) (int, error) {
   279  	err := c.userOpAllowed("read")
   280  	if err != nil {
   281  		return 0, err
   282  	}
   283  
   284  	//nolint:wrapcheck
   285  	return c.inboundBuffer.Read(buffer)
   286  }
   287  
   288  func (c *connection) Write(buffer []byte) (int, error) {
   289  	err := c.userOpAllowed("write")
   290  	if err != nil {
   291  		return 0, err
   292  	}
   293  
   294  	//nolint:wrapcheck
   295  	return c.outboundBuffer.Write(buffer)
   296  }
   297  
   298  func (c *connection) OutboundBuffered() int {
   299  	return c.outboundBuffer.Buffered()
   300  }
   301  
   302  func (c *connection) InboundBuffered() int {
   303  	return c.inboundBuffer.Buffered()
   304  }
   305  
   306  func (c *connection) setMsgHeaderWrite() {
   307  	c.msgHdr.Iov.Base = (*byte)(c.outboundReadAddress())
   308  	c.msgHdr.Iov.SetLen(c.OutboundBuffered())
   309  }
   310  
   311  func (c *connection) initMsgHeader() {
   312  	var iovec syscall.Iovec
   313  	iovec.Base = (*byte)(c.inboundWriteAddress())
   314  	iovec.SetLen(c.inboundBuffer.Cap())
   315  
   316  	var (
   317  		msg syscall.Msghdr
   318  		rsa syscall.RawSockaddrAny
   319  	)
   320  
   321  	msg.Name = (*byte)(unsafe.Pointer(&rsa))
   322  	msg.Namelen = uint32(syscall.SizeofSockaddrAny)
   323  	msg.Iov = &iovec
   324  	msg.Iovlen = 1
   325  
   326  	controlBuffer := byteslice.Get(msgControlBufferSize)
   327  	msg.Control = (*byte)(unsafe.Pointer(&controlBuffer[0]))
   328  	msg.SetControllen(msgControlBufferSize)
   329  
   330  	c.msgHdr = &msg
   331  	c.rawSockaddr = &rsa
   332  }
   333  
   334  func (c *connection) fork(newConn *connection, key int, write bool) *connection {
   335  	newConn.inboundBuffer = c.inboundBuffer
   336  	newConn.outboundBuffer = c.outboundBuffer
   337  	newConn.msgHdr = c.msgHdr
   338  	newConn.rawSockaddr = c.rawSockaddr
   339  	newConn.state = c.state
   340  	newConn.fd = c.fd
   341  	newConn.key = key
   342  	newConn.network = udp
   343  
   344  	if sockAddr, err := anyToSockaddr(newConn.rawSockaddr); err == nil {
   345  		newConn.remoteAddr = socket.SockaddrToUDPAddr(sockAddr)
   346  	}
   347  
   348  	if write {
   349  		newConn.setMsgHeaderWrite()
   350  	}
   351  
   352  	c.inboundBuffer = ringbuffer.Get()
   353  	c.outboundBuffer = ringbuffer.Get()
   354  	c.initMsgHeader()
   355  
   356  	return newConn
   357  }
   358  
   359  func newConnection() *connection {
   360  	conn := &connection{
   361  		inboundBuffer:  ringbuffer.Get(),
   362  		outboundBuffer: ringbuffer.Get(),
   363  	}
   364  
   365  	return conn
   366  }