github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/connection.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client
    16  
    17  import (
    18  	"fmt"
    19  	"net"
    20  	"sync"
    21  	"sync/atomic"
    22  
    23  	"github.com/rs/zerolog/log"
    24  )
    25  
    26  type connectionHolder struct {
    27  	ch   chan *CqlServerConnection
    28  	conn *CqlServerConnection
    29  }
    30  
    31  type clientConnectionHandler struct {
    32  	serverId        string
    33  	maxConnections  int
    34  	connections     map[string]*connectionHolder
    35  	anyConnChan     chan *CqlServerConnection
    36  	connectionsLock *sync.Mutex
    37  	closed          int32
    38  }
    39  
    40  func (h *clientConnectionHandler) String() string {
    41  	return fmt.Sprintf("%v: [conn. handler]", h.serverId)
    42  }
    43  
    44  func newClientConnectionHandler(serverId string, maxClientConnections int) (*clientConnectionHandler, error) {
    45  	if maxClientConnections < 1 {
    46  		return nil, fmt.Errorf("max connections: expecting positive, got: %v", maxClientConnections)
    47  	}
    48  	return &clientConnectionHandler{
    49  		serverId:        serverId,
    50  		maxConnections:  maxClientConnections,
    51  		connections:     make(map[string]*connectionHolder, maxClientConnections),
    52  		anyConnChan:     make(chan *CqlServerConnection, maxClientConnections),
    53  		connectionsLock: &sync.Mutex{},
    54  	}, nil
    55  }
    56  
    57  func (h *clientConnectionHandler) anyConnectionChannel() <-chan *CqlServerConnection {
    58  	return h.anyConnChan
    59  }
    60  
    61  func (h *clientConnectionHandler) allAcceptedClients() []*CqlServerConnection {
    62  	h.connectionsLock.Lock()
    63  	defer h.connectionsLock.Unlock()
    64  	var connections []*CqlServerConnection
    65  	for _, holder := range h.connections {
    66  		if holder.conn != nil && !holder.conn.IsClosed() {
    67  			connections = append(connections, holder.conn)
    68  		}
    69  	}
    70  	return connections
    71  }
    72  
    73  func (h *clientConnectionHandler) onConnectionAcceptRequested(client *CqlClientConnection) (<-chan *CqlServerConnection, error) {
    74  	if h.isClosed() {
    75  		return nil, fmt.Errorf("%v: handler closed", h)
    76  	}
    77  	if clientAddr, err := h.asMapKey(client.conn.LocalAddr()); err != nil {
    78  		return nil, err
    79  	} else {
    80  		log.Trace().Msgf("%v: client accept requested: %v", h, clientAddr)
    81  		h.connectionsLock.Lock()
    82  		defer h.connectionsLock.Unlock()
    83  		holder, found := h.connections[clientAddr]
    84  		if !found {
    85  			log.Trace().Msgf("%v: client address unknown, registering new channel: %v", h, clientAddr)
    86  			if len(h.connections) == h.maxConnections {
    87  				return nil, fmt.Errorf("%v: too many connections: %v", h, h.maxConnections)
    88  			}
    89  			holder = &connectionHolder{
    90  				ch: make(chan *CqlServerConnection, 1),
    91  			}
    92  			h.connections[clientAddr] = holder
    93  		}
    94  		return holder.ch, nil
    95  	}
    96  }
    97  
    98  func (h *clientConnectionHandler) onConnectionAccepted(connection *CqlServerConnection) error {
    99  	if h.isClosed() {
   100  		return fmt.Errorf("%v: handler closed", h)
   101  	}
   102  	if clientAddr, err := h.asMapKey(connection.conn.RemoteAddr()); err != nil {
   103  		return err
   104  	} else {
   105  		log.Trace().Msgf("%v: client accepted: %v", h, connection.conn.RemoteAddr())
   106  		h.connectionsLock.Lock()
   107  		defer h.connectionsLock.Unlock()
   108  		holder, found := h.connections[clientAddr]
   109  		if found {
   110  			holder.conn = connection
   111  		} else {
   112  			log.Trace().Msgf("%v: client address unknown, registering new channel: %v", h, connection.conn.RemoteAddr())
   113  			if len(h.connections) == h.maxConnections {
   114  				return fmt.Errorf("%v: too many connections: %v", h, h.maxConnections)
   115  			}
   116  			holder = &connectionHolder{
   117  				ch:   make(chan *CqlServerConnection, 1),
   118  				conn: connection,
   119  			}
   120  			h.connections[clientAddr] = holder
   121  		}
   122  		holder.ch <- connection
   123  		h.anyConnChan <- connection
   124  		return nil
   125  	}
   126  }
   127  
   128  func (h *clientConnectionHandler) onConnectionClosed(connection *CqlServerConnection) {
   129  	if !h.isClosed() {
   130  		if clientAddr, err := h.asMapKey(connection.conn.RemoteAddr()); err == nil {
   131  			log.Trace().Msgf("%v: client address closed, removing: %v", h, connection.conn.RemoteAddr())
   132  			h.connectionsLock.Lock()
   133  			defer h.connectionsLock.Unlock()
   134  			if holder, found := h.connections[clientAddr]; found {
   135  				log.Trace().Msgf("%v: client address removed: %v", h, connection.conn.RemoteAddr())
   136  				delete(h.connections, clientAddr)
   137  				close(holder.ch)
   138  			} else {
   139  				log.Trace().Msgf("%v: client address not found, ignoring: %v", h, connection.conn.RemoteAddr())
   140  			}
   141  		}
   142  	}
   143  }
   144  
   145  func (h *clientConnectionHandler) isClosed() bool {
   146  	return atomic.LoadInt32(&h.closed) == 1
   147  }
   148  
   149  func (h *clientConnectionHandler) setClosed() bool {
   150  	return atomic.CompareAndSwapInt32(&h.closed, 0, 1)
   151  }
   152  
   153  func (h *clientConnectionHandler) close() {
   154  	if h.setClosed() {
   155  		log.Trace().Msgf("%v: closing", h)
   156  		h.connectionsLock.Lock()
   157  		for clientAddr, holder := range h.connections {
   158  			delete(h.connections, clientAddr)
   159  			if err := holder.conn.Close(); err != nil {
   160  				log.Error().Err(err).Msg(err.Error())
   161  			}
   162  			close(holder.ch)
   163  		}
   164  		anyConnChan := h.anyConnChan
   165  		h.anyConnChan = nil
   166  		close(anyConnChan)
   167  		h.connectionsLock.Unlock()
   168  		log.Trace().Msgf("%v: successfully closed", h)
   169  	}
   170  }
   171  
   172  func (h *clientConnectionHandler) asMapKey(clientAddr net.Addr) (string, error) {
   173  	if tcpAddr, ok := clientAddr.(*net.TCPAddr); !ok {
   174  		return "", fmt.Errorf("%v: expected TCP address, got: %v", h, clientAddr)
   175  	} else {
   176  		return fmt.Sprintf("%v__%v__%v", string(tcpAddr.IP), tcpAddr.Port, tcpAddr.Zone), nil
   177  	}
   178  }