github.com/cilium/cilium@v1.16.2/pkg/fqdn/dnsproxy/shared_client.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package dnsproxy 5 6 import ( 7 "context" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "sync" 13 14 "github.com/cilium/dns" 15 16 "github.com/cilium/cilium/pkg/lock" 17 "github.com/cilium/cilium/pkg/time" 18 ) 19 20 // SharedClients holds a set of SharedClient instances. 21 type SharedClients struct { 22 // SharedClient's lock must not be taken while this lock is held! 23 lock lock.Mutex 24 // clients are created and destroyed on demand, hence 'Mutex' needs to be taken. 25 clients map[string]*SharedClient 26 } 27 28 func NewSharedClients() *SharedClients { 29 return &SharedClients{ 30 clients: make(map[string]*SharedClient), 31 } 32 } 33 34 // sharedClientKey returns an identifier for this five-tuple used to find a shared client. 35 func sharedClientKey(protocol, client, server string) string { 36 return protocol + "-" + client + "-" + server 37 } 38 39 func (s *SharedClients) Exchange(key string, conf *dns.Client, m *dns.Msg, serverAddrStr string) (r *dns.Msg, rtt time.Duration, closer func(), err error) { 40 return s.ExchangeContext(context.Background(), key, conf, m, serverAddrStr) 41 } 42 43 func (s *SharedClients) ExchangeContext(ctx context.Context, key string, conf *dns.Client, m *dns.Msg, serverAddrStr string) (r *dns.Msg, rtt time.Duration, closer func(), err error) { 44 client, closer := s.GetSharedClient(key, conf, serverAddrStr) 45 r, rtt, err = client.ExchangeSharedContext(ctx, m) 46 return r, rtt, closer, err 47 } 48 49 // Called by wrapped TCP connection once the downstream connection breaks/is 50 // closed. Used as a signal to close the upstream connection. 51 func (s *SharedClients) ShutdownTCPClient(key string) { 52 // lock for s.clients access 53 s.lock.Lock() 54 client := s.clients[key] 55 if client == nil { 56 s.lock.Unlock() 57 return 58 } 59 delete(s.clients, key) 60 s.lock.Unlock() 61 62 // The reference counting in the shared client is not authoritative for TCP. 63 // It is increased on acquiring a reference to it, but never decreased, 64 // since we don't want it to hit zero while the downstream connection is 65 // open. Hence we also don't check for it being zero when closing. 66 client.Lock() 67 defer client.Unlock() 68 client.close() 69 } 70 71 // GetSharedClient gets or creates an instance of SharedClient keyed with 'key'. if 'key' is an 72 // empty sting, a new client is always created and it is not actually shared. The returned 'closer' 73 // must be called once the client is no longer needed. Conversely, the returned 'client' must not be 74 // used after the closer is called. 75 func (s *SharedClients) GetSharedClient(key string, conf *dns.Client, serverAddrStr string) (client *SharedClient, closer func()) { 76 if key == "" { 77 // Simplified case when the client is actually not shared 78 client = newSharedClient(conf, serverAddrStr) 79 return client, client.close 80 } 81 for { 82 // lock for s.clients access 83 s.lock.Lock() 84 // locate client to re-use if possible. 85 client = s.clients[key] 86 if client == nil { 87 client = newSharedClient(conf, serverAddrStr) 88 s.clients[key] = client 89 s.lock.Unlock() 90 // new client, we are done 91 break 92 } 93 s.lock.Unlock() 94 95 // reusing client that may start closing while we wait for its lock 96 client.Lock() 97 if client.refcount > 0 { 98 // not closed, add our refcount 99 client.refcount++ 100 client.Unlock() 101 break 102 } 103 // client was closed while we waited for it's lock, discard and try again 104 client.Unlock() 105 client = nil 106 } 107 108 // TCP clients are cleaned up on downstream connection close. 109 if conf.Net == "tcp" { 110 return client, func() {} 111 } 112 113 return client, func() { 114 client.Lock() 115 defer client.Unlock() 116 client.refcount-- 117 if client.refcount == 0 { 118 // connection close must be completed while holding the client's lock to 119 // avoid a race where a new client dials using the same 5-tuple and gets a 120 // bind error. 121 // The client remains findable so that new users with the same key may wait 122 // for this closing to be done with. 123 client.close() 124 // Make client unreachable 125 // Must take s.lock for this. 126 s.lock.Lock() 127 delete(s.clients, key) 128 s.lock.Unlock() 129 } 130 } 131 } 132 133 type request struct { 134 ctx context.Context 135 msg *dns.Msg 136 ch chan sharedClientResponse 137 } 138 139 type sharedClientResponse struct { 140 msg *dns.Msg 141 rtt time.Duration 142 err error 143 } 144 145 // A SharedClient keeps state for concurrent transactions on the same upstream client/connection. 146 type SharedClient struct { 147 serverAddr string 148 149 *dns.Client 150 151 // requests is closed when the client needs to exit 152 requests chan request 153 // wg is waited on for the client finish exiting 154 wg sync.WaitGroup 155 156 lock.Mutex // protects the fields below 157 refcount int 158 conn *dns.Conn 159 } 160 161 func newSharedClient(conf *dns.Client, serverAddr string) *SharedClient { 162 return &SharedClient{ 163 refcount: 1, 164 serverAddr: serverAddr, 165 Client: conf, 166 requests: make(chan request), 167 } 168 } 169 170 // ExchangeShared dials a connection to the server on first invocation, and starts a handler 171 // goroutines to send and receive responses, distributing them to appropriate concurrent caller 172 // based on the DNS message Id. 173 func (c *SharedClient) ExchangeShared(m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) { 174 return c.ExchangeSharedContext(context.Background(), m) 175 } 176 177 // handler is started when the connection is dialed 178 func handler(wg *sync.WaitGroup, client *dns.Client, conn *dns.Conn, requests chan request) { 179 defer wg.Done() 180 181 responses := make(chan sharedClientResponse) 182 183 // receiverTrigger is used to wake up the receive loop after request(s) have been sent. It 184 // must be buffered to be able to send a trigger while the receive loop is not yet ready to 185 // receive the trigger, as we do not want to stall the sender when the receiver is blocking 186 // on the read operation. 187 receiverTrigger := make(chan struct{}, 1) 188 triggerReceiver := func() { 189 select { 190 case receiverTrigger <- struct{}{}: 191 default: 192 } 193 } 194 195 // Receive loop 196 wg.Add(1) 197 go func() { 198 defer wg.Done() 199 defer close(responses) 200 201 // No point trying to receive until the first request has been successfully sent, so 202 // wait for a trigger first. receiverTrigger is buffered, so this is safe 203 // to do, even if the sender sends the trigger before we are ready to receive here. 204 <-receiverTrigger 205 206 for { 207 // This will block but eventually return an i/o timeout, as we always set 208 // the timeouts before sending anything 209 r, err := conn.ReadMsg() 210 if err == nil { 211 responses <- sharedClientResponse{r, 0, nil} 212 continue // receive immediately again 213 } 214 215 // handler is not reading on the channel after closing. 216 // UDP connections return net.ErrClosed, while TCP/TLS connections are read 217 // via the io package, which return io.EOF. 218 if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { 219 return 220 } 221 222 // send error response to cancel all current requests. 223 responses <- sharedClientResponse{nil, 0, err} 224 225 // wait for a trigger from the handler after any errors. Re-reading in 226 // this condition could busy loop, e.g., if a read timeout occurred. 227 // receiverTrigger is buffered so that we catch the trigger that may 228 // have been sent while we sent the error response above. 229 _, ok := <-receiverTrigger 230 if !ok { 231 return // exit immediately when the trigger channel is closed 232 } 233 } 234 }() 235 236 type waiter struct { 237 ch chan sharedClientResponse 238 start time.Time 239 } 240 waitingResponses := make(map[uint16]waiter) 241 defer func() { 242 conn.Close() 243 close(receiverTrigger) 244 245 // Drain responses send by receive loop to allow it to exit. 246 // It may be repeatedly reading after an i/o timeout, for example. 247 for range responses { 248 } 249 250 for _, waiter := range waitingResponses { 251 waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed} 252 close(waiter.ch) 253 } 254 }() 255 256 for { 257 select { 258 case req, ok := <-requests: 259 if !ok { 260 // 'requests' is closed when SharedClient is recycled, which happens 261 // responses (or errors) have been received and there are no more 262 // requests to be sent. 263 return 264 } 265 start := time.Now() 266 267 // Check if we already have a request with the same id 268 // Due to birthday paradox and the fact that ID is uint16 269 // it's likely to happen with small number (~200) of concurrent requests 270 // which would result in goroutine leak as we would never close req.ch 271 if _, duplicate := waitingResponses[req.msg.Id]; duplicate { 272 for n := 0; n < 5; n++ { 273 // Try a new ID 274 id := dns.Id() 275 if _, duplicate = waitingResponses[id]; !duplicate { 276 req.msg.Id = id 277 break 278 } 279 } 280 if duplicate { 281 req.ch <- sharedClientResponse{nil, 0, fmt.Errorf("duplicate request id %d", req.msg.Id)} 282 close(req.ch) 283 continue 284 } 285 } 286 287 err := client.SendContext(req.ctx, req.msg, conn, start) 288 if err != nil { 289 req.ch <- sharedClientResponse{nil, 0, err} 290 close(req.ch) 291 } else { 292 waitingResponses[req.msg.Id] = waiter{req.ch, start} 293 294 // Wake up the receiver that may be waiting to receive again 295 triggerReceiver() 296 } 297 298 case resp, ok := <-responses: 299 if !ok { 300 // 'responses' is closed when the receive loop exits, so we quit as 301 // nothing can be received any more 302 return 303 } 304 if resp.err != nil { 305 // ReadMsg failed, but we cannot match it to a request, 306 // so complete all pending requests. 307 for _, waiter := range waitingResponses { 308 waiter.ch <- sharedClientResponse{nil, 0, resp.err} 309 close(waiter.ch) 310 } 311 waitingResponses = make(map[uint16]waiter) 312 } else if resp.msg != nil { 313 if waiter, ok := waitingResponses[resp.msg.Id]; ok { 314 delete(waitingResponses, resp.msg.Id) 315 resp.rtt = time.Since(waiter.start) 316 waiter.ch <- resp 317 close(waiter.ch) 318 } 319 } 320 } 321 } 322 } 323 324 func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) { 325 c.Lock() 326 if c.conn == nil { 327 c.conn, err = c.DialContext(ctx, c.serverAddr) 328 if err != nil { 329 c.Unlock() 330 return nil, 0, fmt.Errorf("failed to dial connection to %v: %w", c.serverAddr, err) 331 } 332 // Start handler for sending and receiving. 333 c.wg.Add(1) 334 go handler(&c.wg, c.Client, c.conn, c.requests) 335 } 336 c.Unlock() 337 338 // This request keeps 'c.requests' open; sending a request may hang indefinitely if 339 // the handler happens to quit at the same time. Use ctx.Done to avoid this. 340 timeout := getTimeoutForRequest(c.Client) 341 ctx, cancel := context.WithTimeout(ctx, timeout) 342 defer cancel() 343 respCh := make(chan sharedClientResponse) 344 select { 345 case c.requests <- request{ctx: ctx, msg: m, ch: respCh}: 346 case <-ctx.Done(): 347 return nil, 0, ctx.Err() 348 } 349 350 // Since c.requests is unbuffered, the handler is guaranteed to eventually close 'respCh' 351 select { 352 case resp := <-respCh: 353 return resp.msg, resp.rtt, resp.err 354 // This is just fail-safe mechanism in case there is another similar issue 355 case <-time.After(time.Minute): 356 return nil, 0, fmt.Errorf("timeout waiting for response") 357 } 358 } 359 360 // close closes and waits for the close to finish. 361 // Must be called while holding client's lock. 362 func (c *SharedClient) close() { 363 close(c.requests) 364 c.wg.Wait() 365 c.conn = nil 366 } 367 368 // Return the appropriate timeout for a specific request 369 func getTimeoutForRequest(c *dns.Client) time.Duration { 370 wtimeout := c.WriteTimeout 371 if wtimeout == 0 { 372 // Some default timeout as seen in miekg/dns. 373 wtimeout = time.Second * 2 374 } 375 376 var requestTimeout time.Duration 377 if c.Timeout != 0 { 378 requestTimeout = c.Timeout 379 } else { 380 requestTimeout = wtimeout 381 } 382 // net.Dialer.Timeout has priority if smaller than the timeouts computed so 383 // far 384 if c.Dialer != nil && c.Dialer.Timeout != 0 { 385 if c.Dialer.Timeout < requestTimeout { 386 requestTimeout = c.Dialer.Timeout 387 } 388 } 389 return requestTimeout 390 }