github.com/cilium/cilium@v1.16.2/pkg/fqdn/dnsproxy/shared_client.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package dnsproxy
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  
    14  	"github.com/cilium/dns"
    15  
    16  	"github.com/cilium/cilium/pkg/lock"
    17  	"github.com/cilium/cilium/pkg/time"
    18  )
    19  
    20  // SharedClients holds a set of SharedClient instances.
    21  type SharedClients struct {
    22  	// SharedClient's lock must not be taken while this lock is held!
    23  	lock lock.Mutex
    24  	// clients are created and destroyed on demand, hence 'Mutex' needs to be taken.
    25  	clients map[string]*SharedClient
    26  }
    27  
    28  func NewSharedClients() *SharedClients {
    29  	return &SharedClients{
    30  		clients: make(map[string]*SharedClient),
    31  	}
    32  }
    33  
    34  // sharedClientKey returns an identifier for this five-tuple used to find a shared client.
    35  func sharedClientKey(protocol, client, server string) string {
    36  	return protocol + "-" + client + "-" + server
    37  }
    38  
    39  func (s *SharedClients) Exchange(key string, conf *dns.Client, m *dns.Msg, serverAddrStr string) (r *dns.Msg, rtt time.Duration, closer func(), err error) {
    40  	return s.ExchangeContext(context.Background(), key, conf, m, serverAddrStr)
    41  }
    42  
    43  func (s *SharedClients) ExchangeContext(ctx context.Context, key string, conf *dns.Client, m *dns.Msg, serverAddrStr string) (r *dns.Msg, rtt time.Duration, closer func(), err error) {
    44  	client, closer := s.GetSharedClient(key, conf, serverAddrStr)
    45  	r, rtt, err = client.ExchangeSharedContext(ctx, m)
    46  	return r, rtt, closer, err
    47  }
    48  
    49  // Called by wrapped TCP connection once the downstream connection breaks/is
    50  // closed. Used as a signal to close the upstream connection.
    51  func (s *SharedClients) ShutdownTCPClient(key string) {
    52  	// lock for s.clients access
    53  	s.lock.Lock()
    54  	client := s.clients[key]
    55  	if client == nil {
    56  		s.lock.Unlock()
    57  		return
    58  	}
    59  	delete(s.clients, key)
    60  	s.lock.Unlock()
    61  
    62  	// The reference counting in the shared client is not authoritative for TCP.
    63  	// It is increased on acquiring a reference to it, but never decreased,
    64  	// since we don't want it to hit zero while the downstream connection is
    65  	// open. Hence we also don't check for it being zero when closing.
    66  	client.Lock()
    67  	defer client.Unlock()
    68  	client.close()
    69  }
    70  
    71  // GetSharedClient gets or creates an instance of SharedClient keyed with 'key'.  if 'key' is an
    72  // empty sting, a new client is always created and it is not actually shared.  The returned 'closer'
    73  // must be called once the client is no longer needed. Conversely, the returned 'client' must not be
    74  // used after the closer is called.
    75  func (s *SharedClients) GetSharedClient(key string, conf *dns.Client, serverAddrStr string) (client *SharedClient, closer func()) {
    76  	if key == "" {
    77  		// Simplified case when the client is actually not shared
    78  		client = newSharedClient(conf, serverAddrStr)
    79  		return client, client.close
    80  	}
    81  	for {
    82  		// lock for s.clients access
    83  		s.lock.Lock()
    84  		// locate client to re-use if possible.
    85  		client = s.clients[key]
    86  		if client == nil {
    87  			client = newSharedClient(conf, serverAddrStr)
    88  			s.clients[key] = client
    89  			s.lock.Unlock()
    90  			// new client, we are done
    91  			break
    92  		}
    93  		s.lock.Unlock()
    94  
    95  		// reusing client that may start closing while we wait for its lock
    96  		client.Lock()
    97  		if client.refcount > 0 {
    98  			// not closed, add our refcount
    99  			client.refcount++
   100  			client.Unlock()
   101  			break
   102  		}
   103  		// client was closed while we waited for it's lock, discard and try again
   104  		client.Unlock()
   105  		client = nil
   106  	}
   107  
   108  	// TCP clients are cleaned up on downstream connection close.
   109  	if conf.Net == "tcp" {
   110  		return client, func() {}
   111  	}
   112  
   113  	return client, func() {
   114  		client.Lock()
   115  		defer client.Unlock()
   116  		client.refcount--
   117  		if client.refcount == 0 {
   118  			// connection close must be completed while holding the client's lock to
   119  			// avoid a race where a new client dials using the same 5-tuple and gets a
   120  			// bind error.
   121  			// The client remains findable so that new users with the same key may wait
   122  			// for this closing to be done with.
   123  			client.close()
   124  			// Make client unreachable
   125  			// Must take s.lock for this.
   126  			s.lock.Lock()
   127  			delete(s.clients, key)
   128  			s.lock.Unlock()
   129  		}
   130  	}
   131  }
   132  
   133  type request struct {
   134  	ctx context.Context
   135  	msg *dns.Msg
   136  	ch  chan sharedClientResponse
   137  }
   138  
   139  type sharedClientResponse struct {
   140  	msg *dns.Msg
   141  	rtt time.Duration
   142  	err error
   143  }
   144  
   145  // A SharedClient keeps state for concurrent transactions on the same upstream client/connection.
   146  type SharedClient struct {
   147  	serverAddr string
   148  
   149  	*dns.Client
   150  
   151  	// requests is closed when the client needs to exit
   152  	requests chan request
   153  	// wg is waited on for the client finish exiting
   154  	wg sync.WaitGroup
   155  
   156  	lock.Mutex // protects the fields below
   157  	refcount   int
   158  	conn       *dns.Conn
   159  }
   160  
   161  func newSharedClient(conf *dns.Client, serverAddr string) *SharedClient {
   162  	return &SharedClient{
   163  		refcount:   1,
   164  		serverAddr: serverAddr,
   165  		Client:     conf,
   166  		requests:   make(chan request),
   167  	}
   168  }
   169  
   170  // ExchangeShared dials a connection to the server on first invocation, and starts a handler
   171  // goroutines to send and receive responses, distributing them to appropriate concurrent caller
   172  // based on the DNS message Id.
   173  func (c *SharedClient) ExchangeShared(m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
   174  	return c.ExchangeSharedContext(context.Background(), m)
   175  }
   176  
   177  // handler is started when the connection is dialed
   178  func handler(wg *sync.WaitGroup, client *dns.Client, conn *dns.Conn, requests chan request) {
   179  	defer wg.Done()
   180  
   181  	responses := make(chan sharedClientResponse)
   182  
   183  	// receiverTrigger is used to wake up the receive loop after request(s) have been sent. It
   184  	// must be buffered to be able to send a trigger while the receive loop is not yet ready to
   185  	// receive the trigger, as we do not want to stall the sender when the receiver is blocking
   186  	// on the read operation.
   187  	receiverTrigger := make(chan struct{}, 1)
   188  	triggerReceiver := func() {
   189  		select {
   190  		case receiverTrigger <- struct{}{}:
   191  		default:
   192  		}
   193  	}
   194  
   195  	// Receive loop
   196  	wg.Add(1)
   197  	go func() {
   198  		defer wg.Done()
   199  		defer close(responses)
   200  
   201  		// No point trying to receive until the first request has been successfully sent, so
   202  		// wait for a trigger first. receiverTrigger is buffered, so this is safe
   203  		// to do, even if the sender sends the trigger before we are ready to receive here.
   204  		<-receiverTrigger
   205  
   206  		for {
   207  			// This will block but eventually return an i/o timeout, as we always set
   208  			// the timeouts before sending anything
   209  			r, err := conn.ReadMsg()
   210  			if err == nil {
   211  				responses <- sharedClientResponse{r, 0, nil}
   212  				continue // receive immediately again
   213  			}
   214  
   215  			// handler is not reading on the channel after closing.
   216  			// UDP connections return net.ErrClosed, while TCP/TLS connections are read
   217  			// via the io package, which return io.EOF.
   218  			if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
   219  				return
   220  			}
   221  
   222  			// send error response to cancel all current requests.
   223  			responses <- sharedClientResponse{nil, 0, err}
   224  
   225  			// wait for a trigger from the handler after any errors. Re-reading in
   226  			// this condition could busy loop, e.g., if a read timeout occurred.
   227  			// receiverTrigger is buffered so that we catch the trigger that may
   228  			// have been sent while we sent the error response above.
   229  			_, ok := <-receiverTrigger
   230  			if !ok {
   231  				return // exit immediately when the trigger channel is closed
   232  			}
   233  		}
   234  	}()
   235  
   236  	type waiter struct {
   237  		ch    chan sharedClientResponse
   238  		start time.Time
   239  	}
   240  	waitingResponses := make(map[uint16]waiter)
   241  	defer func() {
   242  		conn.Close()
   243  		close(receiverTrigger)
   244  
   245  		// Drain responses send by receive loop to allow it to exit.
   246  		// It may be repeatedly reading after an i/o timeout, for example.
   247  		for range responses {
   248  		}
   249  
   250  		for _, waiter := range waitingResponses {
   251  			waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed}
   252  			close(waiter.ch)
   253  		}
   254  	}()
   255  
   256  	for {
   257  		select {
   258  		case req, ok := <-requests:
   259  			if !ok {
   260  				// 'requests' is closed when SharedClient is recycled, which happens
   261  				// responses (or errors) have been received and there are no more
   262  				// requests to be sent.
   263  				return
   264  			}
   265  			start := time.Now()
   266  
   267  			// Check if we already have a request with the same id
   268  			// Due to birthday paradox and the fact that ID is uint16
   269  			// it's likely to happen with small number (~200) of concurrent requests
   270  			// which would result in goroutine leak as we would never close req.ch
   271  			if _, duplicate := waitingResponses[req.msg.Id]; duplicate {
   272  				for n := 0; n < 5; n++ {
   273  					// Try a new ID
   274  					id := dns.Id()
   275  					if _, duplicate = waitingResponses[id]; !duplicate {
   276  						req.msg.Id = id
   277  						break
   278  					}
   279  				}
   280  				if duplicate {
   281  					req.ch <- sharedClientResponse{nil, 0, fmt.Errorf("duplicate request id %d", req.msg.Id)}
   282  					close(req.ch)
   283  					continue
   284  				}
   285  			}
   286  
   287  			err := client.SendContext(req.ctx, req.msg, conn, start)
   288  			if err != nil {
   289  				req.ch <- sharedClientResponse{nil, 0, err}
   290  				close(req.ch)
   291  			} else {
   292  				waitingResponses[req.msg.Id] = waiter{req.ch, start}
   293  
   294  				// Wake up the receiver that may be waiting to receive again
   295  				triggerReceiver()
   296  			}
   297  
   298  		case resp, ok := <-responses:
   299  			if !ok {
   300  				// 'responses' is closed when the receive loop exits, so we quit as
   301  				// nothing can be received any more
   302  				return
   303  			}
   304  			if resp.err != nil {
   305  				// ReadMsg failed, but we cannot match it to a request,
   306  				// so complete all pending requests.
   307  				for _, waiter := range waitingResponses {
   308  					waiter.ch <- sharedClientResponse{nil, 0, resp.err}
   309  					close(waiter.ch)
   310  				}
   311  				waitingResponses = make(map[uint16]waiter)
   312  			} else if resp.msg != nil {
   313  				if waiter, ok := waitingResponses[resp.msg.Id]; ok {
   314  					delete(waitingResponses, resp.msg.Id)
   315  					resp.rtt = time.Since(waiter.start)
   316  					waiter.ch <- resp
   317  					close(waiter.ch)
   318  				}
   319  			}
   320  		}
   321  	}
   322  }
   323  
   324  func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
   325  	c.Lock()
   326  	if c.conn == nil {
   327  		c.conn, err = c.DialContext(ctx, c.serverAddr)
   328  		if err != nil {
   329  			c.Unlock()
   330  			return nil, 0, fmt.Errorf("failed to dial connection to %v: %w", c.serverAddr, err)
   331  		}
   332  		// Start handler for sending and receiving.
   333  		c.wg.Add(1)
   334  		go handler(&c.wg, c.Client, c.conn, c.requests)
   335  	}
   336  	c.Unlock()
   337  
   338  	// This request keeps 'c.requests' open; sending a request may hang indefinitely if
   339  	// the handler happens to quit at the same time. Use ctx.Done to avoid this.
   340  	timeout := getTimeoutForRequest(c.Client)
   341  	ctx, cancel := context.WithTimeout(ctx, timeout)
   342  	defer cancel()
   343  	respCh := make(chan sharedClientResponse)
   344  	select {
   345  	case c.requests <- request{ctx: ctx, msg: m, ch: respCh}:
   346  	case <-ctx.Done():
   347  		return nil, 0, ctx.Err()
   348  	}
   349  
   350  	// Since c.requests is unbuffered, the handler is guaranteed to eventually close 'respCh'
   351  	select {
   352  	case resp := <-respCh:
   353  		return resp.msg, resp.rtt, resp.err
   354  	// This is just fail-safe mechanism in case there is another similar issue
   355  	case <-time.After(time.Minute):
   356  		return nil, 0, fmt.Errorf("timeout waiting for response")
   357  	}
   358  }
   359  
   360  // close closes and waits for the close to finish.
   361  // Must be called while holding client's lock.
   362  func (c *SharedClient) close() {
   363  	close(c.requests)
   364  	c.wg.Wait()
   365  	c.conn = nil
   366  }
   367  
   368  // Return the appropriate timeout for a specific request
   369  func getTimeoutForRequest(c *dns.Client) time.Duration {
   370  	wtimeout := c.WriteTimeout
   371  	if wtimeout == 0 {
   372  		// Some default timeout as seen in miekg/dns.
   373  		wtimeout = time.Second * 2
   374  	}
   375  
   376  	var requestTimeout time.Duration
   377  	if c.Timeout != 0 {
   378  		requestTimeout = c.Timeout
   379  	} else {
   380  		requestTimeout = wtimeout
   381  	}
   382  	// net.Dialer.Timeout has priority if smaller than the timeouts computed so
   383  	// far
   384  	if c.Dialer != nil && c.Dialer.Timeout != 0 {
   385  		if c.Dialer.Timeout < requestTimeout {
   386  			requestTimeout = c.Dialer.Timeout
   387  		}
   388  	}
   389  	return requestTimeout
   390  }