github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/client/rootd/dns/connpool_linux.go (about)

     1  package dns
     2  
     3  import (
     4  	"container/heap"
     5  	"context"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/miekg/dns"
    11  )
    12  
    13  type ConnPool struct {
    14  	items      map[*dns.Conn]bool
    15  	newArrival chan *waitingClient
    16  	finished   chan *dns.Conn
    17  	clients    clientQueue
    18  	cancel     context.CancelFunc
    19  	remoteAddr string
    20  }
    21  
    22  func NewConnPool(addr string, poolSize int) (*ConnPool, error) {
    23  	cCtx, cCancel := context.WithCancel(context.Background())
    24  	pool := &ConnPool{
    25  		items:      make(map[*dns.Conn]bool, poolSize),
    26  		newArrival: make(chan *waitingClient),
    27  		finished:   make(chan *dns.Conn),
    28  		cancel:     cCancel,
    29  		remoteAddr: addr,
    30  	}
    31  	heap.Init(&pool.clients)
    32  	for i := 0; i < poolSize; i++ {
    33  		conn, err := dns.Dial("udp", net.JoinHostPort(addr, "53"))
    34  		if err != nil {
    35  			return nil, fmt.Errorf("unable to create DNS conn to %s: %w", addr, err)
    36  		}
    37  		pool.items[conn] = false
    38  	}
    39  	go pool.coordinate(cCtx)
    40  	return pool, nil
    41  }
    42  
    43  func (cp *ConnPool) LocalAddrs() []*net.UDPAddr {
    44  	retval := make([]*net.UDPAddr, len(cp.items))
    45  	i := 0
    46  	for conn := range cp.items {
    47  		retval[i] = conn.LocalAddr().(*net.UDPAddr)
    48  		i++
    49  	}
    50  	return retval
    51  }
    52  
    53  func (cp *ConnPool) RemoteAddr() string {
    54  	return cp.remoteAddr
    55  }
    56  
    57  func (cp *ConnPool) Exchange(ctx context.Context, client *dns.Client, msg *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
    58  	conn, err := cp.getConnection(ctx)
    59  	if err != nil {
    60  		return nil, time.Duration(0), err
    61  	}
    62  	defer cp.releaseConnection(conn)
    63  	return client.ExchangeWithConn(msg, conn)
    64  }
    65  
    66  func (cp *ConnPool) Close() {
    67  	cp.cancel()
    68  	for conn := range cp.items {
    69  		conn.Close()
    70  	}
    71  }
    72  
    73  func (cp *ConnPool) coordinate(ctx context.Context) {
    74  	for {
    75  		select {
    76  		case <-ctx.Done():
    77  			return
    78  		case client := <-cp.newArrival:
    79  			heap.Push(&cp.clients, client)
    80  		case conn := <-cp.finished:
    81  			cp.items[conn] = false
    82  		}
    83  		for conn, inUse := range cp.items {
    84  			if !inUse && len(cp.clients) > 0 {
    85  				cp.items[conn] = true
    86  				client := heap.Pop(&cp.clients).(*waitingClient)
    87  				select {
    88  				case client.returnCh <- conn:
    89  				case <-client.doneCh:
    90  					cp.items[conn] = false
    91  				case <-ctx.Done():
    92  					return
    93  				}
    94  			}
    95  		}
    96  	}
    97  }
    98  
    99  func (cp *ConnPool) getConnection(ctx context.Context) (*dns.Conn, error) {
   100  	client := &waitingClient{
   101  		arrivalTime: time.Now(),
   102  		returnCh:    make(chan *dns.Conn),
   103  		doneCh:      ctx.Done(),
   104  	}
   105  	select {
   106  	case cp.newArrival <- client:
   107  	case <-ctx.Done():
   108  		return nil, ctx.Err()
   109  	}
   110  	select {
   111  	case conn := <-client.returnCh:
   112  		return conn, nil
   113  	case <-ctx.Done():
   114  		return nil, ctx.Err()
   115  	}
   116  }
   117  
   118  func (cp *ConnPool) releaseConnection(conn *dns.Conn) {
   119  	cp.finished <- conn
   120  }