github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/read_write_worker.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"net"
    21  
    22  	"github.com/alitto/pond"
    23  	"github.com/pawelgaczynski/gain/pkg/queue"
    24  	"github.com/pawelgaczynski/giouring"
    25  	"github.com/rs/zerolog"
    26  )
    27  
    28  type readWriteWorker interface {
    29  	worker
    30  	activeConnections() int
    31  }
    32  
    33  type readWriteWorkerConfig struct {
    34  	workerConfig
    35  	asyncHandler  bool
    36  	goroutinePool bool
    37  	sendRecvMsg   bool
    38  }
    39  
    40  type readWriteWorkerImpl struct {
    41  	*workerImpl
    42  	*writer
    43  	*reader
    44  	ring              *giouring.Ring
    45  	connectionManager *connectionManager
    46  	asyncOpQueue      queue.LockFreeQueue[*connection]
    47  	pool              *pond.WorkerPool
    48  	eventHandler      EventHandler
    49  	asyncHandler      bool
    50  	goroutinePool     bool
    51  	sendRecvMsg       bool
    52  	localAddr         net.Addr
    53  }
    54  
    55  func (w *readWriteWorkerImpl) handleAsyncWritesIfEnabled() {
    56  	if w.asyncHandler {
    57  		w.handleAsyncWrites()
    58  	}
    59  }
    60  
    61  func (w *readWriteWorkerImpl) handleAsyncWrites() {
    62  	for {
    63  		if w.asyncOpQueue.IsEmpty() {
    64  			break
    65  		}
    66  		conn := w.asyncOpQueue.Dequeue()
    67  
    68  		var err error
    69  
    70  		switch conn.nextAsyncOp {
    71  		case readOp:
    72  			err = w.addReadRequest(conn)
    73  			if err != nil {
    74  				w.logError(err).Int("fd", conn.fd)
    75  
    76  				continue
    77  			}
    78  
    79  		case writeOp:
    80  			closed := conn.isClosed()
    81  
    82  			if w.sendRecvMsg {
    83  				conn.setMsgHeaderWrite()
    84  			}
    85  
    86  			err = w.addWriteRequest(conn, closed)
    87  			if err != nil {
    88  				w.logError(err).Int("fd", conn.fd)
    89  
    90  				continue
    91  			}
    92  
    93  			if closed {
    94  				err = w.addCloseConnRequest(conn)
    95  				if err != nil {
    96  					w.logError(err).Int("fd", conn.fd)
    97  
    98  					continue
    99  				}
   100  			}
   101  
   102  		case closeOp:
   103  			err = w.addCloseConnRequest(conn)
   104  			if err != nil {
   105  				w.logError(err).Int("fd", conn.fd)
   106  
   107  				continue
   108  			}
   109  		}
   110  	}
   111  }
   112  
   113  func (w *readWriteWorkerImpl) work(conn *connection, n int) {
   114  	conn.setUserSpace()
   115  	w.eventHandler.OnRead(conn, n)
   116  }
   117  
   118  func (w *readWriteWorkerImpl) doAsyncWork(conn *connection, n int) func() {
   119  	return func() {
   120  		w.work(conn, n)
   121  
   122  		switch {
   123  		case conn.OutboundBuffered() > 0:
   124  			conn.nextAsyncOp = writeOp
   125  		case conn.isClosed():
   126  			conn.nextAsyncOp = closeOp
   127  		default:
   128  			conn.nextAsyncOp = readOp
   129  		}
   130  
   131  		w.asyncOpQueue.Enqueue(conn)
   132  	}
   133  }
   134  
   135  func (w *readWriteWorkerImpl) writeData(conn *connection) error {
   136  	if w.sendRecvMsg {
   137  		conn.setMsgHeaderWrite()
   138  	}
   139  	closed := conn.isClosed()
   140  
   141  	err := w.addWriteRequest(conn, closed)
   142  	if err != nil {
   143  		return err
   144  	}
   145  
   146  	if closed {
   147  		return w.addCloseConnRequest(conn)
   148  	}
   149  
   150  	return nil
   151  }
   152  
   153  func (w *readWriteWorkerImpl) onRead(cqe *giouring.CompletionQueueEvent, conn *connection) error {
   154  	// https://manpages.debian.org/unstable/manpages-dev/recv.2.en.html
   155  	// These calls return the number of bytes received, or -1 if an error occurred.
   156  	// In the event of an error, errno is set to indicate the error.
   157  	// When a stream socket peer has performed an orderly shutdown,
   158  	// the return value will be 0 (the traditional "end-of-file" return).
   159  	// Datagram sockets in various domains (e.g., the UNIX and Internet domains) permit zero-length datagrams.
   160  	// When such a datagram is received, the return value is 0.
   161  	// The value 0 may also be returned if the requested number of bytes to receive from a stream socket was 0.
   162  	if cqe.Res <= 0 {
   163  		w.closeConn(conn, true, io.EOF)
   164  
   165  		return nil
   166  	}
   167  
   168  	w.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes read")
   169  
   170  	n := int(cqe.Res)
   171  	conn.onKernelRead(n)
   172  
   173  	if w.sendRecvMsg {
   174  		forkedConn := w.connectionManager.fork(conn, true)
   175  		forkedConn.localAddr = w.localAddr
   176  
   177  		err := w.addReadRequest(conn)
   178  		if err != nil {
   179  			return err
   180  		}
   181  
   182  		conn = forkedConn
   183  	}
   184  
   185  	if cqe.Flags&giouring.CQEFSockNonempty > 0 && !conn.isClosed() {
   186  		return w.addReadRequest(conn)
   187  	}
   188  
   189  	if w.asyncHandler {
   190  		if w.goroutinePool {
   191  			w.pool.Submit(w.doAsyncWork(conn, n))
   192  		} else {
   193  			go w.doAsyncWork(conn, n)()
   194  		}
   195  	} else {
   196  		w.work(conn, n)
   197  
   198  		switch {
   199  		case conn.OutboundBuffered() > 0:
   200  			return w.writeData(conn)
   201  		case conn.isClosed():
   202  			err := w.addCloseConnRequest(conn)
   203  			if err != nil {
   204  				return err
   205  			}
   206  		default:
   207  			err := w.addReadRequest(conn)
   208  			if err != nil {
   209  				return err
   210  			}
   211  		}
   212  	}
   213  
   214  	return nil
   215  }
   216  
   217  func (w *readWriteWorkerImpl) addNextRequest(conn *connection) error {
   218  	closed := conn.isClosed()
   219  
   220  	switch {
   221  	case conn.OutboundBuffered() > 0:
   222  		err := w.addWriteRequest(conn, closed)
   223  		if err != nil {
   224  			return fmt.Errorf("add read() request error: %w", err)
   225  		}
   226  
   227  		if closed {
   228  			err = w.addCloseConnRequest(conn)
   229  			if err != nil {
   230  				return fmt.Errorf("add close() request error: %w", err)
   231  			}
   232  		}
   233  
   234  	case closed:
   235  		err := w.addCloseConnRequest(conn)
   236  		if err != nil {
   237  			return fmt.Errorf("add close() request error: %w", err)
   238  		}
   239  
   240  	default:
   241  		err := w.addReadRequest(conn)
   242  		if err != nil {
   243  			return fmt.Errorf("add read() request error: %w", err)
   244  		}
   245  	}
   246  
   247  	return nil
   248  }
   249  
   250  func (w *readWriteWorkerImpl) closeConn(conn *connection, syscallClose bool, err error) {
   251  	if syscallClose {
   252  		_ = w.syscallCloseSocket(conn.fd)
   253  	}
   254  
   255  	conn.setUserSpace()
   256  
   257  	if !conn.isClosed() {
   258  		conn.Close()
   259  	}
   260  
   261  	w.eventHandler.OnClose(conn, err)
   262  	w.connectionManager.release(conn.key)
   263  }
   264  
   265  func newReadWriteWorkerImpl(ring *giouring.Ring, index int, localAddr net.Addr, eventHandler EventHandler,
   266  	connectionManager *connectionManager, config readWriteWorkerConfig, logger zerolog.Logger,
   267  ) *readWriteWorkerImpl {
   268  	worker := &readWriteWorkerImpl{
   269  		workerImpl:        newWorkerImpl(ring, config.workerConfig, index, logger),
   270  		reader:            newReader(ring, config.sendRecvMsg),
   271  		writer:            newWriter(ring, config.sendRecvMsg),
   272  		ring:              ring,
   273  		connectionManager: connectionManager,
   274  		asyncOpQueue:      queue.NewQueue[*connection](),
   275  		eventHandler:      eventHandler,
   276  		asyncHandler:      config.asyncHandler,
   277  		goroutinePool:     config.goroutinePool,
   278  		sendRecvMsg:       config.sendRecvMsg,
   279  		localAddr:         localAddr,
   280  	}
   281  	if config.asyncHandler && config.goroutinePool {
   282  		worker.pool = pond.New(goPoolMaxWorkers, goPoolMaxCapacity)
   283  	}
   284  
   285  	return worker
   286  }