github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/acceptor_worker.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gain 16 17 import ( 18 "fmt" 19 "net" 20 "sync/atomic" 21 "time" 22 23 "github.com/pawelgaczynski/gain/logger" 24 "github.com/pawelgaczynski/gain/pkg/errors" 25 "github.com/pawelgaczynski/gain/pkg/socket" 26 "github.com/pawelgaczynski/giouring" 27 ) 28 29 type acceptorWorkerConfig struct { 30 workerConfig 31 tcpKeepAlive time.Duration 32 } 33 34 type acceptorWorker struct { 35 *acceptor 36 *workerImpl 37 config acceptorWorkerConfig 38 ring *giouring.Ring 39 loadBalancer loadBalancer 40 eventHandler EventHandler 41 addConnMethod func(consumer, int32) error 42 43 accepting atomic.Bool 44 } 45 46 func (a *acceptorWorker) addConnViaRing(worker consumer, fileDescriptor int32) error { 47 entry := a.ring.GetSQE() 48 if entry == nil { 49 return errors.ErrGettingSQE 50 } 51 52 entry.PrepareMsgRing(worker.ringFd(), uint32(fileDescriptor), addConnFlag, 0) 53 entry.UserData = addConnFlag 54 55 return nil 56 } 57 58 func (a *acceptorWorker) addConnViaQueue(worker consumer, fileDescriptor int32) error { 59 err := worker.addConnToQueue(int(fileDescriptor)) 60 if err != nil { 61 return fmt.Errorf("error adding connection to queue: %w", err) 62 } 63 64 return nil 65 } 66 67 func (a *acceptorWorker) registerConsumer(consumer consumer) { 68 a.loadBalancer.register(consumer) 69 } 70 71 func (a *acceptorWorker) closeRingAndConsumers() { 72 _ = a.syscallCloseSocket(a.fd) 73 a.accepting.Store(false) 74 75 err := a.loadBalancer.forEach(func(w consumer) error { 76 w.shutdown() 77 78 return nil 79 }) 80 if err != nil { 81 a.logError(err).Msg("Closing consumers error") 82 } 83 } 84 85 func (a *acceptorWorker) loop(fd int) error { 86 a.logInfo().Int("fd", fd).Msg("Starting acceptor loop...") 87 a.fd = fd 88 a.prepareHandler = func() error { 89 err := a.addAcceptRequest() 90 if err == nil { 91 a.accepting.Store(true) 92 } 93 94 return err 95 } 96 a.shutdownHandler = func() bool { 97 if a.needToShutdown() { 98 a.onCloseHandler() 99 a.markShutdownInProgress() 100 101 return false 102 } 103 104 return true 105 } 106 err := a.startLoop(0, func(cqe *giouring.CompletionQueueEvent) error { 107 if cqe.UserData&addConnFlag > 0 { 108 return nil 109 } 110 err := a.addAcceptRequest() 111 if err != nil { 112 a.logError(err). 113 Int("fd", fd). 114 Msg("Add accept() request error") 115 116 return err 117 } 118 119 if exit := a.processEvent(cqe, func(cqe *giouring.CompletionQueueEvent) bool { 120 return !a.accepting.Load() 121 }); exit { 122 return nil 123 } 124 125 if a.config.tcpKeepAlive > 0 { 126 err = socket.SetKeepAlivePeriod(int(cqe.Res), int(a.config.tcpKeepAlive.Seconds())) 127 if err != nil { 128 return fmt.Errorf("fd: %d, setting tcpKeepAlive error: %w", cqe.Res, err) 129 } 130 } 131 var clientAddr net.Addr 132 clientAddr, err = a.lastClientAddr() 133 134 if err != nil { 135 a.logError(err). 136 Int("fd", fd). 137 Int32("conn fd", cqe.Res). 138 Msg("Getting client address error") 139 _ = a.syscallCloseSocket(int(cqe.Res)) 140 141 return nil 142 } 143 144 nextConsumer := a.loadBalancer.next(clientAddr) 145 a.logDebug(). 146 Int("fd", fd). 147 Int32("conn fd", cqe.Res). 148 Int("consumer", nextConsumer.index()). 149 Msg("Forwarding accepted connection to consumer") 150 151 nextConsumer.setSocketAddr(int(cqe.Res), clientAddr) 152 153 err = a.addConnMethod(nextConsumer, cqe.Res) 154 if err != nil { 155 a.logError(err). 156 Int("fd", fd). 157 Int32("conn fd", cqe.Res). 158 Msg("Add connection to consumer error") 159 160 _ = a.syscallCloseSocket(int(cqe.Res)) 161 162 return nil 163 } 164 165 return nil 166 }) 167 168 a.close() 169 a.notifyFinish() 170 171 return err 172 } 173 174 func newAcceptorWorker( 175 config acceptorWorkerConfig, loadBalancer loadBalancer, eventHandler EventHandler, features supportedFeatures, 176 ) (*acceptorWorker, error) { 177 ring, err := giouring.CreateRing(uint32(config.maxSQEntries)) 178 if err != nil { 179 return nil, fmt.Errorf("creating ring error: %w", err) 180 } 181 logger := logger.NewLogger("acceptor", config.loggerLevel, config.prettyLogger) 182 connectionManager := newConnectionManager() 183 acceptor := &acceptorWorker{ 184 workerImpl: newWorkerImpl(ring, config.workerConfig, 0, logger), 185 acceptor: newAcceptor(ring, connectionManager), 186 config: config, 187 ring: ring, 188 eventHandler: eventHandler, 189 loadBalancer: loadBalancer, 190 } 191 192 if features.ringsMessaging { 193 acceptor.addConnMethod = acceptor.addConnViaRing 194 } else { 195 acceptor.addConnMethod = acceptor.addConnViaQueue 196 } 197 acceptor.onCloseHandler = acceptor.closeRingAndConsumers 198 199 return acceptor, nil 200 }