github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/connection.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package client 16 17 import ( 18 "fmt" 19 "net" 20 "sync" 21 "sync/atomic" 22 23 "github.com/rs/zerolog/log" 24 ) 25 26 type connectionHolder struct { 27 ch chan *CqlServerConnection 28 conn *CqlServerConnection 29 } 30 31 type clientConnectionHandler struct { 32 serverId string 33 maxConnections int 34 connections map[string]*connectionHolder 35 anyConnChan chan *CqlServerConnection 36 connectionsLock *sync.Mutex 37 closed int32 38 } 39 40 func (h *clientConnectionHandler) String() string { 41 return fmt.Sprintf("%v: [conn. handler]", h.serverId) 42 } 43 44 func newClientConnectionHandler(serverId string, maxClientConnections int) (*clientConnectionHandler, error) { 45 if maxClientConnections < 1 { 46 return nil, fmt.Errorf("max connections: expecting positive, got: %v", maxClientConnections) 47 } 48 return &clientConnectionHandler{ 49 serverId: serverId, 50 maxConnections: maxClientConnections, 51 connections: make(map[string]*connectionHolder, maxClientConnections), 52 anyConnChan: make(chan *CqlServerConnection, maxClientConnections), 53 connectionsLock: &sync.Mutex{}, 54 }, nil 55 } 56 57 func (h *clientConnectionHandler) anyConnectionChannel() <-chan *CqlServerConnection { 58 return h.anyConnChan 59 } 60 61 func (h *clientConnectionHandler) allAcceptedClients() []*CqlServerConnection { 62 h.connectionsLock.Lock() 63 defer h.connectionsLock.Unlock() 64 var connections []*CqlServerConnection 65 for _, holder := range h.connections { 66 if holder.conn != nil && !holder.conn.IsClosed() { 67 connections = append(connections, holder.conn) 68 } 69 } 70 return connections 71 } 72 73 func (h *clientConnectionHandler) onConnectionAcceptRequested(client *CqlClientConnection) (<-chan *CqlServerConnection, error) { 74 if h.isClosed() { 75 return nil, fmt.Errorf("%v: handler closed", h) 76 } 77 if clientAddr, err := h.asMapKey(client.conn.LocalAddr()); err != nil { 78 return nil, err 79 } else { 80 log.Trace().Msgf("%v: client accept requested: %v", h, clientAddr) 81 h.connectionsLock.Lock() 82 defer h.connectionsLock.Unlock() 83 holder, found := h.connections[clientAddr] 84 if !found { 85 log.Trace().Msgf("%v: client address unknown, registering new channel: %v", h, clientAddr) 86 if len(h.connections) == h.maxConnections { 87 return nil, fmt.Errorf("%v: too many connections: %v", h, h.maxConnections) 88 } 89 holder = &connectionHolder{ 90 ch: make(chan *CqlServerConnection, 1), 91 } 92 h.connections[clientAddr] = holder 93 } 94 return holder.ch, nil 95 } 96 } 97 98 func (h *clientConnectionHandler) onConnectionAccepted(connection *CqlServerConnection) error { 99 if h.isClosed() { 100 return fmt.Errorf("%v: handler closed", h) 101 } 102 if clientAddr, err := h.asMapKey(connection.conn.RemoteAddr()); err != nil { 103 return err 104 } else { 105 log.Trace().Msgf("%v: client accepted: %v", h, connection.conn.RemoteAddr()) 106 h.connectionsLock.Lock() 107 defer h.connectionsLock.Unlock() 108 holder, found := h.connections[clientAddr] 109 if found { 110 holder.conn = connection 111 } else { 112 log.Trace().Msgf("%v: client address unknown, registering new channel: %v", h, connection.conn.RemoteAddr()) 113 if len(h.connections) == h.maxConnections { 114 return fmt.Errorf("%v: too many connections: %v", h, h.maxConnections) 115 } 116 holder = &connectionHolder{ 117 ch: make(chan *CqlServerConnection, 1), 118 conn: connection, 119 } 120 h.connections[clientAddr] = holder 121 } 122 holder.ch <- connection 123 h.anyConnChan <- connection 124 return nil 125 } 126 } 127 128 func (h *clientConnectionHandler) onConnectionClosed(connection *CqlServerConnection) { 129 if !h.isClosed() { 130 if clientAddr, err := h.asMapKey(connection.conn.RemoteAddr()); err == nil { 131 log.Trace().Msgf("%v: client address closed, removing: %v", h, connection.conn.RemoteAddr()) 132 h.connectionsLock.Lock() 133 defer h.connectionsLock.Unlock() 134 if holder, found := h.connections[clientAddr]; found { 135 log.Trace().Msgf("%v: client address removed: %v", h, connection.conn.RemoteAddr()) 136 delete(h.connections, clientAddr) 137 close(holder.ch) 138 } else { 139 log.Trace().Msgf("%v: client address not found, ignoring: %v", h, connection.conn.RemoteAddr()) 140 } 141 } 142 } 143 } 144 145 func (h *clientConnectionHandler) isClosed() bool { 146 return atomic.LoadInt32(&h.closed) == 1 147 } 148 149 func (h *clientConnectionHandler) setClosed() bool { 150 return atomic.CompareAndSwapInt32(&h.closed, 0, 1) 151 } 152 153 func (h *clientConnectionHandler) close() { 154 if h.setClosed() { 155 log.Trace().Msgf("%v: closing", h) 156 h.connectionsLock.Lock() 157 for clientAddr, holder := range h.connections { 158 delete(h.connections, clientAddr) 159 if err := holder.conn.Close(); err != nil { 160 log.Error().Err(err).Msg(err.Error()) 161 } 162 close(holder.ch) 163 } 164 anyConnChan := h.anyConnChan 165 h.anyConnChan = nil 166 close(anyConnChan) 167 h.connectionsLock.Unlock() 168 log.Trace().Msgf("%v: successfully closed", h) 169 } 170 } 171 172 func (h *clientConnectionHandler) asMapKey(clientAddr net.Addr) (string, error) { 173 if tcpAddr, ok := clientAddr.(*net.TCPAddr); !ok { 174 return "", fmt.Errorf("%v: expected TCP address, got: %v", h, clientAddr) 175 } else { 176 return fmt.Sprintf("%v__%v__%v", string(tcpAddr.IP), tcpAddr.Port, tcpAddr.Zone), nil 177 } 178 }