github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/client/agentpf/clients.go (about)

     1  package agentpf
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/puzpuzpuz/xsync/v3"
    14  	"go.opentelemetry.io/otel"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/status"
    18  
    19  	"github.com/datawire/dlib/dlog"
    20  	"github.com/datawire/dlib/dtime"
    21  	"github.com/telepresenceio/telepresence/rpc/v2/agent"
    22  	"github.com/telepresenceio/telepresence/rpc/v2/manager"
    23  	"github.com/telepresenceio/telepresence/v2/pkg/client/k8sclient"
    24  	"github.com/telepresenceio/telepresence/v2/pkg/dnet"
    25  	"github.com/telepresenceio/telepresence/v2/pkg/iputil"
    26  	"github.com/telepresenceio/telepresence/v2/pkg/tunnel"
    27  )
    28  
    29  type client struct {
    30  	// Mutex protects the following fields (the rest is immutable)
    31  	//   info.intercepted
    32  	//   cli
    33  	//   cancelClient
    34  	//   cancelDialWatch
    35  	// cli and cancelClient are both safe to use without a mutex once the ready channel is closed.
    36  	sync.Mutex
    37  	cli             agent.AgentClient
    38  	session         *manager.SessionInfo
    39  	info            *manager.AgentPodInfo
    40  	ready           chan error
    41  	cancelClient    context.CancelFunc
    42  	cancelDialWatch context.CancelFunc
    43  	tunnelCount     int32
    44  }
    45  
    46  func (ac *client) String() string {
    47  	if ac == nil {
    48  		return "<nil>"
    49  	}
    50  	ai := ac.info
    51  	return fmt.Sprintf("%s.%s:%d", ai.PodName, ai.Namespace, ai.ApiPort)
    52  }
    53  
    54  func (ac *client) Tunnel(ctx context.Context, opts ...grpc.CallOption) (tunnel.Client, error) {
    55  	select {
    56  	case err, ok := <-ac.ready:
    57  		if ok {
    58  			return nil, err
    59  		}
    60  		// ready channel is closed. We are ready to go.
    61  	case <-ctx.Done():
    62  		return nil, ctx.Err()
    63  	}
    64  	tc, err := ac.cli.Tunnel(ctx, opts...)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	atomic.AddInt32(&ac.tunnelCount, 1)
    69  	dlog.Tracef(ctx, "%s(%s) have %d active tunnels", ac, net.IP(ac.info.PodIp), ac.tunnelCount)
    70  	go func() {
    71  		<-ctx.Done()
    72  		atomic.AddInt32(&ac.tunnelCount, -1)
    73  		dlog.Tracef(ctx, "%s(%s) have %d active tunnels", ac, net.IP(ac.info.PodIp), ac.tunnelCount)
    74  	}()
    75  	return tc, nil
    76  }
    77  
    78  func (ac *client) connect(ctx context.Context, deleteMe func()) {
    79  	defer close(ac.ready)
    80  	pfDialer := dnet.GetPortForwardDialer(ctx)
    81  	if pfDialer == nil {
    82  		return
    83  	}
    84  
    85  	dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
    86  	defer dialCancel()
    87  
    88  	conn, cli, _, err := k8sclient.ConnectToAgent(dialCtx, pfDialer.Dial, ac.info.PodName, ac.info.Namespace, uint16(ac.info.ApiPort))
    89  	if err != nil {
    90  		deleteMe()
    91  		ac.ready <- err
    92  		return
    93  	}
    94  
    95  	ac.Lock()
    96  	ac.cli = cli
    97  	ac.cancelClient = func() {
    98  		conn.Close()
    99  	}
   100  	intercepted := ac.info.Intercepted
   101  	ac.Unlock()
   102  	if intercepted {
   103  		if err = ac.startDialWatcherReady(ctx); err != nil {
   104  			deleteMe()
   105  			ac.ready <- err
   106  		}
   107  	}
   108  }
   109  
   110  func (ac *client) busy() bool {
   111  	ac.Lock()
   112  	bzy := ac.cli == nil || ac.info.Intercepted || atomic.LoadInt32(&ac.tunnelCount) > 0
   113  	ac.Unlock()
   114  	return bzy
   115  }
   116  
   117  func (ac *client) intercepted() bool {
   118  	ac.Lock()
   119  	ret := ac.info.Intercepted
   120  	ac.Unlock()
   121  	return ret
   122  }
   123  
   124  func (ac *client) cancel() {
   125  	ac.Lock()
   126  	cc := ac.cancelClient
   127  	cdw := ac.cancelDialWatch
   128  	ac.Unlock()
   129  	if cc != nil {
   130  		cc()
   131  	}
   132  	if cdw != nil {
   133  		cdw()
   134  	}
   135  }
   136  
   137  func (ac *client) setIntercepted(ctx context.Context, k string, status bool) {
   138  	ac.Lock()
   139  	aci := ac.info.Intercepted
   140  	ac.Unlock()
   141  	if status {
   142  		if aci {
   143  			return
   144  		}
   145  		dlog.Debugf(ctx, "Agent %s changed to intercepted", k)
   146  		if err := ac.startDialWatcher(ctx); err != nil {
   147  			dlog.Errorf(ctx, "failed to start client watcher for %s: %v", k, err)
   148  		}
   149  		// This agent is now intercepting. Start a dial watcher.
   150  	} else {
   151  		if !aci {
   152  			return
   153  		}
   154  
   155  		// This agent is no longer intercepting. Stop the dial watcher
   156  		dlog.Debugf(ctx, "Agent %s changed to not intercepted", k)
   157  		ac.Lock()
   158  		cdw := ac.cancelDialWatch
   159  		ac.Unlock()
   160  		if cdw != nil {
   161  			cdw()
   162  		}
   163  	}
   164  }
   165  
   166  func (ac *client) startDialWatcher(ctx context.Context) error {
   167  	// Not called from the startup go routine, so wait for that routine to finish
   168  	select {
   169  	case err, ok := <-ac.ready:
   170  		if ok {
   171  			return err
   172  		}
   173  		// ready channel is closed. We are ready to go.
   174  	case <-ctx.Done():
   175  		return ctx.Err()
   176  	}
   177  	return ac.startDialWatcherReady(ctx)
   178  }
   179  
   180  func (ac *client) startDialWatcherReady(ctx context.Context) error {
   181  	ctx, cancel := context.WithCancel(ctx)
   182  
   183  	// Create the dial watcher
   184  	dlog.Debugf(ctx, "watching dials from agent pod %s", ac)
   185  	watcher, err := ac.cli.WatchDial(ctx, ac.session)
   186  	if err != nil {
   187  		cancel()
   188  		return err
   189  	}
   190  
   191  	ac.Lock()
   192  	ac.info.Intercepted = true
   193  	ac.cancelDialWatch = func() {
   194  		ac.Lock()
   195  		ac.info.Intercepted = false
   196  		ac.cancelDialWatch = nil
   197  		ac.Unlock()
   198  		cancel()
   199  	}
   200  	ac.Unlock()
   201  
   202  	go func() {
   203  		err := tunnel.DialWaitLoop(ctx, tunnel.AgentProvider(ac.cli), watcher, ac.session.SessionId)
   204  		if err != nil {
   205  			dlog.Error(ctx, err)
   206  		}
   207  	}()
   208  	return nil
   209  }
   210  
   211  type Clients interface {
   212  	GetClient(net.IP) tunnel.Provider
   213  	WatchAgentPods(ctx context.Context, rmc manager.ManagerClient) error
   214  	WaitForIP(ctx context.Context, timeout time.Duration, ip net.IP) error
   215  	WaitForWorkload(ctx context.Context, timeout time.Duration, name string) error
   216  	GetWorkloadClient(workload string) (ag tunnel.Provider)
   217  	SetProxyVia(workload string)
   218  }
   219  
   220  type clients struct {
   221  	session   *manager.SessionInfo
   222  	clients   *xsync.MapOf[string, *client]
   223  	ipWaiters *xsync.MapOf[iputil.IPKey, chan struct{}]
   224  	wlWaiters *xsync.MapOf[string, chan struct{}]
   225  	proxyVias *xsync.MapOf[string, struct{}]
   226  	disabled  atomic.Bool
   227  }
   228  
   229  func NewClients(session *manager.SessionInfo) Clients {
   230  	return &clients{
   231  		session:   session,
   232  		clients:   xsync.NewMapOf[string, *client](),
   233  		ipWaiters: xsync.NewMapOf[iputil.IPKey, chan struct{}](),
   234  		wlWaiters: xsync.NewMapOf[string, chan struct{}](),
   235  		proxyVias: xsync.NewMapOf[string, struct{}](),
   236  	}
   237  }
   238  
   239  // GetClient returns tunnel.Provider that opens a tunnel to a known traffic-agent.
   240  // The traffic-agent is chosen using the following rules in the order mentioned:
   241  //
   242  //  1. agent has a pod_ip that matches the given ip
   243  //  2. agent is currently intercepted by this client
   244  //  3. any agent
   245  //
   246  // The function returns nil when there are no agents in the connected namespace.
   247  func (s *clients) GetClient(ip net.IP) (pvd tunnel.Provider) {
   248  	var primary, secondary, ternary tunnel.Provider
   249  	s.clients.Range(func(_ string, c *client) bool {
   250  		switch {
   251  		case ip.Equal(c.info.PodIp):
   252  			primary = c
   253  		case c.intercepted():
   254  			secondary = c
   255  		default:
   256  			ternary = c
   257  		}
   258  		return primary == nil
   259  	})
   260  	switch {
   261  	case primary != nil:
   262  		pvd = primary
   263  	case secondary != nil:
   264  		pvd = secondary
   265  	default:
   266  		pvd = ternary
   267  	}
   268  	return pvd
   269  }
   270  
   271  // GetWorkloadClient returns tunnel.Provider that opens a tunnel to a traffic-agent that
   272  // belongs to a pod created for the given workload.
   273  //
   274  // The function returns nil when there are no agents for the given workload in the connected namespace.
   275  func (s *clients) GetWorkloadClient(workload string) (pvd tunnel.Provider) {
   276  	s.clients.Range(func(_ string, ac *client) bool {
   277  		if ac.info.WorkloadName == workload {
   278  			pvd = ac
   279  			return false
   280  		}
   281  		return true
   282  	})
   283  	return
   284  }
   285  
   286  func (s *clients) SetProxyVia(workload string) {
   287  	s.proxyVias.Store(workload, struct{}{})
   288  }
   289  
   290  func (s *clients) isProxyVIA(info *manager.AgentPodInfo) bool {
   291  	_, isPV := s.proxyVias.Load(info.WorkloadName)
   292  	return isPV
   293  }
   294  
   295  func (s *clients) hasWaiterFor(info *manager.AgentPodInfo) bool {
   296  	if _, isW := s.ipWaiters.Load(iputil.IPKey(info.PodIp)); isW {
   297  		return true
   298  	}
   299  	if _, isW := s.wlWaiters.Load(info.WorkloadName); isW {
   300  		return true
   301  	}
   302  	return false
   303  }
   304  
   305  func (s *clients) WatchAgentPods(ctx context.Context, rmc manager.ManagerClient) error {
   306  	dlog.Debug(ctx, "WatchAgentPods starting")
   307  	defer func() {
   308  		dlog.Debugf(ctx, "WatchAgentPods ending with %d clients still active", s.clients.Size())
   309  		s.clients.Range(func(_ string, ac *client) bool {
   310  			ac.cancel()
   311  			return true
   312  		})
   313  		s.disabled.Store(true)
   314  	}()
   315  	backoff := 100 * time.Millisecond
   316  
   317  outer:
   318  	for ctx.Err() == nil {
   319  		as, err := rmc.WatchAgentPods(ctx, s.session)
   320  		switch status.Code(err) {
   321  		case codes.OK:
   322  		case codes.Unavailable:
   323  			dtime.SleepWithContext(ctx, backoff)
   324  			backoff *= 2
   325  			if backoff > 15*time.Second {
   326  				backoff = 15 * time.Second
   327  			}
   328  			continue outer
   329  		case codes.Unimplemented:
   330  			dlog.Debug(ctx, "traffic-manager does not implement WatchAgentPods")
   331  			return nil
   332  		default:
   333  			err = fmt.Errorf("error when calling WatchAgents: %w", err)
   334  			dlog.Warn(ctx, err)
   335  			return err
   336  		}
   337  
   338  		for ctx.Err() == nil {
   339  			ais, err := as.Recv()
   340  			if errors.Is(err, io.EOF) {
   341  				return nil
   342  			}
   343  			switch status.Code(err) {
   344  			case codes.OK:
   345  				ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "AgentClientUpdate")
   346  				err = s.updateClients(ctx, ais.Agents)
   347  				span.End()
   348  				if err != nil {
   349  					return err
   350  				}
   351  			case codes.Unavailable:
   352  				dtime.SleepWithContext(ctx, backoff)
   353  				backoff *= 2
   354  				if backoff > 15*time.Second {
   355  					backoff = 15 * time.Second
   356  				}
   357  				continue outer
   358  			case codes.Unimplemented:
   359  				dlog.Debug(ctx, "traffic-manager does not implement WatchAgentPods")
   360  				return nil
   361  			default:
   362  				return err
   363  			}
   364  		}
   365  	}
   366  	return nil
   367  }
   368  
   369  func (s *clients) notifyWaiters() {
   370  	s.clients.Range(func(name string, ac *client) bool {
   371  		if waiter, ok := s.ipWaiters.LoadAndDelete(iputil.IPKey(ac.info.PodIp)); ok {
   372  			close(waiter)
   373  		}
   374  		if waiter, ok := s.wlWaiters.LoadAndDelete(ac.info.WorkloadName); ok {
   375  			close(waiter)
   376  		}
   377  		return true
   378  	})
   379  }
   380  
   381  func (s *clients) waitWithTimeout(ctx context.Context, timeout time.Duration, waitOn <-chan struct{}) error {
   382  	s.notifyWaiters()
   383  	ctx, cancel := context.WithTimeout(ctx, timeout)
   384  	defer cancel()
   385  	select {
   386  	case <-waitOn:
   387  		return nil
   388  	case <-ctx.Done():
   389  		return ctx.Err()
   390  	}
   391  }
   392  
   393  func (s *clients) WaitForIP(ctx context.Context, timeout time.Duration, ip net.IP) error {
   394  	if s.disabled.Load() {
   395  		return nil
   396  	}
   397  	waitOn, ok := s.ipWaiters.Compute(iputil.IPKey(ip), func(oldValue chan struct{}, loaded bool) (chan struct{}, bool) {
   398  		if loaded {
   399  			return oldValue, false
   400  		}
   401  		found := false
   402  		s.clients.Range(func(k string, ac *client) bool {
   403  			if ip.Equal(ac.info.PodIp) {
   404  				found = true
   405  				return false
   406  			}
   407  			return true
   408  		})
   409  		if found {
   410  			return nil, true
   411  		}
   412  		return make(chan struct{}), false
   413  	})
   414  	if ok {
   415  		return s.waitWithTimeout(ctx, timeout, waitOn)
   416  	}
   417  	// No chan created because the agent already exists
   418  	return nil
   419  }
   420  
   421  func (s *clients) WaitForWorkload(ctx context.Context, timeout time.Duration, name string) error {
   422  	if s.disabled.Load() {
   423  		return nil
   424  	}
   425  
   426  	// Create a channel to subscribe to, but only if the agent doesn't already exist.
   427  	waitOn, ok := s.wlWaiters.Compute(name, func(oldValue chan struct{}, loaded bool) (chan struct{}, bool) {
   428  		if loaded {
   429  			return oldValue, false
   430  		}
   431  		found := false
   432  		s.clients.Range(func(k string, ac *client) bool {
   433  			if ac.info.WorkloadName == name {
   434  				found = true
   435  				return false
   436  			}
   437  			return true
   438  		})
   439  		if found {
   440  			return nil, true
   441  		}
   442  		return make(chan struct{}), false
   443  	})
   444  	if ok {
   445  		return s.waitWithTimeout(ctx, timeout, waitOn)
   446  	}
   447  	// No chan created because the agent already exists
   448  	return nil
   449  }
   450  
   451  func (s *clients) updateClients(ctx context.Context, ais []*manager.AgentPodInfo) error {
   452  	defer s.notifyWaiters()
   453  
   454  	var aim map[string]*manager.AgentPodInfo
   455  	if len(ais) > 0 {
   456  		aim = make(map[string]*manager.AgentPodInfo, len(ais))
   457  		for _, ai := range ais {
   458  			if ai.PodName != "" {
   459  				aim[ai.PodName+"."+ai.Namespace] = ai
   460  			}
   461  		}
   462  		if len(aim) == 0 {
   463  			// The current traffic-manager injects old style clients that doesn't report a pod name.
   464  			dlog.Debugf(ctx, "disabling, because traffic-agent doesn't report pod name")
   465  			s.disabled.Store(true)
   466  			return nil
   467  		}
   468  	}
   469  
   470  	deleteClient := func(k string) {
   471  		s.clients.Compute(k, func(oldValue *client, loaded bool) (*client, bool) {
   472  			if loaded {
   473  				dlog.Debugf(ctx, "Deleting agent %s", k)
   474  				oldValue.cancel()
   475  			}
   476  			return nil, true
   477  		})
   478  	}
   479  
   480  	// Cancel clients that no longer exist.
   481  	s.clients.Range(func(k string, _ *client) bool {
   482  		if _, ok := aim[k]; !ok {
   483  			deleteClient(k)
   484  		}
   485  		return true
   486  	})
   487  
   488  	// Refresh current clients
   489  	for k, ai := range aim {
   490  		if ac, ok := s.clients.Load(k); ok {
   491  			ac.setIntercepted(ctx, k, ai.Intercepted)
   492  		}
   493  	}
   494  
   495  	addClient := func(k string, ai *manager.AgentPodInfo) {
   496  		_, _ = s.clients.Compute(k, func(oldValue *client, loaded bool) (*client, bool) {
   497  			if loaded {
   498  				return oldValue, false
   499  			}
   500  			ac := &client{
   501  				ready:   make(chan error),
   502  				session: s.session,
   503  				info:    ai,
   504  			}
   505  			go ac.connect(ctx, func() {
   506  				deleteClient(k)
   507  			})
   508  			return ac, false
   509  		})
   510  	}
   511  
   512  	// Add clients for newly arrived intercepts
   513  	for k, ai := range aim {
   514  		if ai.Intercepted || s.isProxyVIA(ai) || s.hasWaiterFor(ai) {
   515  			addClient(k, ai)
   516  		}
   517  	}
   518  
   519  	// Terminate all non-intercepting idle agents except the last one.
   520  	s.clients.Range(func(k string, ac *client) bool {
   521  		if s.clients.Size() <= 1 {
   522  			return false
   523  		}
   524  		if !ac.busy() && !s.isProxyVIA(ac.info) && !s.hasWaiterFor(ac.info) {
   525  			deleteClient(k)
   526  		}
   527  		return true
   528  	})
   529  
   530  	// Ensure that we have at least one client (if at least one agent exists)
   531  	if s.clients.Size() == 0 && len(aim) > 0 {
   532  		var ai *manager.AgentPodInfo
   533  		for _, ai = range aim {
   534  			break
   535  		}
   536  		k := ai.PodName + "." + ai.Namespace
   537  		addClient(k, ai)
   538  	}
   539  	return nil
   540  }