github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/shard_worker.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gain 16 17 import ( 18 "fmt" 19 "net" 20 "sync/atomic" 21 "time" 22 23 "github.com/pawelgaczynski/gain/logger" 24 gainErrors "github.com/pawelgaczynski/gain/pkg/errors" 25 "github.com/pawelgaczynski/gain/pkg/socket" 26 "github.com/pawelgaczynski/giouring" 27 "golang.org/x/sys/unix" 28 ) 29 30 type shardWorkerConfig struct { 31 readWriteWorkerConfig 32 tcpKeepAlive time.Duration 33 } 34 35 type shardWorker struct { 36 *acceptor 37 *readWriteWorkerImpl 38 ring *giouring.Ring 39 connectionManager *connectionManager 40 cpuAffinity bool 41 tcpKeepAlive time.Duration 42 connectionProtocol bool 43 accepting atomic.Bool 44 } 45 46 func (w *shardWorker) onAccept(cqe *giouring.CompletionQueueEvent) error { 47 err := w.addAcceptConnRequest() 48 if err != nil { 49 w.accepting.Store(false) 50 51 return fmt.Errorf("add accept() request error: %w", err) 52 } 53 54 fileDescriptor := int(cqe.Res) 55 w.logDebug().Int("fd", fileDescriptor).Msg("Connection accepted") 56 57 conn := w.connectionManager.getFd(fileDescriptor) 58 conn.fd = fileDescriptor 59 conn.localAddr = w.localAddr 60 61 var clientAddr net.Addr 62 if clientAddr, err = w.acceptor.lastClientAddr(); err == nil { 63 conn.remoteAddr = clientAddr 64 } else { 65 w.logError(err).Msg("get last client address failed") 66 } 67 68 if w.cpuAffinity { 69 err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, w.index()+1) 70 if err != nil { 71 return fmt.Errorf("fd: %d, setting SO_INCOMING_CPU error: %w", fileDescriptor, err) 72 } 73 } 74 75 if w.tcpKeepAlive > 0 { 76 err = socket.SetKeepAlivePeriod(fileDescriptor, int(w.tcpKeepAlive.Seconds())) 77 if err != nil { 78 return fmt.Errorf("fd: %d, setting tcpKeepAlive error: %w", fileDescriptor, err) 79 } 80 } 81 82 conn.setUserSpace() 83 w.eventHandler.OnAccept(conn) 84 85 return w.addNextRequest(conn) 86 } 87 88 func (w *shardWorker) stopAccept() error { 89 return w.syscallCloseSocket(w.fd) 90 } 91 92 func (w *shardWorker) activeConnections() int { 93 return w.connectionManager.activeConnections() - 1 94 } 95 96 func (w *shardWorker) handleConn(conn *connection, cqe *giouring.CompletionQueueEvent) { 97 var ( 98 err error 99 errMsg string 100 ) 101 102 switch conn.state { 103 case connAccept: 104 err = w.onAccept(cqe) 105 if err != nil { 106 errMsg = "accept error" 107 } 108 109 case connRead: 110 err = w.onRead(cqe, conn) 111 if err != nil { 112 errMsg = "read error" 113 } 114 115 case connWrite: 116 if !w.connectionProtocol && conn.key == 1 { 117 errMsg = "main socket cannot be in write mode" 118 119 break 120 } 121 122 n := int(cqe.Res) 123 conn.onKernelWrite(n) 124 w.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes writed") 125 126 conn.setUserSpace() 127 w.eventHandler.OnWrite(conn, n) 128 129 if w.sendRecvMsg { 130 w.connectionManager.release(conn.key) 131 132 break 133 } 134 135 err = w.addNextRequest(conn) 136 if err != nil { 137 errMsg = "add request error" 138 } 139 140 case connClose: 141 if cqe.UserData&closeConnFlag > 0 { 142 w.closeConn(conn, false, nil) 143 } else if cqe.UserData&writeDataFlag > 0 { 144 n := int(cqe.Res) 145 conn.onKernelWrite(n) 146 w.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes writed") 147 conn.setUserSpace() 148 w.eventHandler.OnWrite(conn, n) 149 } 150 151 default: 152 err = gainErrors.ErrorUnknownConnectionState(int(conn.state)) 153 } 154 155 if err != nil { 156 w.logError(err).Msg(errMsg) 157 158 w.closeConn(conn, true, err) 159 } 160 } 161 162 func (w *shardWorker) initLoop() { 163 if w.connectionProtocol { 164 w.prepareHandler = func() error { 165 w.startedChan <- done 166 167 err := w.addAcceptConnRequest() 168 if err == nil { 169 w.accepting.Store(true) 170 } 171 172 return err 173 } 174 } else { 175 w.prepareHandler = func() error { 176 w.startedChan <- done 177 // 1 is always index for main socket 178 conn := w.connectionManager.get(1, w.fd) 179 conn.fd = w.fd 180 conn.initMsgHeader() 181 182 return w.addReadRequest(conn) 183 } 184 } 185 w.shutdownHandler = func() bool { 186 if w.needToShutdown() { 187 w.onCloseHandler() 188 w.markShutdownInProgress() 189 } 190 191 return true 192 } 193 w.loopFinisher = w.handleAsyncWritesIfEnabled 194 w.loopFinishCondition = func() bool { 195 if w.connectionManager.allClosed() || (w.connectionProtocol && !w.accepting.Load() && w.activeConnections() == 0) { 196 w.close() 197 w.notifyFinish() 198 199 return true 200 } 201 202 return false 203 } 204 } 205 206 func (w *shardWorker) loop(fd int) error { 207 w.logInfo().Int("fd", fd).Msg("Starting worker loop...") 208 w.fd = fd 209 w.initLoop() 210 211 loopErr := w.startLoop(w.index(), func(cqe *giouring.CompletionQueueEvent) error { 212 if exit := w.processEvent(cqe, func(cqe *giouring.CompletionQueueEvent) bool { 213 keyOrFd := cqe.UserData & ^allFlagsMask 214 if acceptReqFailedAfterStop := keyOrFd == uint64(w.fd) && 215 !w.accepting.Load(); acceptReqFailedAfterStop { 216 return true 217 } 218 219 return false 220 }); exit { 221 return nil 222 } 223 key := int(cqe.UserData & ^allFlagsMask) 224 var connFd int 225 if w.connectionProtocol { 226 connFd = key 227 } else { 228 connFd = w.fd 229 } 230 conn := w.connectionManager.get(key, connFd) 231 w.handleConn(conn, cqe) 232 233 return nil 234 }) 235 _ = w.stopAccept() 236 237 return loopErr 238 } 239 240 func (w *shardWorker) closeAllConnsAndRings() { 241 w.logWarn().Msg("Closing connections") 242 w.accepting.Store(false) 243 _ = w.syscallCloseSocket(w.fd) 244 w.connectionManager.close(func(conn *connection) bool { 245 err := w.addCloseConnRequest(conn) 246 if err != nil { 247 w.logError(err).Msg("Add close() connection request error") 248 } 249 250 return err == nil 251 }, w.fd) 252 } 253 254 func newShardWorker( 255 index int, localAddr net.Addr, config shardWorkerConfig, eventHandler EventHandler, 256 ) (*shardWorker, error) { 257 ring, err := giouring.CreateRing(uint32(config.maxSQEntries)) 258 if err != nil { 259 return nil, fmt.Errorf("creating ring error: %w", err) 260 } 261 logger := logger.NewLogger("worker", config.loggerLevel, config.prettyLogger) 262 connectionManager := newConnectionManager() 263 connectionProtocol := !config.sendRecvMsg 264 worker := &shardWorker{ 265 readWriteWorkerImpl: newReadWriteWorkerImpl( 266 ring, index, localAddr, eventHandler, connectionManager, config.readWriteWorkerConfig, logger, 267 ), 268 ring: ring, 269 connectionManager: connectionManager, 270 cpuAffinity: config.cpuAffinity, 271 tcpKeepAlive: config.tcpKeepAlive, 272 connectionProtocol: connectionProtocol, 273 acceptor: newAcceptor(ring, connectionManager), 274 } 275 worker.onCloseHandler = worker.closeAllConnsAndRings 276 277 return worker, nil 278 }