github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/consumer_worker.go (about) 1 package gain 2 3 // Copyright (c) 2023 Paweł Gaczyński 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 import ( 18 "fmt" 19 "net" 20 "sync" 21 "syscall" 22 23 "github.com/pawelgaczynski/gain/logger" 24 gainErrors "github.com/pawelgaczynski/gain/pkg/errors" 25 "github.com/pawelgaczynski/gain/pkg/queue" 26 "github.com/pawelgaczynski/giouring" 27 ) 28 29 type consumerConfig struct { 30 readWriteWorkerConfig 31 } 32 33 type consumer interface { 34 readWriteWorker 35 addConnToQueue(fd int) error 36 setSocketAddr(fd int, addr net.Addr) 37 } 38 39 type consumerWorker struct { 40 *readWriteWorkerImpl 41 config consumerConfig 42 43 socketAddresses sync.Map 44 // used for kernels < 5.18 where OP_MSG_RING is not supported 45 connQueue queue.LockFreeQueue[int] 46 } 47 48 func (c *consumerWorker) setSocketAddr(fd int, addr net.Addr) { 49 c.socketAddresses.Store(fd, addr) 50 } 51 52 func (c *consumerWorker) addConnToQueue(fd int) error { 53 if c.connQueue == nil { 54 return gainErrors.ErrConnectionQueueIsNil 55 } 56 57 c.connQueue.Enqueue(fd) 58 59 return nil 60 } 61 62 func (c *consumerWorker) closeAllConns() { 63 c.logWarn().Msg("Closing connections") 64 c.connectionManager.close(func(conn *connection) bool { 65 err := c.addCloseConnRequest(conn) 66 if err != nil { 67 c.logError(err).Msg("Add close() connection request error") 68 } 69 70 return err == nil 71 }, -1) 72 } 73 74 func (c *consumerWorker) activeConnections() int { 75 return c.connectionManager.activeConnections() 76 } 77 78 func (c *consumerWorker) handleConn(conn *connection, cqe *giouring.CompletionQueueEvent) { 79 var ( 80 err error 81 errMsg string 82 ) 83 84 switch conn.state { 85 case connRead: 86 err = c.onRead(cqe, conn) 87 if err != nil { 88 errMsg = "read error" 89 } 90 91 case connWrite: 92 n := int(cqe.Res) 93 conn.onKernelWrite(n) 94 c.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes writed") 95 96 conn.setUserSpace() 97 c.eventHandler.OnWrite(conn, n) 98 99 err = c.addNextRequest(conn) 100 if err != nil { 101 errMsg = "add request error" 102 } 103 104 case connClose: 105 if cqe.UserData&closeConnFlag > 0 { 106 c.closeConn(conn, false, nil) 107 } else if cqe.UserData&writeDataFlag > 0 { 108 n := int(cqe.Res) 109 conn.onKernelWrite(n) 110 c.logDebug().Int("fd", conn.fd).Int32("count", cqe.Res).Msg("Bytes writed") 111 conn.setUserSpace() 112 c.eventHandler.OnWrite(conn, n) 113 } 114 115 default: 116 err = gainErrors.ErrorUnknownConnectionState(int(conn.state)) 117 } 118 119 if err != nil { 120 c.logError(err).Msg(errMsg) 121 c.closeConn(conn, true, err) 122 } 123 } 124 125 func (c *consumerWorker) handleNewConn(fd int) error { 126 conn := c.connectionManager.getFd(fd) 127 conn.fd = fd 128 conn.localAddr = c.localAddr 129 130 if remoteAddr, ok := c.socketAddresses.Load(fd); ok { 131 conn.remoteAddr, _ = remoteAddr.(net.Addr) 132 133 c.socketAddresses.Delete(fd) 134 } else { 135 c.logError(gainErrors.ErrorAddressNotFound(fd)).Msg("Get connection address error") 136 } 137 138 conn.setUserSpace() 139 c.eventHandler.OnAccept(conn) 140 141 return c.addNextRequest(conn) 142 } 143 144 func (c *consumerWorker) getConnsFromQueue() { 145 for { 146 if c.connQueue.IsEmpty() { 147 break 148 } 149 fd := c.connQueue.Dequeue() 150 151 err := c.handleNewConn(fd) 152 if err != nil { 153 c.logError(err).Msg("add request error") 154 } 155 } 156 } 157 158 func (c *consumerWorker) handleJobsInQueues() { 159 if c.connQueue != nil { 160 c.getConnsFromQueue() 161 } 162 163 c.handleAsyncWritesIfEnabled() 164 } 165 166 func (c *consumerWorker) loop(_ int) error { 167 c.logInfo().Msg("Starting consumer loop...") 168 c.prepareHandler = func() error { 169 c.startedChan <- done 170 171 return nil 172 } 173 c.shutdownHandler = func() bool { 174 if c.needToShutdown() { 175 c.onCloseHandler() 176 c.markShutdownInProgress() 177 } 178 179 return true 180 } 181 c.loopFinisher = c.handleJobsInQueues 182 c.loopFinishCondition = func() bool { 183 if c.connectionManager.allClosed() { 184 c.close() 185 c.notifyFinish() 186 187 return true 188 } 189 190 return false 191 } 192 193 return c.looper.startLoop(c.index(), func(cqe *giouring.CompletionQueueEvent) error { 194 if exit := c.processEvent(cqe, func(cqe *giouring.CompletionQueueEvent) bool { 195 keyOrFd := cqe.UserData & ^allFlagsMask 196 197 return c.connectionManager.get(int(keyOrFd), 0) == nil 198 }); exit { 199 return nil 200 } 201 if cqe.UserData&addConnFlag > 0 { 202 fd := int(cqe.Res) 203 204 return c.handleNewConn(fd) 205 } 206 fileDescriptor := int(cqe.UserData & ^allFlagsMask) 207 if fileDescriptor < syscall.Stderr { 208 c.logError(nil).Int("fd", fileDescriptor).Msg("Invalid file descriptor") 209 210 return nil 211 } 212 conn := c.connectionManager.getFd(fileDescriptor) 213 c.handleConn(conn, cqe) 214 215 return nil 216 }) 217 } 218 219 func newConsumerWorker( 220 index int, localAddr net.Addr, config consumerConfig, eventHandler EventHandler, features supportedFeatures, 221 ) (*consumerWorker, error) { 222 ring, err := giouring.CreateRing(uint32(config.maxSQEntries)) 223 if err != nil { 224 return nil, fmt.Errorf("creating ring error: %w", err) 225 } 226 logger := logger.NewLogger("consumer", config.loggerLevel, config.prettyLogger) 227 connectionManager := newConnectionManager() 228 consumer := &consumerWorker{ 229 config: config, 230 readWriteWorkerImpl: newReadWriteWorkerImpl( 231 ring, index, localAddr, eventHandler, connectionManager, config.readWriteWorkerConfig, logger, 232 ), 233 } 234 235 if !features.ringsMessaging { 236 consumer.connQueue = queue.NewIntQueue() 237 } 238 consumer.onCloseHandler = consumer.closeAllConns 239 240 return consumer, nil 241 }