github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/client/userd/trafficmgr/workloads.go (about)

     1  package trafficmgr
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  	"sync"
     8  	"time"
     9  
    10  	apps "k8s.io/api/apps/v1"
    11  	core "k8s.io/api/core/v1"
    12  	meta "k8s.io/apimachinery/pkg/apis/meta/v1"
    13  	"k8s.io/apimachinery/pkg/labels"
    14  	"k8s.io/apimachinery/pkg/runtime"
    15  	"k8s.io/apimachinery/pkg/util/intstr"
    16  	"k8s.io/client-go/tools/cache"
    17  
    18  	"github.com/datawire/dlib/dlog"
    19  	"github.com/datawire/k8sapi/pkg/k8sapi"
    20  )
    21  
    22  type workloadsAndServicesWatcher struct {
    23  	sync.Mutex
    24  	nsWatchers  map[string]*namespacedWASWatcher
    25  	nsListeners []func()
    26  	cond        sync.Cond
    27  }
    28  
    29  const (
    30  	deployments  = 0
    31  	replicasets  = 1
    32  	statefulsets = 2
    33  )
    34  
    35  // namespacedWASWatcher is watches Workloads And Services (WAS) for a namespace.
    36  type namespacedWASWatcher struct {
    37  	svcWatcher *k8sapi.Watcher[*core.Service]
    38  	wlWatchers [3]*k8sapi.Watcher[runtime.Object]
    39  }
    40  
    41  // svcEquals compare only the Service fields that are of interest to Telepresence. They are
    42  //
    43  //   - UID
    44  //   - Name
    45  //   - Namespace
    46  //   - Spec.Ports
    47  //   - Spec.Type
    48  func svcEquals(a, b *core.Service) bool {
    49  	aPorts := a.Spec.Ports
    50  	bPorts := b.Spec.Ports
    51  	if len(aPorts) != len(bPorts) {
    52  		return false
    53  	}
    54  	if a.UID != b.UID || a.Name != b.Name || a.Namespace != b.Namespace || a.Spec.Type != b.Spec.Type {
    55  		return false
    56  	}
    57  nextMP:
    58  	// order is not significant (nor can it be trusted) when comparing
    59  	for _, mp := range aPorts {
    60  		for _, op := range bPorts {
    61  			if mp == op {
    62  				continue nextMP
    63  			}
    64  		}
    65  		return false
    66  	}
    67  	return true
    68  }
    69  
    70  // workloadEquals compare only the workload (Deployment, ResourceSet, or StatefulSet) fields that are of interest to Telepresence. They are
    71  //
    72  //   - UID
    73  //   - Name
    74  //   - Namespace
    75  //   - Spec.Template:
    76  //   - Labels
    77  //   - Containers (must contain an equal number of equally named containers with equal ports)
    78  func workloadEquals(oa, ob runtime.Object) bool {
    79  	a, err := k8sapi.WrapWorkload(oa)
    80  	if err != nil {
    81  		// This should definitely never happen
    82  		panic(err)
    83  	}
    84  	b, err := k8sapi.WrapWorkload(ob)
    85  	if err != nil {
    86  		// This should definitely never happen
    87  		panic(err)
    88  	}
    89  	if a.GetUID() != b.GetUID() || a.GetName() != b.GetName() || a.GetNamespace() != b.GetNamespace() {
    90  		return false
    91  	}
    92  
    93  	aSpec := a.GetPodTemplate()
    94  	bSpec := b.GetPodTemplate()
    95  	if !labels.Equals(aSpec.Labels, bSpec.Labels) {
    96  		return false
    97  	}
    98  	aPod := aSpec.Spec
    99  	bPod := bSpec.Spec
   100  	if len(aPod.Containers) != len(bPod.Containers) {
   101  		return false
   102  	}
   103  	makeContainerMap := func(cs []core.Container) map[string]*core.Container {
   104  		m := make(map[string]*core.Container, len(cs))
   105  		for i := range cs {
   106  			c := &cs[i]
   107  			m[c.Name] = c
   108  		}
   109  		return m
   110  	}
   111  
   112  	portsEqual := func(a, b []core.ContainerPort) bool {
   113  		if len(a) != len(b) {
   114  			return false
   115  		}
   116  	nextAP:
   117  		for _, ap := range a {
   118  			for _, bp := range b {
   119  				if ap == bp {
   120  					continue nextAP
   121  				}
   122  			}
   123  			return false
   124  		}
   125  		return true
   126  	}
   127  
   128  	am := makeContainerMap(aPod.Containers)
   129  	bm := makeContainerMap(bPod.Containers)
   130  	for n, ac := range am {
   131  		bc, ok := bm[n]
   132  		if !ok {
   133  			return false
   134  		}
   135  		if !portsEqual(ac.Ports, bc.Ports) {
   136  			return false
   137  		}
   138  	}
   139  	return true
   140  }
   141  
   142  func newNamespaceWatcher(c context.Context, namespace string, cond *sync.Cond) *namespacedWASWatcher {
   143  	dlog.Debugf(c, "newNamespaceWatcher %s", namespace)
   144  	ki := k8sapi.GetK8sInterface(c)
   145  	appsGetter := ki.AppsV1().RESTClient()
   146  	w := &namespacedWASWatcher{
   147  		svcWatcher: k8sapi.NewWatcher("services", ki.CoreV1().RESTClient(), cond, k8sapi.WithEquals(svcEquals), k8sapi.WithNamespace[*core.Service](namespace)),
   148  		wlWatchers: [3]*k8sapi.Watcher[runtime.Object]{
   149  			k8sapi.NewWatcher("deployments", appsGetter, cond, k8sapi.WithEquals(workloadEquals), k8sapi.WithNamespace[runtime.Object](namespace)),
   150  			k8sapi.NewWatcher("replicasets", appsGetter, cond, k8sapi.WithEquals(workloadEquals), k8sapi.WithNamespace[runtime.Object](namespace)),
   151  			k8sapi.NewWatcher("statefulsets", appsGetter, cond, k8sapi.WithEquals(workloadEquals), k8sapi.WithNamespace[runtime.Object](namespace)),
   152  		},
   153  	}
   154  	return w
   155  }
   156  
   157  func (nw *namespacedWASWatcher) cancel() {
   158  	nw.svcWatcher.Cancel()
   159  	for _, w := range nw.wlWatchers {
   160  		w.Cancel()
   161  	}
   162  }
   163  
   164  func (nw *namespacedWASWatcher) hasSynced() bool {
   165  	return nw.svcWatcher.HasSynced() &&
   166  		nw.wlWatchers[0].HasSynced() &&
   167  		nw.wlWatchers[1].HasSynced() &&
   168  		nw.wlWatchers[2].HasSynced()
   169  }
   170  
   171  func newWASWatcher() *workloadsAndServicesWatcher {
   172  	w := &workloadsAndServicesWatcher{
   173  		nsWatchers: make(map[string]*namespacedWASWatcher),
   174  	}
   175  	w.cond.L = &w.Mutex
   176  	return w
   177  }
   178  
   179  // eachService iterates over the services in the current snapshot. Unless namespace
   180  // is the empty string, the iteration is limited to the services matching that namespace.
   181  // The traffic-manager service is excluded.
   182  func (w *workloadsAndServicesWatcher) eachService(c context.Context, tmns string, namespaces []string, f func(*core.Service)) {
   183  	if len(namespaces) != 1 {
   184  		// Produce workloads in a predictable order
   185  		nss := make([]string, len(namespaces))
   186  		copy(nss, namespaces)
   187  		sort.Strings(nss)
   188  		for _, n := range nss {
   189  			w.eachService(c, tmns, []string{n}, f)
   190  		}
   191  	} else {
   192  		ns := namespaces[0]
   193  		w.Lock()
   194  		nw, ok := w.nsWatchers[ns]
   195  		w.Unlock()
   196  		if ok {
   197  			svcs, err := nw.svcWatcher.List(c)
   198  			if err != nil {
   199  				dlog.Errorf(c, "error listing services: %s", err)
   200  				return
   201  			}
   202  			for _, svc := range svcs {
   203  				// If this is our traffic-manager namespace, then exclude the traffic-manager service.
   204  				if !(ns == tmns && svc.Labels["app"] == "traffic-manager" && svc.Labels["telepresence"] == "manager") {
   205  					f(svc)
   206  				}
   207  			}
   208  		}
   209  	}
   210  }
   211  
   212  func (w *workloadsAndServicesWatcher) waitForSync(c context.Context) {
   213  	hss := make([]cache.InformerSynced, len(w.nsWatchers))
   214  	w.Lock()
   215  	i := 0
   216  	for _, nw := range w.nsWatchers {
   217  		hss[i] = nw.hasSynced
   218  		i++
   219  	}
   220  	w.Unlock()
   221  
   222  	hasSynced := true
   223  	for _, hs := range hss {
   224  		if !hs() {
   225  			hasSynced = false
   226  			break
   227  		}
   228  	}
   229  	if !hasSynced {
   230  		// Waiting for cache sync will sometimes block, so a timeout is necessary here
   231  		c, cancel := context.WithTimeout(c, 5*time.Second)
   232  		defer cancel()
   233  		cache.WaitForCacheSync(c.Done(), hss...)
   234  	}
   235  }
   236  
   237  // subscribe writes to the given channel whenever relevant information has changed
   238  // in the current snapshot.
   239  func (w *workloadsAndServicesWatcher) subscribe(c context.Context) <-chan struct{} {
   240  	return k8sapi.Subscribe(c, &w.cond)
   241  }
   242  
   243  // setNamespacesToWatch starts new watchers or kills old ones to make the current
   244  // set of watchers reflect the nss argument.
   245  func (w *workloadsAndServicesWatcher) setNamespacesToWatch(c context.Context, nss []string) {
   246  	var adds []string
   247  	desired := make(map[string]struct{})
   248  
   249  	w.Lock()
   250  	for _, ns := range nss {
   251  		desired[ns] = struct{}{}
   252  		if _, ok := w.nsWatchers[ns]; !ok {
   253  			adds = append(adds, ns)
   254  		}
   255  	}
   256  	for ns, nw := range w.nsWatchers {
   257  		if _, ok := desired[ns]; !ok {
   258  			delete(w.nsWatchers, ns)
   259  			nw.cancel()
   260  		}
   261  	}
   262  	for _, ns := range adds {
   263  		w.addNSLocked(c, ns)
   264  	}
   265  	w.Unlock()
   266  }
   267  
   268  func (w *workloadsAndServicesWatcher) addNSLocked(c context.Context, ns string) *namespacedWASWatcher {
   269  	nw := newNamespaceWatcher(c, ns, &w.cond)
   270  	w.nsWatchers[ns] = nw
   271  	for _, l := range w.nsListeners {
   272  		nw.svcWatcher.AddStateListener(&k8sapi.StateListener{Cb: l})
   273  	}
   274  	return nw
   275  }
   276  
   277  func (w *workloadsAndServicesWatcher) ensureStarted(c context.Context, ns string, cb func(bool)) {
   278  	w.Lock()
   279  	defer w.Unlock()
   280  	nw, ok := w.nsWatchers[ns]
   281  	if !ok {
   282  		nw = w.addNSLocked(c, ns)
   283  	}
   284  	// Starting the svcWatcher will set it to active and also trigger its state listener
   285  	// which means a) that the set of active namespaces will change, and b) that the
   286  	// WatchAgentsNS will restart with that namespace included.
   287  	err := nw.svcWatcher.EnsureStarted(c, cb)
   288  	if err != nil {
   289  		dlog.Errorf(c, "error starting service watchers: %s", err)
   290  	}
   291  }
   292  
   293  func (w *workloadsAndServicesWatcher) findMatchingWorkloads(c context.Context, svc *core.Service) ([]k8sapi.Workload, error) {
   294  	w.Lock()
   295  	nw := w.nsWatchers[svc.Namespace]
   296  	w.Unlock()
   297  	if nw == nil {
   298  		// Extremely odd, given that the service originated from a namespace watcher
   299  		return nil, fmt.Errorf("no watcher found for namespace %q", svc.Namespace)
   300  	}
   301  	return nw.findMatchingWorkloads(c, svc)
   302  }
   303  
   304  func (nw *namespacedWASWatcher) findMatchingWorkloads(c context.Context, svc *core.Service) ([]k8sapi.Workload, error) {
   305  	ps := svc.Spec.Ports
   306  	targetPortNames := make([]string, 0, len(ps))
   307  	for i := range ps {
   308  		tp := ps[i].TargetPort
   309  		if tp.Type == intstr.String {
   310  			targetPortNames = append(targetPortNames, tp.StrVal)
   311  		} else {
   312  			if tp.IntVal == 0 {
   313  				// targetPort is not specified, so it defaults to the port name
   314  				targetPortNames = append(targetPortNames, ps[i].Name)
   315  			} else {
   316  				// Unless all target ports are named, we cannot really use this as a filter.
   317  				// A numeric target port will map to any container, and containers don't
   318  				// have to expose numbered ports in order to use them.
   319  				targetPortNames = nil
   320  				break
   321  			}
   322  		}
   323  	}
   324  
   325  	var selector labels.Selector
   326  	if sm := svc.Spec.Selector; len(sm) > 0 {
   327  		selector = labels.SelectorFromSet(sm)
   328  	} else {
   329  		// There will be no matching workloads for this service
   330  		return nil, nil
   331  	}
   332  
   333  	var allWls []k8sapi.Workload
   334  	for i, wlw := range nw.wlWatchers {
   335  		wls, err := wlw.List(c)
   336  		if err != nil {
   337  			return nil, err
   338  		}
   339  		for _, o := range wls {
   340  			var wl k8sapi.Workload
   341  			switch i {
   342  			case deployments:
   343  				wl = k8sapi.Deployment(o.(*apps.Deployment))
   344  			case replicasets:
   345  				wl = k8sapi.ReplicaSet(o.(*apps.ReplicaSet))
   346  			case statefulsets:
   347  				wl = k8sapi.StatefulSet(o.(*apps.StatefulSet))
   348  			}
   349  			if selector.Matches(labels.Set(wl.GetPodTemplate().Labels)) {
   350  				owl, err := nw.maybeReplaceWithOwner(c, wl)
   351  				if err != nil {
   352  					return nil, err
   353  				}
   354  				allWls = append(allWls, owl)
   355  			}
   356  		}
   357  	}
   358  
   359  	// Prefer entries with matching ports. I.e. strip all non-matching if matching entries
   360  	// are found.
   361  	if pfWls := filterByNamedTargetPort(c, targetPortNames, allWls); len(pfWls) > 0 {
   362  		allWls = pfWls
   363  	}
   364  	return allWls, nil
   365  }
   366  
   367  func (nw *namespacedWASWatcher) maybeReplaceWithOwner(c context.Context, wl k8sapi.Workload) (k8sapi.Workload, error) {
   368  	var err error
   369  	for _, or := range wl.GetOwnerReferences() {
   370  		if or.Controller != nil && *or.Controller && or.Kind == "Deployment" {
   371  			// Chances are that the owner's labels doesn't match, but we really want the owner anyway.
   372  			wl, err = nw.replaceWithOwner(c, wl, or.Kind, or.Name)
   373  			break
   374  		}
   375  	}
   376  	return wl, err
   377  }
   378  
   379  func (nw *namespacedWASWatcher) replaceWithOwner(c context.Context, wl k8sapi.Workload, kind, name string) (k8sapi.Workload, error) {
   380  	od, found, err := nw.wlWatchers[deployments].Get(c, &apps.Deployment{
   381  		ObjectMeta: meta.ObjectMeta{
   382  			Name:      name,
   383  			Namespace: wl.GetNamespace(),
   384  		},
   385  	})
   386  	switch {
   387  	case err != nil:
   388  		return nil, fmt.Errorf("get %s owner %s for %s %s.%s: %v",
   389  			kind, name, wl.GetKind(), wl.GetName(), wl.GetNamespace(), err)
   390  	case found:
   391  		dlog.Debugf(c, "replacing %s %s.%s, with owner %s %s", wl.GetKind(), wl.GetName(), wl.GetNamespace(), kind, name)
   392  		return k8sapi.Deployment(od.(*apps.Deployment)), nil
   393  	default:
   394  		return nil, fmt.Errorf("get %s owner %s for %s %s.%s: not found", kind, name, wl.GetKind(), wl.GetName(), wl.GetNamespace())
   395  	}
   396  }
   397  
   398  func filterByNamedTargetPort(c context.Context, targetPortNames []string, wls []k8sapi.Workload) []k8sapi.Workload {
   399  	if len(targetPortNames) == 0 {
   400  		// service ports are not all named
   401  		return wls
   402  	}
   403  	var filtered []k8sapi.Workload
   404  nextWL:
   405  	for _, wl := range wls {
   406  		cs := wl.GetPodTemplate().Spec.Containers
   407  		for ci := range cs {
   408  			ps := cs[ci].Ports
   409  			for pi := range ps {
   410  				name := ps[pi].Name
   411  				for _, tpn := range targetPortNames {
   412  					if name == tpn {
   413  						filtered = append(filtered, wl)
   414  						continue nextWL
   415  					}
   416  				}
   417  			}
   418  		}
   419  		dlog.Debugf(c, "skipping %s %s.%s, it has no matching ports", wl.GetKind(), wl.GetName(), wl.GetNamespace())
   420  	}
   421  	return filtered
   422  }