github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/shard_worker.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain
    16  
    17  import (
    18  	"fmt"
    19  	"net"
    20  	"sync/atomic"
    21  	"time"
    22  
    23  	"github.com/pawelgaczynski/gain/logger"
    24  	gainErrors "github.com/pawelgaczynski/gain/pkg/errors"
    25  	"github.com/pawelgaczynski/gain/pkg/socket"
    26  	"github.com/pawelgaczynski/giouring"
    27  	"golang.org/x/sys/unix"
    28  )
    29  
    30  type shardWorkerConfig struct {
    31  	readWriteWorkerConfig
    32  	tcpKeepAlive time.Duration
    33  }
    34  
    35  type shardWorker struct {
    36  	*acceptor
    37  	*readWriteWorkerImpl
    38  	ring               *giouring.Ring
    39  	connectionManager  *connectionManager
    40  	cpuAffinity        bool
    41  	tcpKeepAlive       time.Duration
    42  	connectionProtocol bool
    43  	accepting          atomic.Bool
    44  }
    45  
    46  func (w *shardWorker) onAccept(cqe *giouring.CompletionQueueEvent) error {
    47  	err := w.addAcceptConnRequest()
    48  	if err != nil {
    49  		w.accepting.Store(false)
    50  
    51  		return fmt.Errorf("add accept() request error: %w", err)
    52  	}
    53  
    54  	fileDescriptor := int(cqe.Res)
    55  	w.logDebug().Int("fd", fileDescriptor).Msg("Connection accepted")
    56  
    57  	conn := w.connectionManager.getFd(fileDescriptor)
    58  	conn.fd = fileDescriptor
    59  	conn.localAddr = w.localAddr
    60  
    61  	var clientAddr net.Addr
    62  	if clientAddr, err = w.acceptor.lastClientAddr(); err == nil {
    63  		conn.remoteAddr = clientAddr
    64  	} else {
    65  		w.logError(err).Msg("get last client address failed")
    66  	}
    67  
    68  	if w.cpuAffinity {
    69  		err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, w.index()+1)
    70  		if err != nil {
    71  			return fmt.Errorf("fd: %d, setting SO_INCOMING_CPU error: %w", fileDescriptor, err)
    72  		}
    73  	}
    74  
    75  	if w.tcpKeepAlive > 0 {
    76  		err = socket.SetKeepAlivePeriod(fileDescriptor, int(w.tcpKeepAlive.Seconds()))
    77  		if err != nil {
    78  			return fmt.Errorf("fd: %d, setting tcpKeepAlive error: %w", fileDescriptor, err)
    79  		}
    80  	}
    81  
    82  	conn.setUserSpace()
    83  	w.eventHandler.OnAccept(conn)
    84  
    85  	return w.addNextRequest(conn)
    86  }
    87  
    88  func (w *shardWorker) stopAccept() error {
    89  	return w.syscallCloseSocket(w.fd)
    90  }
    91  
    92  func (w *shardWorker) activeConnections() int {
    93  	return w.connectionManager.activeConnections() - 1
    94  }
    95  
    96  func (w *shardWorker) handleConn(conn *connection, cqe *giouring.CompletionQueueEvent) {
    97  	var (
    98  		err    error
    99  		errMsg string
   100  	)
   101  
   102  	switch conn.state {
   103  	case connAccept:
   104  		err = w.onAccept(cqe)
   105  		if err != nil {
   106  			errMsg = "accept error"
   107  		}
   108  
   109  	case connRead:
   110  		err = w.onRead(cqe, conn)
   111  		if err != nil {
   112  			errMsg = "read error"
   113  		}
   114  
   115  	case connWrite:
   116  		if !w.connectionProtocol && conn.key == 1 {
   117  			errMsg = "main socket cannot be in write mode"
   118  
   119  			break
   120  		}
   121  
   122  		n := int(cqe.Res)
   123  		conn.onKernelWrite(n)
   124  		w.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes writed")
   125  
   126  		conn.setUserSpace()
   127  		w.eventHandler.OnWrite(conn, n)
   128  
   129  		if w.sendRecvMsg {
   130  			w.connectionManager.release(conn.key)
   131  
   132  			break
   133  		}
   134  
   135  		err = w.addNextRequest(conn)
   136  		if err != nil {
   137  			errMsg = "add request error"
   138  		}
   139  
   140  	case connClose:
   141  		if cqe.UserData&closeConnFlag > 0 {
   142  			w.closeConn(conn, false, nil)
   143  		} else if cqe.UserData&writeDataFlag > 0 {
   144  			n := int(cqe.Res)
   145  			conn.onKernelWrite(n)
   146  			w.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes writed")
   147  			conn.setUserSpace()
   148  			w.eventHandler.OnWrite(conn, n)
   149  		}
   150  
   151  	default:
   152  		err = gainErrors.ErrorUnknownConnectionState(int(conn.state))
   153  	}
   154  
   155  	if err != nil {
   156  		w.logError(err).Msg(errMsg)
   157  
   158  		w.closeConn(conn, true, err)
   159  	}
   160  }
   161  
   162  func (w *shardWorker) initLoop() {
   163  	if w.connectionProtocol {
   164  		w.prepareHandler = func() error {
   165  			w.startedChan <- done
   166  
   167  			err := w.addAcceptConnRequest()
   168  			if err == nil {
   169  				w.accepting.Store(true)
   170  			}
   171  
   172  			return err
   173  		}
   174  	} else {
   175  		w.prepareHandler = func() error {
   176  			w.startedChan <- done
   177  			// 1 is always index for main socket
   178  			conn := w.connectionManager.get(1, w.fd)
   179  			conn.fd = w.fd
   180  			conn.initMsgHeader()
   181  
   182  			return w.addReadRequest(conn)
   183  		}
   184  	}
   185  	w.shutdownHandler = func() bool {
   186  		if w.needToShutdown() {
   187  			w.onCloseHandler()
   188  			w.markShutdownInProgress()
   189  		}
   190  
   191  		return true
   192  	}
   193  	w.loopFinisher = w.handleAsyncWritesIfEnabled
   194  	w.loopFinishCondition = func() bool {
   195  		if w.connectionManager.allClosed() || (w.connectionProtocol && !w.accepting.Load() && w.activeConnections() == 0) {
   196  			w.close()
   197  			w.notifyFinish()
   198  
   199  			return true
   200  		}
   201  
   202  		return false
   203  	}
   204  }
   205  
   206  func (w *shardWorker) loop(fd int) error {
   207  	w.logInfo().Int("fd", fd).Msg("Starting worker loop...")
   208  	w.fd = fd
   209  	w.initLoop()
   210  
   211  	loopErr := w.startLoop(w.index(), func(cqe *giouring.CompletionQueueEvent) error {
   212  		if exit := w.processEvent(cqe, func(cqe *giouring.CompletionQueueEvent) bool {
   213  			keyOrFd := cqe.UserData & ^allFlagsMask
   214  			if acceptReqFailedAfterStop := keyOrFd == uint64(w.fd) &&
   215  				!w.accepting.Load(); acceptReqFailedAfterStop {
   216  				return true
   217  			}
   218  
   219  			return false
   220  		}); exit {
   221  			return nil
   222  		}
   223  		key := int(cqe.UserData & ^allFlagsMask)
   224  		var connFd int
   225  		if w.connectionProtocol {
   226  			connFd = key
   227  		} else {
   228  			connFd = w.fd
   229  		}
   230  		conn := w.connectionManager.get(key, connFd)
   231  		w.handleConn(conn, cqe)
   232  
   233  		return nil
   234  	})
   235  	_ = w.stopAccept()
   236  
   237  	return loopErr
   238  }
   239  
   240  func (w *shardWorker) closeAllConnsAndRings() {
   241  	w.logWarn().Msg("Closing connections")
   242  	w.accepting.Store(false)
   243  	_ = w.syscallCloseSocket(w.fd)
   244  	w.connectionManager.close(func(conn *connection) bool {
   245  		err := w.addCloseConnRequest(conn)
   246  		if err != nil {
   247  			w.logError(err).Msg("Add close() connection request error")
   248  		}
   249  
   250  		return err == nil
   251  	}, w.fd)
   252  }
   253  
   254  func newShardWorker(
   255  	index int, localAddr net.Addr, config shardWorkerConfig, eventHandler EventHandler,
   256  ) (*shardWorker, error) {
   257  	ring, err := giouring.CreateRing(uint32(config.maxSQEntries))
   258  	if err != nil {
   259  		return nil, fmt.Errorf("creating ring error: %w", err)
   260  	}
   261  	logger := logger.NewLogger("worker", config.loggerLevel, config.prettyLogger)
   262  	connectionManager := newConnectionManager()
   263  	connectionProtocol := !config.sendRecvMsg
   264  	worker := &shardWorker{
   265  		readWriteWorkerImpl: newReadWriteWorkerImpl(
   266  			ring, index, localAddr, eventHandler, connectionManager, config.readWriteWorkerConfig, logger,
   267  		),
   268  		ring:               ring,
   269  		connectionManager:  connectionManager,
   270  		cpuAffinity:        config.cpuAffinity,
   271  		tcpKeepAlive:       config.tcpKeepAlive,
   272  		connectionProtocol: connectionProtocol,
   273  		acceptor:           newAcceptor(ring, connectionManager),
   274  	}
   275  	worker.onCloseHandler = worker.closeAllConnsAndRings
   276  
   277  	return worker, nil
   278  }