github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/acceptor_worker.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain
    16  
    17  import (
    18  	"fmt"
    19  	"net"
    20  	"sync/atomic"
    21  	"time"
    22  
    23  	"github.com/pawelgaczynski/gain/logger"
    24  	"github.com/pawelgaczynski/gain/pkg/errors"
    25  	"github.com/pawelgaczynski/gain/pkg/socket"
    26  	"github.com/pawelgaczynski/giouring"
    27  )
    28  
    29  type acceptorWorkerConfig struct {
    30  	workerConfig
    31  	tcpKeepAlive time.Duration
    32  }
    33  
    34  type acceptorWorker struct {
    35  	*acceptor
    36  	*workerImpl
    37  	config        acceptorWorkerConfig
    38  	ring          *giouring.Ring
    39  	loadBalancer  loadBalancer
    40  	eventHandler  EventHandler
    41  	addConnMethod func(consumer, int32) error
    42  
    43  	accepting atomic.Bool
    44  }
    45  
    46  func (a *acceptorWorker) addConnViaRing(worker consumer, fileDescriptor int32) error {
    47  	entry := a.ring.GetSQE()
    48  	if entry == nil {
    49  		return errors.ErrGettingSQE
    50  	}
    51  
    52  	entry.PrepareMsgRing(worker.ringFd(), uint32(fileDescriptor), addConnFlag, 0)
    53  	entry.UserData = addConnFlag
    54  
    55  	return nil
    56  }
    57  
    58  func (a *acceptorWorker) addConnViaQueue(worker consumer, fileDescriptor int32) error {
    59  	err := worker.addConnToQueue(int(fileDescriptor))
    60  	if err != nil {
    61  		return fmt.Errorf("error adding connection to queue: %w", err)
    62  	}
    63  
    64  	return nil
    65  }
    66  
    67  func (a *acceptorWorker) registerConsumer(consumer consumer) {
    68  	a.loadBalancer.register(consumer)
    69  }
    70  
    71  func (a *acceptorWorker) closeRingAndConsumers() {
    72  	_ = a.syscallCloseSocket(a.fd)
    73  	a.accepting.Store(false)
    74  
    75  	err := a.loadBalancer.forEach(func(w consumer) error {
    76  		w.shutdown()
    77  
    78  		return nil
    79  	})
    80  	if err != nil {
    81  		a.logError(err).Msg("Closing consumers error")
    82  	}
    83  }
    84  
    85  func (a *acceptorWorker) loop(fd int) error {
    86  	a.logInfo().Int("fd", fd).Msg("Starting acceptor loop...")
    87  	a.fd = fd
    88  	a.prepareHandler = func() error {
    89  		err := a.addAcceptRequest()
    90  		if err == nil {
    91  			a.accepting.Store(true)
    92  		}
    93  
    94  		return err
    95  	}
    96  	a.shutdownHandler = func() bool {
    97  		if a.needToShutdown() {
    98  			a.onCloseHandler()
    99  			a.markShutdownInProgress()
   100  
   101  			return false
   102  		}
   103  
   104  		return true
   105  	}
   106  	err := a.startLoop(0, func(cqe *giouring.CompletionQueueEvent) error {
   107  		if cqe.UserData&addConnFlag > 0 {
   108  			return nil
   109  		}
   110  		err := a.addAcceptRequest()
   111  		if err != nil {
   112  			a.logError(err).
   113  				Int("fd", fd).
   114  				Msg("Add accept() request error")
   115  
   116  			return err
   117  		}
   118  
   119  		if exit := a.processEvent(cqe, func(cqe *giouring.CompletionQueueEvent) bool {
   120  			return !a.accepting.Load()
   121  		}); exit {
   122  			return nil
   123  		}
   124  
   125  		if a.config.tcpKeepAlive > 0 {
   126  			err = socket.SetKeepAlivePeriod(int(cqe.Res), int(a.config.tcpKeepAlive.Seconds()))
   127  			if err != nil {
   128  				return fmt.Errorf("fd: %d, setting tcpKeepAlive error: %w", cqe.Res, err)
   129  			}
   130  		}
   131  		var clientAddr net.Addr
   132  		clientAddr, err = a.lastClientAddr()
   133  
   134  		if err != nil {
   135  			a.logError(err).
   136  				Int("fd", fd).
   137  				Int32("conn fd", cqe.Res).
   138  				Msg("Getting client address error")
   139  			_ = a.syscallCloseSocket(int(cqe.Res))
   140  
   141  			return nil
   142  		}
   143  
   144  		nextConsumer := a.loadBalancer.next(clientAddr)
   145  		a.logDebug().
   146  			Int("fd", fd).
   147  			Int32("conn fd", cqe.Res).
   148  			Int("consumer", nextConsumer.index()).
   149  			Msg("Forwarding accepted connection to consumer")
   150  
   151  		nextConsumer.setSocketAddr(int(cqe.Res), clientAddr)
   152  
   153  		err = a.addConnMethod(nextConsumer, cqe.Res)
   154  		if err != nil {
   155  			a.logError(err).
   156  				Int("fd", fd).
   157  				Int32("conn fd", cqe.Res).
   158  				Msg("Add connection to consumer error")
   159  
   160  			_ = a.syscallCloseSocket(int(cqe.Res))
   161  
   162  			return nil
   163  		}
   164  
   165  		return nil
   166  	})
   167  
   168  	a.close()
   169  	a.notifyFinish()
   170  
   171  	return err
   172  }
   173  
   174  func newAcceptorWorker(
   175  	config acceptorWorkerConfig, loadBalancer loadBalancer, eventHandler EventHandler, features supportedFeatures,
   176  ) (*acceptorWorker, error) {
   177  	ring, err := giouring.CreateRing(uint32(config.maxSQEntries))
   178  	if err != nil {
   179  		return nil, fmt.Errorf("creating ring error: %w", err)
   180  	}
   181  	logger := logger.NewLogger("acceptor", config.loggerLevel, config.prettyLogger)
   182  	connectionManager := newConnectionManager()
   183  	acceptor := &acceptorWorker{
   184  		workerImpl:   newWorkerImpl(ring, config.workerConfig, 0, logger),
   185  		acceptor:     newAcceptor(ring, connectionManager),
   186  		config:       config,
   187  		ring:         ring,
   188  		eventHandler: eventHandler,
   189  		loadBalancer: loadBalancer,
   190  	}
   191  
   192  	if features.ringsMessaging {
   193  		acceptor.addConnMethod = acceptor.addConnViaRing
   194  	} else {
   195  		acceptor.addConnMethod = acceptor.addConnViaQueue
   196  	}
   197  	acceptor.onCloseHandler = acceptor.closeRingAndConsumers
   198  
   199  	return acceptor, nil
   200  }