github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/server.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/tls"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"math"
    25  	"net"
    26  	"sync"
    27  	"sync/atomic"
    28  	"time"
    29  
    30  	"github.com/datastax/go-cassandra-native-protocol/message"
    31  	"github.com/datastax/go-cassandra-native-protocol/segment"
    32  
    33  	"github.com/rs/zerolog/log"
    34  
    35  	"github.com/datastax/go-cassandra-native-protocol/frame"
    36  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    37  )
    38  
    39  const (
    40  	DefaultAcceptTimeout = time.Second * 60
    41  	DefaultIdleTimeout   = time.Hour
    42  )
    43  
    44  const DefaultMaxConnections = 128
    45  
    46  const (
    47  	ServerStateNotStarted = int32(iota)
    48  	ServerStateRunning    = int32(iota)
    49  	ServerStateClosed     = int32(iota)
    50  )
    51  
    52  // RequestHandlerContext is the RequestHandler invocation context. Each invocation of a given RequestHandler will be
    53  // passed one instance of a RequestHandlerContext, that remains the same between invocations. This allows
    54  // handlers to become stateful if required.
    55  type RequestHandlerContext interface {
    56  	// PutAttribute puts the given value in this context under the given key name.
    57  	// Will override any previously-stored value under that key.
    58  	PutAttribute(name string, value interface{})
    59  	// GetAttribute retrieves the value stored in this context under the given key name.
    60  	// Returns nil if nil is stored, or if the key does not exist.
    61  	GetAttribute(name string) interface{}
    62  }
    63  
    64  type requestHandlerContext map[string]interface{}
    65  
    66  func (ctx requestHandlerContext) PutAttribute(name string, value interface{}) {
    67  	ctx[name] = value
    68  }
    69  
    70  func (ctx requestHandlerContext) GetAttribute(name string) interface{} {
    71  	return ctx[name]
    72  }
    73  
    74  // RequestHandler is a callback function that gets invoked whenever a CqlServerConnection receives an incoming
    75  // frame. The handler function should inspect the request frame and determine if it can handle the response for it.
    76  // If so, it should return a non-nil response frame. When that happens, no further handlers will be tried for the
    77  // incoming request.
    78  // If a handler returns nil, it is assumed that it was not able to handle the request, in which case another handler,
    79  // if any, may be tried.
    80  type RequestHandler func(request *frame.Frame, conn *CqlServerConnection, ctx RequestHandlerContext) (response *frame.Frame)
    81  
    82  // RawRequestHandler is similar to RequestHandler but returns an already encoded response in byte slice format, this can be used to return responses that the
    83  // embedded codecs can't encode
    84  type RawRequestHandler func(request *frame.Frame, conn *CqlServerConnection, ctx RequestHandlerContext) (encodedResponse []byte)
    85  
    86  // CqlServer is a minimalistic server stub that can be used to mimic CQL-compatible backends. It is preferable to
    87  // create CqlServer instances using the constructor function NewCqlServer. Once the server is properly created and
    88  // configured, use Start to start the server, then call Accept or AcceptAny to accept incoming client connections.
    89  type CqlServer struct {
    90  	// ListenAddress is the address to listen to.
    91  	ListenAddress string
    92  	// Credentials is the AuthCredentials to use. If nil, no authentication will be used; otherwise, clients will be
    93  	// required to authenticate with plain-text auth using the same credentials.
    94  	Credentials *AuthCredentials
    95  	// MaxConnections is the maximum number of open client connections to accept. Must be strictly positive.
    96  	MaxConnections int
    97  	// MaxInFlight is the maximum number of in-flight requests to apply for each connection created with Accept. Must
    98  	// be strictly positive.
    99  	MaxInFlight int
   100  	// AcceptTimeout is the timeout to apply when accepting new connections.
   101  	AcceptTimeout time.Duration
   102  	// IdleTimeout is the timeout to apply for closing idle connections.
   103  	IdleTimeout time.Duration
   104  	// RequestHandlers is an optional list of handlers to handle incoming requests.
   105  	RequestHandlers []RequestHandler
   106  	// RequestRawHandlers is an optional list of handlers to handle incoming requests and return a response in a byte slice format.
   107  	RequestRawHandlers []RawRequestHandler
   108  	// TLSConfig is the TLS configuration to use.
   109  	TLSConfig *tls.Config
   110  
   111  	ctx                context.Context
   112  	cancel             context.CancelFunc
   113  	listener           net.Listener
   114  	connectionsHandler *clientConnectionHandler
   115  	waitGroup          *sync.WaitGroup
   116  	state              int32
   117  }
   118  
   119  // NewCqlServer creates a new CqlServer with default options. Leave credentials nil to opt out from authentication.
   120  func NewCqlServer(listenAddress string, credentials *AuthCredentials) *CqlServer {
   121  	return &CqlServer{
   122  		ListenAddress:  listenAddress,
   123  		Credentials:    credentials,
   124  		MaxConnections: DefaultMaxConnections,
   125  		MaxInFlight:    DefaultMaxInFlight,
   126  		AcceptTimeout:  DefaultAcceptTimeout,
   127  		IdleTimeout:    DefaultIdleTimeout,
   128  	}
   129  }
   130  
   131  func (server *CqlServer) String() string {
   132  	return fmt.Sprintf("CQL server [%v]", server.ListenAddress)
   133  }
   134  
   135  func (server *CqlServer) getState() int32 {
   136  	return atomic.LoadInt32(&server.state)
   137  }
   138  
   139  func (server *CqlServer) IsNotStarted() bool {
   140  	return server.getState() == ServerStateNotStarted
   141  }
   142  
   143  func (server *CqlServer) IsRunning() bool {
   144  	return server.getState() == ServerStateRunning
   145  }
   146  
   147  func (server *CqlServer) IsClosed() bool {
   148  	return server.getState() == ServerStateClosed
   149  }
   150  
   151  func (server *CqlServer) transitionState(old int32, new int32) bool {
   152  	return atomic.CompareAndSwapInt32(&server.state, old, new)
   153  }
   154  
   155  // Start starts the server and binds to its listen address. This method must be called before calling Accept.
   156  // Set ctx to context.Background if no parent context exists.
   157  func (server *CqlServer) Start(ctx context.Context) (err error) {
   158  	if ctx == nil {
   159  		return fmt.Errorf("context cannot be nil")
   160  	}
   161  	if server.transitionState(ServerStateNotStarted, ServerStateRunning) {
   162  		log.Debug().Msgf("%v: server is starting", server)
   163  		server.connectionsHandler, err = newClientConnectionHandler(server.String(), server.MaxConnections)
   164  		if err != nil {
   165  			return fmt.Errorf("%v: start failed: %w", server, err)
   166  		}
   167  		if server.TLSConfig != nil {
   168  			server.listener, err = tls.Listen("tcp", server.ListenAddress, server.TLSConfig)
   169  		} else {
   170  			server.listener, err = net.Listen("tcp", server.ListenAddress)
   171  		}
   172  		if err != nil {
   173  			return fmt.Errorf("%v: start failed: %w", server, err)
   174  		}
   175  		server.ctx, server.cancel = context.WithCancel(ctx)
   176  		server.waitGroup = &sync.WaitGroup{}
   177  		server.acceptLoop()
   178  		server.awaitDone()
   179  		log.Info().Msgf("%v: successfully started", server)
   180  	} else {
   181  		log.Debug().Msgf("%v: already started or closed", server)
   182  	}
   183  	return err
   184  }
   185  
   186  func (server *CqlServer) Close() (err error) {
   187  	if server.transitionState(ServerStateRunning, ServerStateClosed) {
   188  		log.Debug().Msgf("%v: closing", server)
   189  		err = server.listener.Close()
   190  		server.connectionsHandler.close()
   191  		server.cancel()
   192  		server.waitGroup.Wait()
   193  		if err != nil {
   194  			log.Debug().Err(err).Msgf("%v: could not close server", server)
   195  			err = fmt.Errorf("%v: could not close server: %w", server, err)
   196  		} else {
   197  			log.Info().Msgf("%v: successfully closed", server)
   198  		}
   199  	} else {
   200  		log.Debug().Msgf("%v: not started or already closed", server)
   201  	}
   202  	return err
   203  }
   204  
   205  func (server *CqlServer) abort() {
   206  	log.Debug().Msgf("%v: forcefully closing", server)
   207  	if err := server.Close(); err != nil {
   208  		log.Error().Err(err).Msgf("%v: error closing", server)
   209  	}
   210  }
   211  
   212  func (server *CqlServer) acceptLoop() {
   213  	server.waitGroup.Add(1)
   214  	go func() {
   215  		abort := false
   216  		for server.IsRunning() {
   217  			if conn, err := server.listener.Accept(); err != nil {
   218  				if !server.IsClosed() {
   219  					log.Error().Err(err).Msgf("%v: error accepting client connections, closing server", server)
   220  					abort = true
   221  				}
   222  				break
   223  			} else {
   224  				log.Debug().Msgf("%v: new TCP connection accepted", server)
   225  				if connection, err := newCqlServerConnection(
   226  					conn,
   227  					server.ctx,
   228  					server.Credentials,
   229  					server.MaxInFlight,
   230  					server.IdleTimeout,
   231  					server.RequestHandlers,
   232  					server.RequestRawHandlers,
   233  					server.connectionsHandler.onConnectionClosed,
   234  				); err != nil {
   235  					log.Error().Msgf("%v: failed to accept incoming CQL client connection: %v", server, connection)
   236  					_ = conn.Close()
   237  				} else if err := server.connectionsHandler.onConnectionAccepted(connection); err != nil {
   238  					log.Error().Msgf("%v: handler rejected incoming CQL client connection: %v", server, connection)
   239  					_ = conn.Close()
   240  				} else {
   241  					log.Info().Msgf("%v: accepted new incoming CQL client connection: %v", server, connection)
   242  				}
   243  			}
   244  		}
   245  		server.waitGroup.Done()
   246  		if abort {
   247  			server.abort()
   248  		}
   249  	}()
   250  }
   251  
   252  func (server *CqlServer) awaitDone() {
   253  	server.waitGroup.Add(1)
   254  	go func() {
   255  		<-server.ctx.Done()
   256  		log.Debug().Err(server.ctx.Err()).Msgf("%v: context was closed", server)
   257  		server.waitGroup.Done()
   258  		server.abort()
   259  	}()
   260  }
   261  
   262  // Accept waits until the given client address is accepted, the configured timeout is triggered, or the server is
   263  // closed, whichever happens first.
   264  func (server *CqlServer) Accept(client *CqlClientConnection) (*CqlServerConnection, error) {
   265  	if server.IsClosed() {
   266  		return nil, fmt.Errorf("%v: server closed", server)
   267  	}
   268  	log.Debug().Msgf("%v: waiting for incoming client connection to be accepted: %v", server, client)
   269  	if serverConnectionChannel, err := server.connectionsHandler.onConnectionAcceptRequested(client); err != nil {
   270  		return nil, err
   271  	} else {
   272  		select {
   273  		case serverConnection, ok := <-serverConnectionChannel:
   274  			if !ok {
   275  				return nil, fmt.Errorf("%v: incoming client connection channel closed unexpectedly", server)
   276  			}
   277  			log.Debug().Msgf("%v: returning accepted client connection: %v", server, serverConnection)
   278  			return serverConnection, nil
   279  		case <-time.After(server.AcceptTimeout):
   280  			return nil, fmt.Errorf("%v: timed out waiting for incoming client connection", server)
   281  		}
   282  	}
   283  }
   284  
   285  // AcceptAny waits until any client is accepted, the configured timeout is triggered, or the server is closed,
   286  // whichever happens first. This method is useful when the client is not known in advance.
   287  func (server *CqlServer) AcceptAny() (*CqlServerConnection, error) {
   288  	if server.IsClosed() {
   289  		return nil, fmt.Errorf("%v: server closed", server)
   290  	}
   291  	log.Debug().Msgf("%v: waiting for any incoming client connection to be accepted", server)
   292  	anyConn := server.connectionsHandler.anyConnectionChannel()
   293  	select {
   294  	case serverConnection, ok := <-anyConn:
   295  		if !ok {
   296  			return nil, fmt.Errorf("%v: incoming client connection channel closed unexpectedly", server)
   297  		}
   298  		log.Debug().Msgf("%v: returning accepted client connection: %v", server, serverConnection)
   299  		return serverConnection, nil
   300  	case <-time.After(server.AcceptTimeout):
   301  		return nil, fmt.Errorf("%v: timed out waiting for incoming client connection", server)
   302  	}
   303  }
   304  
   305  // AllAcceptedClients returns a list of all the currently active server connections.
   306  func (server *CqlServer) AllAcceptedClients() ([]*CqlServerConnection, error) {
   307  	if server.IsClosed() {
   308  		return nil, fmt.Errorf("%v: server closed", server)
   309  	}
   310  	return server.connectionsHandler.allAcceptedClients(), nil
   311  }
   312  
   313  // Bind is a convenience method to connect a CqlClient to this CqlServer. The returned connections will be open, but not
   314  // initialized (i.e., no handshake performed). The server must be started prior to calling this method.
   315  func (server *CqlServer) Bind(client *CqlClient, ctx context.Context) (*CqlClientConnection, *CqlServerConnection, error) {
   316  	if server.IsNotStarted() {
   317  		return nil, nil, fmt.Errorf("%v: server not started", server)
   318  	} else if server.IsClosed() {
   319  		return nil, nil, fmt.Errorf("%v: server closed", server)
   320  	} else if clientConn, err := client.Connect(ctx); err != nil {
   321  		return nil, nil, fmt.Errorf("%v: bind failed, client %v could not connect: %w", server, client, err)
   322  	} else if serverConn, err := server.Accept(clientConn); err != nil {
   323  		return nil, nil, fmt.Errorf("%v: bind failed, client %v wasn't accepted: %w", server, client, err)
   324  	} else {
   325  		log.Debug().Msgf("%v: bind successful: %v", server, serverConn)
   326  		return clientConn, serverConn, nil
   327  	}
   328  }
   329  
   330  // BindAndInit is a convenience method to connect a CqlClient to this CqlServer. The returned connections will be open
   331  // and initialized (i.e., handshake is already performed). The server must be started prior to calling this method.
   332  // Use stream id zero to activate automatic stream id management.
   333  func (server *CqlServer) BindAndInit(
   334  	client *CqlClient,
   335  	ctx context.Context,
   336  	version primitive.ProtocolVersion,
   337  	streamId int16,
   338  ) (*CqlClientConnection, *CqlServerConnection, error) {
   339  	if clientConn, serverConn, err := server.Bind(client, ctx); err != nil {
   340  		return nil, nil, err
   341  	} else {
   342  		return clientConn, serverConn, PerformHandshake(clientConn, serverConn, version, streamId)
   343  	}
   344  }
   345  
   346  type response struct {
   347  	responseFrame *frame.Frame
   348  	rawResponse   []byte
   349  }
   350  
   351  func newFrameResponse(frameResponse *frame.Frame) *response {
   352  	return &response{
   353  		responseFrame: frameResponse,
   354  	}
   355  }
   356  
   357  func newRawResponse(rawResponse []byte) *response {
   358  	return &response{
   359  		rawResponse: rawResponse,
   360  	}
   361  }
   362  
   363  // CqlServerConnection encapsulates a TCP server connection to a remote CQL client.
   364  // CqlServerConnection instances should be created by calling CqlServer.Accept or CqlServer.Bind.
   365  type CqlServerConnection struct {
   366  	conn               net.Conn
   367  	credentials        *AuthCredentials
   368  	frameCodec         frame.Codec
   369  	segmentCodec       segment.Codec
   370  	compression        primitive.Compression
   371  	modernLayout       bool
   372  	idleTimeout        time.Duration
   373  	handlers           []RequestHandler
   374  	rawHandlers        []RawRequestHandler
   375  	handlerCtx         []RequestHandlerContext
   376  	incoming           chan *frame.Frame
   377  	outgoing           chan *response
   378  	waitGroup          *sync.WaitGroup
   379  	closed             int32
   380  	onClose            func(*CqlServerConnection)
   381  	ctx                context.Context
   382  	cancel             context.CancelFunc
   383  	payloadAccumulator *payloadAccumulator
   384  }
   385  
   386  func newCqlServerConnection(
   387  	conn net.Conn,
   388  	ctx context.Context,
   389  	credentials *AuthCredentials,
   390  	maxInFlight int,
   391  	idleTimeout time.Duration,
   392  	handlers []RequestHandler,
   393  	rawHandlers []RawRequestHandler,
   394  	onClose func(*CqlServerConnection),
   395  ) (*CqlServerConnection, error) {
   396  	if conn == nil {
   397  		return nil, fmt.Errorf("TCP connection cannot be nil")
   398  	}
   399  	if maxInFlight < 1 {
   400  		return nil, fmt.Errorf("max in-flight: expecting positive, got: %v", maxInFlight)
   401  	} else if maxInFlight > math.MaxInt16 {
   402  		return nil, fmt.Errorf("max in-flight: expecting <= %v, got: %v", math.MaxInt16, maxInFlight)
   403  	}
   404  	frameCodec := frame.NewCodec()
   405  	segmentCodec := segment.NewCodec()
   406  	connection := &CqlServerConnection{
   407  		conn:         conn,
   408  		frameCodec:   frameCodec,
   409  		segmentCodec: segmentCodec,
   410  		compression:  primitive.CompressionNone,
   411  		credentials:  credentials,
   412  		idleTimeout:  idleTimeout,
   413  		handlers:     handlers,
   414  		rawHandlers:  rawHandlers,
   415  		handlerCtx:   make([]RequestHandlerContext, len(handlers)),
   416  		incoming:     make(chan *frame.Frame, maxInFlight),
   417  		outgoing:     make(chan *response, maxInFlight),
   418  		waitGroup:    &sync.WaitGroup{},
   419  		onClose:      onClose,
   420  	}
   421  	for i := range handlers {
   422  		connection.handlerCtx[i] = requestHandlerContext{}
   423  	}
   424  	connection.ctx, connection.cancel = context.WithCancel(ctx)
   425  	connection.incomingLoop()
   426  	connection.outgoingLoop()
   427  	connection.awaitDone()
   428  	return connection, nil
   429  }
   430  
   431  func (c *CqlServerConnection) String() string {
   432  	return fmt.Sprintf("CQL server conn [L:%v <-> R:%v]", c.conn.LocalAddr(), c.conn.RemoteAddr())
   433  }
   434  
   435  // LocalAddr Returns the connection's local address (that is, the client address).
   436  func (c *CqlServerConnection) LocalAddr() net.Addr {
   437  	return c.conn.LocalAddr()
   438  }
   439  
   440  // RemoteAddr Returns the connection's remote address (that is, the server address).
   441  func (c *CqlServerConnection) RemoteAddr() net.Addr {
   442  	return c.conn.RemoteAddr()
   443  }
   444  
   445  // Credentials Returns a copy of the connection's AuthCredentials, if any, or nil if no authentication was configured.
   446  func (c *CqlServerConnection) Credentials() *AuthCredentials {
   447  	if c.credentials == nil {
   448  		return nil
   449  	}
   450  	return c.credentials.Copy()
   451  }
   452  
   453  func (c *CqlServerConnection) GetConn() net.Conn {
   454  	return c.conn
   455  }
   456  
   457  func (c *CqlServerConnection) incomingLoop() {
   458  	log.Debug().Msgf("%v: listening for incoming frames...", c)
   459  	c.waitGroup.Add(1)
   460  	go func() {
   461  		abort := false
   462  		for !abort && !c.IsClosed() {
   463  			if abort = c.setIdleTimeout(); !abort {
   464  				if source, err := c.waitForIncomingData(); err != nil {
   465  					abort = c.reportConnectionFailure(err, true)
   466  				} else if c.modernLayout {
   467  					abort = c.readSegment(source)
   468  				} else {
   469  					abort = c.readFrame(source)
   470  				}
   471  			}
   472  		}
   473  		c.waitGroup.Done()
   474  		if abort {
   475  			c.abort()
   476  		}
   477  	}()
   478  }
   479  
   480  func (c *CqlServerConnection) outgoingLoop() {
   481  	log.Debug().Msgf("%v: listening for outgoing frames...", c)
   482  	c.waitGroup.Add(1)
   483  	go func() {
   484  		abort := false
   485  		for !c.IsClosed() {
   486  			if outgoing, ok := <-c.outgoing; !ok {
   487  				if !c.IsClosed() {
   488  					log.Error().Msgf("%v: outgoing frame channel was closed unexpectedly, closing connection", c)
   489  					abort = true
   490  				}
   491  				break
   492  			} else {
   493  				if outgoing.rawResponse != nil {
   494  					abort = c.writeRawResponse(outgoing.rawResponse, c.conn)
   495  					log.Debug().Msgf("%v: sending outgoing raw response: %v", c, outgoing.rawResponse)
   496  				} else {
   497  					if c.compression != primitive.CompressionNone {
   498  						outgoing.responseFrame.Header.Flags = outgoing.responseFrame.Header.Flags.Add(primitive.HeaderFlagCompressed)
   499  					}
   500  					log.Debug().Msgf("%v: sending outgoing frame: %v", c, outgoing.responseFrame)
   501  					if c.modernLayout {
   502  						// TODO write coalescer
   503  						abort = c.writeSegment(outgoing.responseFrame, c.conn)
   504  					} else {
   505  						abort = c.writeFrame(outgoing.responseFrame, c.conn)
   506  					}
   507  				}
   508  			}
   509  		}
   510  		c.waitGroup.Done()
   511  		if abort {
   512  			c.abort()
   513  		}
   514  	}()
   515  }
   516  
   517  func (c *CqlServerConnection) waitForIncomingData() (io.Reader, error) {
   518  	buf := make([]byte, 1)
   519  	if _, err := io.ReadFull(c.conn, buf); err != nil {
   520  		return nil, err
   521  	} else {
   522  		return io.MultiReader(bytes.NewReader(buf), c.conn), nil
   523  	}
   524  }
   525  
   526  func (c *CqlServerConnection) setIdleTimeout() (abort bool) {
   527  	if err := c.conn.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil {
   528  		if !c.IsClosed() {
   529  			log.Error().Err(err).Msgf("%v: error setting idle timeout, closing connection", c)
   530  			abort = true
   531  		}
   532  	}
   533  	return abort
   534  }
   535  
   536  func (c *CqlServerConnection) readSegment(source io.Reader) (abort bool) {
   537  	if incoming, err := c.segmentCodec.DecodeSegment(source); err != nil {
   538  		abort = c.reportConnectionFailure(err, true)
   539  	} else if incoming.Header.IsSelfContained {
   540  		log.Debug().Msgf("%v: received incoming self-contained segment: %v", c, incoming)
   541  		abort = c.readSelfContainedSegment(incoming, abort)
   542  	} else {
   543  		log.Debug().Msgf("%v: received incoming multi-segment part: %v", c, incoming)
   544  		abort = c.addMultiSegmentPayload(incoming.Payload)
   545  	}
   546  	return abort
   547  }
   548  
   549  func (c *CqlServerConnection) readSelfContainedSegment(incoming *segment.Segment, abort bool) bool {
   550  	payloadReader := bytes.NewReader(incoming.Payload.UncompressedData)
   551  	for payloadReader.Len() > 0 {
   552  		if abort = c.readFrame(payloadReader); abort {
   553  			break
   554  		}
   555  	}
   556  	return abort
   557  }
   558  
   559  func (c *CqlServerConnection) addMultiSegmentPayload(payload *segment.Payload) (abort bool) {
   560  	accumulator := c.payloadAccumulator
   561  	if accumulator.targetLength == 0 {
   562  		// First reader, read ahead to find the target length
   563  		if header, err := accumulator.frameCodec.DecodeHeader(bytes.NewReader(payload.UncompressedData)); err != nil {
   564  			log.Error().Err(err).Msgf("%v: error decoding first frame header in multi-segment payload, closing connection", c)
   565  			return true
   566  		} else {
   567  			accumulator.targetLength = int(primitive.FrameHeaderLengthV3AndHigher + header.BodyLength)
   568  		}
   569  	}
   570  	accumulator.accumulatedData = append(accumulator.accumulatedData, payload.UncompressedData...)
   571  	if accumulator.targetLength == len(accumulator.accumulatedData) {
   572  		// We've received enough data to reassemble the whole frame
   573  		encodedFrame := bytes.NewReader(accumulator.accumulatedData)
   574  		accumulator.reset()
   575  		return c.readFrame(encodedFrame)
   576  	}
   577  	return false
   578  }
   579  
   580  func (c *CqlServerConnection) writeSegment(outgoing *frame.Frame, dest io.Writer) (abort bool) {
   581  	// never compress frames individually when included in a segment
   582  	outgoing.Header.Flags.Remove(primitive.HeaderFlagCompressed)
   583  	encodedFrame := &bytes.Buffer{}
   584  	if abort = c.writeFrame(outgoing, encodedFrame); abort {
   585  		abort = true
   586  	} else {
   587  		seg := &segment.Segment{
   588  			Header:  &segment.Header{IsSelfContained: true},
   589  			Payload: &segment.Payload{UncompressedData: encodedFrame.Bytes()},
   590  		}
   591  		if err := c.segmentCodec.EncodeSegment(seg, dest); err != nil {
   592  			abort = c.reportConnectionFailure(err, false)
   593  		} else {
   594  			log.Debug().Msgf("%v: outgoing segment successfully written: %v (frame: %v)", c, seg, outgoing)
   595  		}
   596  	}
   597  	return abort
   598  }
   599  
   600  func (c *CqlServerConnection) readFrame(source io.Reader) (abort bool) {
   601  	if incoming, err := c.frameCodec.DecodeFrame(source); err != nil {
   602  		abort = c.reportConnectionFailure(err, true)
   603  	} else {
   604  		if startup, ok := incoming.Body.Message.(*message.Startup); ok {
   605  			c.compression = startup.GetCompression()
   606  			c.frameCodec = frame.NewCodecWithCompression(NewBodyCompressor(c.compression))
   607  			c.segmentCodec = segment.NewCodecWithCompression(NewPayloadCompressor(c.compression))
   608  		}
   609  		c.processIncomingFrame(incoming)
   610  	}
   611  	return abort
   612  }
   613  
   614  func (c *CqlServerConnection) writeFrame(outgoing *frame.Frame, dest io.Writer) (abort bool) {
   615  	c.maybeSwitchToModernLayout(outgoing)
   616  	if err := c.frameCodec.EncodeFrame(outgoing, dest); err != nil {
   617  		abort = c.reportConnectionFailure(err, false)
   618  	} else {
   619  		log.Debug().Msgf("%v: outgoing frame successfully written: %v", c, outgoing)
   620  	}
   621  	return abort
   622  }
   623  
   624  func (c *CqlServerConnection) writeRawResponse(outgoing []byte, dest io.Writer) (abort bool) {
   625  	if _, err := dest.Write(outgoing); err != nil {
   626  		abort = c.reportConnectionFailure(err, false)
   627  	} else {
   628  		log.Debug().Msgf("%v: outgoing raw response successfully written: %v", c, outgoing)
   629  	}
   630  	return abort
   631  }
   632  
   633  func (c *CqlServerConnection) maybeSwitchToModernLayout(outgoing *frame.Frame) {
   634  	if !c.modernLayout &&
   635  		outgoing.Header.Version.SupportsModernFramingLayout() &&
   636  		(isReady(outgoing) || isAuthenticate(outgoing)) {
   637  		// Changing this value could be racy if some incoming frame is being processed;
   638  		// but in theory, this should never happen during handshake.
   639  		log.Debug().Msgf("%v: switching to modern framing layout", c)
   640  		c.modernLayout = true
   641  	}
   642  }
   643  
   644  func (c *CqlServerConnection) reportConnectionFailure(err error, read bool) (abort bool) {
   645  	if !c.IsClosed() {
   646  		if errors.Is(err, io.EOF) {
   647  			log.Info().Msgf("%v: connection reset by peer, closing", c)
   648  		} else {
   649  			if read {
   650  				log.Error().Err(err).Msgf("%v: error reading, closing connection", c)
   651  			} else {
   652  				log.Error().Err(err).Msgf("%v: error writing, closing connection", c)
   653  			}
   654  		}
   655  		abort = true
   656  	}
   657  	return abort
   658  }
   659  
   660  func (c *CqlServerConnection) processIncomingFrame(incoming *frame.Frame) {
   661  	log.Debug().Msgf("%v: received incoming frame: %v", c, incoming)
   662  	select {
   663  	case c.incoming <- incoming:
   664  		log.Debug().Msgf("%v: incoming frame successfully delivered: %v", c, incoming)
   665  	default:
   666  		log.Error().Msgf("%v: incoming frames queue is full, discarding frame: %v", c, incoming)
   667  	}
   668  	if len(c.handlers) > 0 {
   669  		c.invokeRequestHandlers(incoming)
   670  	}
   671  }
   672  
   673  func (c *CqlServerConnection) awaitDone() {
   674  	c.waitGroup.Add(1)
   675  	go func() {
   676  		<-c.ctx.Done()
   677  		log.Debug().Err(c.ctx.Err()).Msgf("%v: context was closed", c)
   678  		c.waitGroup.Done()
   679  		c.abort()
   680  	}()
   681  }
   682  
   683  func (c *CqlServerConnection) invokeRequestHandlers(request *frame.Frame) {
   684  	c.waitGroup.Add(1)
   685  	go func() {
   686  		log.Debug().Msgf("%v: invoking request handlers for incoming request: %v", c, request)
   687  		var err error
   688  		var rawResponse []byte
   689  		for i, rawHandler := range c.rawHandlers {
   690  			if rawResponse = rawHandler(request, c, c.handlerCtx[i]); rawResponse != nil {
   691  				log.Debug().Msgf("%v: raw request handler %v produced response: %v", c, i, rawResponse)
   692  				if err = c.SendRaw(rawResponse); err != nil {
   693  					log.Error().Err(err).Msgf("%v: send failed for frame: %v", c, rawResponse)
   694  				}
   695  				break
   696  			}
   697  		}
   698  		if rawResponse == nil {
   699  			var response *frame.Frame
   700  			for i, handler := range c.handlers {
   701  				if response = handler(request, c, c.handlerCtx[i]); response != nil {
   702  					log.Debug().Msgf("%v: request handler %v produced response: %v", c, i, response)
   703  					if err = c.Send(response); err != nil {
   704  						log.Error().Err(err).Msgf("%v: send failed for frame: %v", c, response)
   705  					}
   706  					break
   707  				}
   708  			}
   709  			if response == nil {
   710  				log.Debug().Msgf("%v: no request handler could handle the request: %v", c, request)
   711  			}
   712  		}
   713  		c.waitGroup.Done()
   714  	}()
   715  }
   716  
   717  // Send sends the given response frame.
   718  func (c *CqlServerConnection) Send(f *frame.Frame) error {
   719  	if c.IsClosed() {
   720  		return fmt.Errorf("%v: connection closed", c)
   721  	}
   722  	log.Debug().Msgf("%v: enqueuing outgoing frame: %v", c, f)
   723  	select {
   724  	case c.outgoing <- newFrameResponse(f):
   725  		log.Debug().Msgf("%v: outgoing frame successfully enqueued: %v", c, f)
   726  		return nil
   727  	default:
   728  		return fmt.Errorf("%v: failed to enqueue outgoing frame: %v", c, f)
   729  	}
   730  }
   731  
   732  // SendRaw sends the given response frame (already encoded).
   733  func (c *CqlServerConnection) SendRaw(rawResponse []byte) error {
   734  	if c.IsClosed() {
   735  		return fmt.Errorf("%v: connection closed", c)
   736  	}
   737  	log.Debug().Msgf("%v: enqueuing outgoing raw response: %v", c, rawResponse)
   738  	select {
   739  	case c.outgoing <- newRawResponse(rawResponse):
   740  		log.Debug().Msgf("%v: outgoing frame successfully enqueued: %v", c, rawResponse)
   741  		return nil
   742  	default:
   743  		return fmt.Errorf("%v: failed to send outgoing raw response: %v", c, rawResponse)
   744  	}
   745  }
   746  
   747  // Receive waits until the next request frame is received, or the configured idle timeout is triggered, or the
   748  // connection itself is closed, whichever happens first.
   749  func (c *CqlServerConnection) Receive() (*frame.Frame, error) {
   750  	if c.IsClosed() {
   751  		return nil, fmt.Errorf("%v: connection closed", c)
   752  	}
   753  	log.Debug().Msgf("%v: waiting for incoming frame", c)
   754  	if incoming, ok := <-c.incoming; !ok {
   755  		if c.IsClosed() {
   756  			return nil, fmt.Errorf("%v: connection closed", c)
   757  		} else {
   758  			return nil, fmt.Errorf("%v: incoming frame channel closed unexpectedly", c)
   759  		}
   760  	} else {
   761  		log.Debug().Msgf("%v: incoming frame successfully received: %v", c, incoming)
   762  		return incoming, nil
   763  	}
   764  }
   765  
   766  func (c *CqlServerConnection) IsClosed() bool {
   767  	return atomic.LoadInt32(&c.closed) == 1
   768  }
   769  
   770  func (c *CqlServerConnection) setClosed() bool {
   771  	return atomic.CompareAndSwapInt32(&c.closed, 0, 1)
   772  }
   773  
   774  func (c *CqlServerConnection) Close() (err error) {
   775  	if c.setClosed() {
   776  		log.Debug().Msgf("%v: closing", c)
   777  		c.cancel()
   778  		err = c.conn.Close()
   779  		incoming := c.incoming
   780  		outgoing := c.outgoing
   781  		c.incoming = nil
   782  		c.outgoing = nil
   783  		close(incoming)
   784  		close(outgoing)
   785  		c.waitGroup.Wait()
   786  		c.onClose(c)
   787  		if err != nil {
   788  			err = fmt.Errorf("%v: error closing: %w", c, err)
   789  		} else {
   790  			log.Info().Msgf("%v: successfully closed", c)
   791  		}
   792  	} else {
   793  		log.Debug().Err(err).Msgf("%v: already closed", c)
   794  	}
   795  	return err
   796  }
   797  
   798  func (c *CqlServerConnection) abort() {
   799  	log.Debug().Msgf("%v: forcefully closing", c)
   800  	if err := c.Close(); err != nil {
   801  		log.Error().Err(err).Msgf("%v: error closing", c)
   802  	}
   803  }