github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/server.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package client 16 17 import ( 18 "bytes" 19 "context" 20 "crypto/tls" 21 "errors" 22 "fmt" 23 "io" 24 "math" 25 "net" 26 "sync" 27 "sync/atomic" 28 "time" 29 30 "github.com/datastax/go-cassandra-native-protocol/message" 31 "github.com/datastax/go-cassandra-native-protocol/segment" 32 33 "github.com/rs/zerolog/log" 34 35 "github.com/datastax/go-cassandra-native-protocol/frame" 36 "github.com/datastax/go-cassandra-native-protocol/primitive" 37 ) 38 39 const ( 40 DefaultAcceptTimeout = time.Second * 60 41 DefaultIdleTimeout = time.Hour 42 ) 43 44 const DefaultMaxConnections = 128 45 46 const ( 47 ServerStateNotStarted = int32(iota) 48 ServerStateRunning = int32(iota) 49 ServerStateClosed = int32(iota) 50 ) 51 52 // RequestHandlerContext is the RequestHandler invocation context. Each invocation of a given RequestHandler will be 53 // passed one instance of a RequestHandlerContext, that remains the same between invocations. This allows 54 // handlers to become stateful if required. 55 type RequestHandlerContext interface { 56 // PutAttribute puts the given value in this context under the given key name. 57 // Will override any previously-stored value under that key. 58 PutAttribute(name string, value interface{}) 59 // GetAttribute retrieves the value stored in this context under the given key name. 60 // Returns nil if nil is stored, or if the key does not exist. 61 GetAttribute(name string) interface{} 62 } 63 64 type requestHandlerContext map[string]interface{} 65 66 func (ctx requestHandlerContext) PutAttribute(name string, value interface{}) { 67 ctx[name] = value 68 } 69 70 func (ctx requestHandlerContext) GetAttribute(name string) interface{} { 71 return ctx[name] 72 } 73 74 // RequestHandler is a callback function that gets invoked whenever a CqlServerConnection receives an incoming 75 // frame. The handler function should inspect the request frame and determine if it can handle the response for it. 76 // If so, it should return a non-nil response frame. When that happens, no further handlers will be tried for the 77 // incoming request. 78 // If a handler returns nil, it is assumed that it was not able to handle the request, in which case another handler, 79 // if any, may be tried. 80 type RequestHandler func(request *frame.Frame, conn *CqlServerConnection, ctx RequestHandlerContext) (response *frame.Frame) 81 82 // RawRequestHandler is similar to RequestHandler but returns an already encoded response in byte slice format, this can be used to return responses that the 83 // embedded codecs can't encode 84 type RawRequestHandler func(request *frame.Frame, conn *CqlServerConnection, ctx RequestHandlerContext) (encodedResponse []byte) 85 86 // CqlServer is a minimalistic server stub that can be used to mimic CQL-compatible backends. It is preferable to 87 // create CqlServer instances using the constructor function NewCqlServer. Once the server is properly created and 88 // configured, use Start to start the server, then call Accept or AcceptAny to accept incoming client connections. 89 type CqlServer struct { 90 // ListenAddress is the address to listen to. 91 ListenAddress string 92 // Credentials is the AuthCredentials to use. If nil, no authentication will be used; otherwise, clients will be 93 // required to authenticate with plain-text auth using the same credentials. 94 Credentials *AuthCredentials 95 // MaxConnections is the maximum number of open client connections to accept. Must be strictly positive. 96 MaxConnections int 97 // MaxInFlight is the maximum number of in-flight requests to apply for each connection created with Accept. Must 98 // be strictly positive. 99 MaxInFlight int 100 // AcceptTimeout is the timeout to apply when accepting new connections. 101 AcceptTimeout time.Duration 102 // IdleTimeout is the timeout to apply for closing idle connections. 103 IdleTimeout time.Duration 104 // RequestHandlers is an optional list of handlers to handle incoming requests. 105 RequestHandlers []RequestHandler 106 // RequestRawHandlers is an optional list of handlers to handle incoming requests and return a response in a byte slice format. 107 RequestRawHandlers []RawRequestHandler 108 // TLSConfig is the TLS configuration to use. 109 TLSConfig *tls.Config 110 111 ctx context.Context 112 cancel context.CancelFunc 113 listener net.Listener 114 connectionsHandler *clientConnectionHandler 115 waitGroup *sync.WaitGroup 116 state int32 117 } 118 119 // NewCqlServer creates a new CqlServer with default options. Leave credentials nil to opt out from authentication. 120 func NewCqlServer(listenAddress string, credentials *AuthCredentials) *CqlServer { 121 return &CqlServer{ 122 ListenAddress: listenAddress, 123 Credentials: credentials, 124 MaxConnections: DefaultMaxConnections, 125 MaxInFlight: DefaultMaxInFlight, 126 AcceptTimeout: DefaultAcceptTimeout, 127 IdleTimeout: DefaultIdleTimeout, 128 } 129 } 130 131 func (server *CqlServer) String() string { 132 return fmt.Sprintf("CQL server [%v]", server.ListenAddress) 133 } 134 135 func (server *CqlServer) getState() int32 { 136 return atomic.LoadInt32(&server.state) 137 } 138 139 func (server *CqlServer) IsNotStarted() bool { 140 return server.getState() == ServerStateNotStarted 141 } 142 143 func (server *CqlServer) IsRunning() bool { 144 return server.getState() == ServerStateRunning 145 } 146 147 func (server *CqlServer) IsClosed() bool { 148 return server.getState() == ServerStateClosed 149 } 150 151 func (server *CqlServer) transitionState(old int32, new int32) bool { 152 return atomic.CompareAndSwapInt32(&server.state, old, new) 153 } 154 155 // Start starts the server and binds to its listen address. This method must be called before calling Accept. 156 // Set ctx to context.Background if no parent context exists. 157 func (server *CqlServer) Start(ctx context.Context) (err error) { 158 if ctx == nil { 159 return fmt.Errorf("context cannot be nil") 160 } 161 if server.transitionState(ServerStateNotStarted, ServerStateRunning) { 162 log.Debug().Msgf("%v: server is starting", server) 163 server.connectionsHandler, err = newClientConnectionHandler(server.String(), server.MaxConnections) 164 if err != nil { 165 return fmt.Errorf("%v: start failed: %w", server, err) 166 } 167 if server.TLSConfig != nil { 168 server.listener, err = tls.Listen("tcp", server.ListenAddress, server.TLSConfig) 169 } else { 170 server.listener, err = net.Listen("tcp", server.ListenAddress) 171 } 172 if err != nil { 173 return fmt.Errorf("%v: start failed: %w", server, err) 174 } 175 server.ctx, server.cancel = context.WithCancel(ctx) 176 server.waitGroup = &sync.WaitGroup{} 177 server.acceptLoop() 178 server.awaitDone() 179 log.Info().Msgf("%v: successfully started", server) 180 } else { 181 log.Debug().Msgf("%v: already started or closed", server) 182 } 183 return err 184 } 185 186 func (server *CqlServer) Close() (err error) { 187 if server.transitionState(ServerStateRunning, ServerStateClosed) { 188 log.Debug().Msgf("%v: closing", server) 189 err = server.listener.Close() 190 server.connectionsHandler.close() 191 server.cancel() 192 server.waitGroup.Wait() 193 if err != nil { 194 log.Debug().Err(err).Msgf("%v: could not close server", server) 195 err = fmt.Errorf("%v: could not close server: %w", server, err) 196 } else { 197 log.Info().Msgf("%v: successfully closed", server) 198 } 199 } else { 200 log.Debug().Msgf("%v: not started or already closed", server) 201 } 202 return err 203 } 204 205 func (server *CqlServer) abort() { 206 log.Debug().Msgf("%v: forcefully closing", server) 207 if err := server.Close(); err != nil { 208 log.Error().Err(err).Msgf("%v: error closing", server) 209 } 210 } 211 212 func (server *CqlServer) acceptLoop() { 213 server.waitGroup.Add(1) 214 go func() { 215 abort := false 216 for server.IsRunning() { 217 if conn, err := server.listener.Accept(); err != nil { 218 if !server.IsClosed() { 219 log.Error().Err(err).Msgf("%v: error accepting client connections, closing server", server) 220 abort = true 221 } 222 break 223 } else { 224 log.Debug().Msgf("%v: new TCP connection accepted", server) 225 if connection, err := newCqlServerConnection( 226 conn, 227 server.ctx, 228 server.Credentials, 229 server.MaxInFlight, 230 server.IdleTimeout, 231 server.RequestHandlers, 232 server.RequestRawHandlers, 233 server.connectionsHandler.onConnectionClosed, 234 ); err != nil { 235 log.Error().Msgf("%v: failed to accept incoming CQL client connection: %v", server, connection) 236 _ = conn.Close() 237 } else if err := server.connectionsHandler.onConnectionAccepted(connection); err != nil { 238 log.Error().Msgf("%v: handler rejected incoming CQL client connection: %v", server, connection) 239 _ = conn.Close() 240 } else { 241 log.Info().Msgf("%v: accepted new incoming CQL client connection: %v", server, connection) 242 } 243 } 244 } 245 server.waitGroup.Done() 246 if abort { 247 server.abort() 248 } 249 }() 250 } 251 252 func (server *CqlServer) awaitDone() { 253 server.waitGroup.Add(1) 254 go func() { 255 <-server.ctx.Done() 256 log.Debug().Err(server.ctx.Err()).Msgf("%v: context was closed", server) 257 server.waitGroup.Done() 258 server.abort() 259 }() 260 } 261 262 // Accept waits until the given client address is accepted, the configured timeout is triggered, or the server is 263 // closed, whichever happens first. 264 func (server *CqlServer) Accept(client *CqlClientConnection) (*CqlServerConnection, error) { 265 if server.IsClosed() { 266 return nil, fmt.Errorf("%v: server closed", server) 267 } 268 log.Debug().Msgf("%v: waiting for incoming client connection to be accepted: %v", server, client) 269 if serverConnectionChannel, err := server.connectionsHandler.onConnectionAcceptRequested(client); err != nil { 270 return nil, err 271 } else { 272 select { 273 case serverConnection, ok := <-serverConnectionChannel: 274 if !ok { 275 return nil, fmt.Errorf("%v: incoming client connection channel closed unexpectedly", server) 276 } 277 log.Debug().Msgf("%v: returning accepted client connection: %v", server, serverConnection) 278 return serverConnection, nil 279 case <-time.After(server.AcceptTimeout): 280 return nil, fmt.Errorf("%v: timed out waiting for incoming client connection", server) 281 } 282 } 283 } 284 285 // AcceptAny waits until any client is accepted, the configured timeout is triggered, or the server is closed, 286 // whichever happens first. This method is useful when the client is not known in advance. 287 func (server *CqlServer) AcceptAny() (*CqlServerConnection, error) { 288 if server.IsClosed() { 289 return nil, fmt.Errorf("%v: server closed", server) 290 } 291 log.Debug().Msgf("%v: waiting for any incoming client connection to be accepted", server) 292 anyConn := server.connectionsHandler.anyConnectionChannel() 293 select { 294 case serverConnection, ok := <-anyConn: 295 if !ok { 296 return nil, fmt.Errorf("%v: incoming client connection channel closed unexpectedly", server) 297 } 298 log.Debug().Msgf("%v: returning accepted client connection: %v", server, serverConnection) 299 return serverConnection, nil 300 case <-time.After(server.AcceptTimeout): 301 return nil, fmt.Errorf("%v: timed out waiting for incoming client connection", server) 302 } 303 } 304 305 // AllAcceptedClients returns a list of all the currently active server connections. 306 func (server *CqlServer) AllAcceptedClients() ([]*CqlServerConnection, error) { 307 if server.IsClosed() { 308 return nil, fmt.Errorf("%v: server closed", server) 309 } 310 return server.connectionsHandler.allAcceptedClients(), nil 311 } 312 313 // Bind is a convenience method to connect a CqlClient to this CqlServer. The returned connections will be open, but not 314 // initialized (i.e., no handshake performed). The server must be started prior to calling this method. 315 func (server *CqlServer) Bind(client *CqlClient, ctx context.Context) (*CqlClientConnection, *CqlServerConnection, error) { 316 if server.IsNotStarted() { 317 return nil, nil, fmt.Errorf("%v: server not started", server) 318 } else if server.IsClosed() { 319 return nil, nil, fmt.Errorf("%v: server closed", server) 320 } else if clientConn, err := client.Connect(ctx); err != nil { 321 return nil, nil, fmt.Errorf("%v: bind failed, client %v could not connect: %w", server, client, err) 322 } else if serverConn, err := server.Accept(clientConn); err != nil { 323 return nil, nil, fmt.Errorf("%v: bind failed, client %v wasn't accepted: %w", server, client, err) 324 } else { 325 log.Debug().Msgf("%v: bind successful: %v", server, serverConn) 326 return clientConn, serverConn, nil 327 } 328 } 329 330 // BindAndInit is a convenience method to connect a CqlClient to this CqlServer. The returned connections will be open 331 // and initialized (i.e., handshake is already performed). The server must be started prior to calling this method. 332 // Use stream id zero to activate automatic stream id management. 333 func (server *CqlServer) BindAndInit( 334 client *CqlClient, 335 ctx context.Context, 336 version primitive.ProtocolVersion, 337 streamId int16, 338 ) (*CqlClientConnection, *CqlServerConnection, error) { 339 if clientConn, serverConn, err := server.Bind(client, ctx); err != nil { 340 return nil, nil, err 341 } else { 342 return clientConn, serverConn, PerformHandshake(clientConn, serverConn, version, streamId) 343 } 344 } 345 346 type response struct { 347 responseFrame *frame.Frame 348 rawResponse []byte 349 } 350 351 func newFrameResponse(frameResponse *frame.Frame) *response { 352 return &response{ 353 responseFrame: frameResponse, 354 } 355 } 356 357 func newRawResponse(rawResponse []byte) *response { 358 return &response{ 359 rawResponse: rawResponse, 360 } 361 } 362 363 // CqlServerConnection encapsulates a TCP server connection to a remote CQL client. 364 // CqlServerConnection instances should be created by calling CqlServer.Accept or CqlServer.Bind. 365 type CqlServerConnection struct { 366 conn net.Conn 367 credentials *AuthCredentials 368 frameCodec frame.Codec 369 segmentCodec segment.Codec 370 compression primitive.Compression 371 modernLayout bool 372 idleTimeout time.Duration 373 handlers []RequestHandler 374 rawHandlers []RawRequestHandler 375 handlerCtx []RequestHandlerContext 376 incoming chan *frame.Frame 377 outgoing chan *response 378 waitGroup *sync.WaitGroup 379 closed int32 380 onClose func(*CqlServerConnection) 381 ctx context.Context 382 cancel context.CancelFunc 383 payloadAccumulator *payloadAccumulator 384 } 385 386 func newCqlServerConnection( 387 conn net.Conn, 388 ctx context.Context, 389 credentials *AuthCredentials, 390 maxInFlight int, 391 idleTimeout time.Duration, 392 handlers []RequestHandler, 393 rawHandlers []RawRequestHandler, 394 onClose func(*CqlServerConnection), 395 ) (*CqlServerConnection, error) { 396 if conn == nil { 397 return nil, fmt.Errorf("TCP connection cannot be nil") 398 } 399 if maxInFlight < 1 { 400 return nil, fmt.Errorf("max in-flight: expecting positive, got: %v", maxInFlight) 401 } else if maxInFlight > math.MaxInt16 { 402 return nil, fmt.Errorf("max in-flight: expecting <= %v, got: %v", math.MaxInt16, maxInFlight) 403 } 404 frameCodec := frame.NewCodec() 405 segmentCodec := segment.NewCodec() 406 connection := &CqlServerConnection{ 407 conn: conn, 408 frameCodec: frameCodec, 409 segmentCodec: segmentCodec, 410 compression: primitive.CompressionNone, 411 credentials: credentials, 412 idleTimeout: idleTimeout, 413 handlers: handlers, 414 rawHandlers: rawHandlers, 415 handlerCtx: make([]RequestHandlerContext, len(handlers)), 416 incoming: make(chan *frame.Frame, maxInFlight), 417 outgoing: make(chan *response, maxInFlight), 418 waitGroup: &sync.WaitGroup{}, 419 onClose: onClose, 420 } 421 for i := range handlers { 422 connection.handlerCtx[i] = requestHandlerContext{} 423 } 424 connection.ctx, connection.cancel = context.WithCancel(ctx) 425 connection.incomingLoop() 426 connection.outgoingLoop() 427 connection.awaitDone() 428 return connection, nil 429 } 430 431 func (c *CqlServerConnection) String() string { 432 return fmt.Sprintf("CQL server conn [L:%v <-> R:%v]", c.conn.LocalAddr(), c.conn.RemoteAddr()) 433 } 434 435 // LocalAddr Returns the connection's local address (that is, the client address). 436 func (c *CqlServerConnection) LocalAddr() net.Addr { 437 return c.conn.LocalAddr() 438 } 439 440 // RemoteAddr Returns the connection's remote address (that is, the server address). 441 func (c *CqlServerConnection) RemoteAddr() net.Addr { 442 return c.conn.RemoteAddr() 443 } 444 445 // Credentials Returns a copy of the connection's AuthCredentials, if any, or nil if no authentication was configured. 446 func (c *CqlServerConnection) Credentials() *AuthCredentials { 447 if c.credentials == nil { 448 return nil 449 } 450 return c.credentials.Copy() 451 } 452 453 func (c *CqlServerConnection) GetConn() net.Conn { 454 return c.conn 455 } 456 457 func (c *CqlServerConnection) incomingLoop() { 458 log.Debug().Msgf("%v: listening for incoming frames...", c) 459 c.waitGroup.Add(1) 460 go func() { 461 abort := false 462 for !abort && !c.IsClosed() { 463 if abort = c.setIdleTimeout(); !abort { 464 if source, err := c.waitForIncomingData(); err != nil { 465 abort = c.reportConnectionFailure(err, true) 466 } else if c.modernLayout { 467 abort = c.readSegment(source) 468 } else { 469 abort = c.readFrame(source) 470 } 471 } 472 } 473 c.waitGroup.Done() 474 if abort { 475 c.abort() 476 } 477 }() 478 } 479 480 func (c *CqlServerConnection) outgoingLoop() { 481 log.Debug().Msgf("%v: listening for outgoing frames...", c) 482 c.waitGroup.Add(1) 483 go func() { 484 abort := false 485 for !c.IsClosed() { 486 if outgoing, ok := <-c.outgoing; !ok { 487 if !c.IsClosed() { 488 log.Error().Msgf("%v: outgoing frame channel was closed unexpectedly, closing connection", c) 489 abort = true 490 } 491 break 492 } else { 493 if outgoing.rawResponse != nil { 494 abort = c.writeRawResponse(outgoing.rawResponse, c.conn) 495 log.Debug().Msgf("%v: sending outgoing raw response: %v", c, outgoing.rawResponse) 496 } else { 497 if c.compression != primitive.CompressionNone { 498 outgoing.responseFrame.Header.Flags = outgoing.responseFrame.Header.Flags.Add(primitive.HeaderFlagCompressed) 499 } 500 log.Debug().Msgf("%v: sending outgoing frame: %v", c, outgoing.responseFrame) 501 if c.modernLayout { 502 // TODO write coalescer 503 abort = c.writeSegment(outgoing.responseFrame, c.conn) 504 } else { 505 abort = c.writeFrame(outgoing.responseFrame, c.conn) 506 } 507 } 508 } 509 } 510 c.waitGroup.Done() 511 if abort { 512 c.abort() 513 } 514 }() 515 } 516 517 func (c *CqlServerConnection) waitForIncomingData() (io.Reader, error) { 518 buf := make([]byte, 1) 519 if _, err := io.ReadFull(c.conn, buf); err != nil { 520 return nil, err 521 } else { 522 return io.MultiReader(bytes.NewReader(buf), c.conn), nil 523 } 524 } 525 526 func (c *CqlServerConnection) setIdleTimeout() (abort bool) { 527 if err := c.conn.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil { 528 if !c.IsClosed() { 529 log.Error().Err(err).Msgf("%v: error setting idle timeout, closing connection", c) 530 abort = true 531 } 532 } 533 return abort 534 } 535 536 func (c *CqlServerConnection) readSegment(source io.Reader) (abort bool) { 537 if incoming, err := c.segmentCodec.DecodeSegment(source); err != nil { 538 abort = c.reportConnectionFailure(err, true) 539 } else if incoming.Header.IsSelfContained { 540 log.Debug().Msgf("%v: received incoming self-contained segment: %v", c, incoming) 541 abort = c.readSelfContainedSegment(incoming, abort) 542 } else { 543 log.Debug().Msgf("%v: received incoming multi-segment part: %v", c, incoming) 544 abort = c.addMultiSegmentPayload(incoming.Payload) 545 } 546 return abort 547 } 548 549 func (c *CqlServerConnection) readSelfContainedSegment(incoming *segment.Segment, abort bool) bool { 550 payloadReader := bytes.NewReader(incoming.Payload.UncompressedData) 551 for payloadReader.Len() > 0 { 552 if abort = c.readFrame(payloadReader); abort { 553 break 554 } 555 } 556 return abort 557 } 558 559 func (c *CqlServerConnection) addMultiSegmentPayload(payload *segment.Payload) (abort bool) { 560 accumulator := c.payloadAccumulator 561 if accumulator.targetLength == 0 { 562 // First reader, read ahead to find the target length 563 if header, err := accumulator.frameCodec.DecodeHeader(bytes.NewReader(payload.UncompressedData)); err != nil { 564 log.Error().Err(err).Msgf("%v: error decoding first frame header in multi-segment payload, closing connection", c) 565 return true 566 } else { 567 accumulator.targetLength = int(primitive.FrameHeaderLengthV3AndHigher + header.BodyLength) 568 } 569 } 570 accumulator.accumulatedData = append(accumulator.accumulatedData, payload.UncompressedData...) 571 if accumulator.targetLength == len(accumulator.accumulatedData) { 572 // We've received enough data to reassemble the whole frame 573 encodedFrame := bytes.NewReader(accumulator.accumulatedData) 574 accumulator.reset() 575 return c.readFrame(encodedFrame) 576 } 577 return false 578 } 579 580 func (c *CqlServerConnection) writeSegment(outgoing *frame.Frame, dest io.Writer) (abort bool) { 581 // never compress frames individually when included in a segment 582 outgoing.Header.Flags.Remove(primitive.HeaderFlagCompressed) 583 encodedFrame := &bytes.Buffer{} 584 if abort = c.writeFrame(outgoing, encodedFrame); abort { 585 abort = true 586 } else { 587 seg := &segment.Segment{ 588 Header: &segment.Header{IsSelfContained: true}, 589 Payload: &segment.Payload{UncompressedData: encodedFrame.Bytes()}, 590 } 591 if err := c.segmentCodec.EncodeSegment(seg, dest); err != nil { 592 abort = c.reportConnectionFailure(err, false) 593 } else { 594 log.Debug().Msgf("%v: outgoing segment successfully written: %v (frame: %v)", c, seg, outgoing) 595 } 596 } 597 return abort 598 } 599 600 func (c *CqlServerConnection) readFrame(source io.Reader) (abort bool) { 601 if incoming, err := c.frameCodec.DecodeFrame(source); err != nil { 602 abort = c.reportConnectionFailure(err, true) 603 } else { 604 if startup, ok := incoming.Body.Message.(*message.Startup); ok { 605 c.compression = startup.GetCompression() 606 c.frameCodec = frame.NewCodecWithCompression(NewBodyCompressor(c.compression)) 607 c.segmentCodec = segment.NewCodecWithCompression(NewPayloadCompressor(c.compression)) 608 } 609 c.processIncomingFrame(incoming) 610 } 611 return abort 612 } 613 614 func (c *CqlServerConnection) writeFrame(outgoing *frame.Frame, dest io.Writer) (abort bool) { 615 c.maybeSwitchToModernLayout(outgoing) 616 if err := c.frameCodec.EncodeFrame(outgoing, dest); err != nil { 617 abort = c.reportConnectionFailure(err, false) 618 } else { 619 log.Debug().Msgf("%v: outgoing frame successfully written: %v", c, outgoing) 620 } 621 return abort 622 } 623 624 func (c *CqlServerConnection) writeRawResponse(outgoing []byte, dest io.Writer) (abort bool) { 625 if _, err := dest.Write(outgoing); err != nil { 626 abort = c.reportConnectionFailure(err, false) 627 } else { 628 log.Debug().Msgf("%v: outgoing raw response successfully written: %v", c, outgoing) 629 } 630 return abort 631 } 632 633 func (c *CqlServerConnection) maybeSwitchToModernLayout(outgoing *frame.Frame) { 634 if !c.modernLayout && 635 outgoing.Header.Version.SupportsModernFramingLayout() && 636 (isReady(outgoing) || isAuthenticate(outgoing)) { 637 // Changing this value could be racy if some incoming frame is being processed; 638 // but in theory, this should never happen during handshake. 639 log.Debug().Msgf("%v: switching to modern framing layout", c) 640 c.modernLayout = true 641 } 642 } 643 644 func (c *CqlServerConnection) reportConnectionFailure(err error, read bool) (abort bool) { 645 if !c.IsClosed() { 646 if errors.Is(err, io.EOF) { 647 log.Info().Msgf("%v: connection reset by peer, closing", c) 648 } else { 649 if read { 650 log.Error().Err(err).Msgf("%v: error reading, closing connection", c) 651 } else { 652 log.Error().Err(err).Msgf("%v: error writing, closing connection", c) 653 } 654 } 655 abort = true 656 } 657 return abort 658 } 659 660 func (c *CqlServerConnection) processIncomingFrame(incoming *frame.Frame) { 661 log.Debug().Msgf("%v: received incoming frame: %v", c, incoming) 662 select { 663 case c.incoming <- incoming: 664 log.Debug().Msgf("%v: incoming frame successfully delivered: %v", c, incoming) 665 default: 666 log.Error().Msgf("%v: incoming frames queue is full, discarding frame: %v", c, incoming) 667 } 668 if len(c.handlers) > 0 { 669 c.invokeRequestHandlers(incoming) 670 } 671 } 672 673 func (c *CqlServerConnection) awaitDone() { 674 c.waitGroup.Add(1) 675 go func() { 676 <-c.ctx.Done() 677 log.Debug().Err(c.ctx.Err()).Msgf("%v: context was closed", c) 678 c.waitGroup.Done() 679 c.abort() 680 }() 681 } 682 683 func (c *CqlServerConnection) invokeRequestHandlers(request *frame.Frame) { 684 c.waitGroup.Add(1) 685 go func() { 686 log.Debug().Msgf("%v: invoking request handlers for incoming request: %v", c, request) 687 var err error 688 var rawResponse []byte 689 for i, rawHandler := range c.rawHandlers { 690 if rawResponse = rawHandler(request, c, c.handlerCtx[i]); rawResponse != nil { 691 log.Debug().Msgf("%v: raw request handler %v produced response: %v", c, i, rawResponse) 692 if err = c.SendRaw(rawResponse); err != nil { 693 log.Error().Err(err).Msgf("%v: send failed for frame: %v", c, rawResponse) 694 } 695 break 696 } 697 } 698 if rawResponse == nil { 699 var response *frame.Frame 700 for i, handler := range c.handlers { 701 if response = handler(request, c, c.handlerCtx[i]); response != nil { 702 log.Debug().Msgf("%v: request handler %v produced response: %v", c, i, response) 703 if err = c.Send(response); err != nil { 704 log.Error().Err(err).Msgf("%v: send failed for frame: %v", c, response) 705 } 706 break 707 } 708 } 709 if response == nil { 710 log.Debug().Msgf("%v: no request handler could handle the request: %v", c, request) 711 } 712 } 713 c.waitGroup.Done() 714 }() 715 } 716 717 // Send sends the given response frame. 718 func (c *CqlServerConnection) Send(f *frame.Frame) error { 719 if c.IsClosed() { 720 return fmt.Errorf("%v: connection closed", c) 721 } 722 log.Debug().Msgf("%v: enqueuing outgoing frame: %v", c, f) 723 select { 724 case c.outgoing <- newFrameResponse(f): 725 log.Debug().Msgf("%v: outgoing frame successfully enqueued: %v", c, f) 726 return nil 727 default: 728 return fmt.Errorf("%v: failed to enqueue outgoing frame: %v", c, f) 729 } 730 } 731 732 // SendRaw sends the given response frame (already encoded). 733 func (c *CqlServerConnection) SendRaw(rawResponse []byte) error { 734 if c.IsClosed() { 735 return fmt.Errorf("%v: connection closed", c) 736 } 737 log.Debug().Msgf("%v: enqueuing outgoing raw response: %v", c, rawResponse) 738 select { 739 case c.outgoing <- newRawResponse(rawResponse): 740 log.Debug().Msgf("%v: outgoing frame successfully enqueued: %v", c, rawResponse) 741 return nil 742 default: 743 return fmt.Errorf("%v: failed to send outgoing raw response: %v", c, rawResponse) 744 } 745 } 746 747 // Receive waits until the next request frame is received, or the configured idle timeout is triggered, or the 748 // connection itself is closed, whichever happens first. 749 func (c *CqlServerConnection) Receive() (*frame.Frame, error) { 750 if c.IsClosed() { 751 return nil, fmt.Errorf("%v: connection closed", c) 752 } 753 log.Debug().Msgf("%v: waiting for incoming frame", c) 754 if incoming, ok := <-c.incoming; !ok { 755 if c.IsClosed() { 756 return nil, fmt.Errorf("%v: connection closed", c) 757 } else { 758 return nil, fmt.Errorf("%v: incoming frame channel closed unexpectedly", c) 759 } 760 } else { 761 log.Debug().Msgf("%v: incoming frame successfully received: %v", c, incoming) 762 return incoming, nil 763 } 764 } 765 766 func (c *CqlServerConnection) IsClosed() bool { 767 return atomic.LoadInt32(&c.closed) == 1 768 } 769 770 func (c *CqlServerConnection) setClosed() bool { 771 return atomic.CompareAndSwapInt32(&c.closed, 0, 1) 772 } 773 774 func (c *CqlServerConnection) Close() (err error) { 775 if c.setClosed() { 776 log.Debug().Msgf("%v: closing", c) 777 c.cancel() 778 err = c.conn.Close() 779 incoming := c.incoming 780 outgoing := c.outgoing 781 c.incoming = nil 782 c.outgoing = nil 783 close(incoming) 784 close(outgoing) 785 c.waitGroup.Wait() 786 c.onClose(c) 787 if err != nil { 788 err = fmt.Errorf("%v: error closing: %w", c, err) 789 } else { 790 log.Info().Msgf("%v: successfully closed", c) 791 } 792 } else { 793 log.Debug().Err(err).Msgf("%v: already closed", c) 794 } 795 return err 796 } 797 798 func (c *CqlServerConnection) abort() { 799 log.Debug().Msgf("%v: forcefully closing", c) 800 if err := c.Close(); err != nil { 801 log.Error().Err(err).Msgf("%v: error closing", c) 802 } 803 }