github.com/castai/kvisor@v1.7.1-0.20240516114728-b3572a2607b5/pkg/containers/client.go (about)

     1  package containers
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/castai/kvisor/pkg/cgroup"
    11  	"github.com/castai/kvisor/pkg/logging"
    12  	"github.com/castai/kvisor/pkg/metrics"
    13  	containerdContainers "github.com/containerd/containerd/containers"
    14  	"github.com/samber/lo"
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/status"
    17  )
    18  
    19  var (
    20  	ErrContainerNotFound = errors.New("container not found")
    21  )
    22  
    23  type ContainerCreatedListener func(c *Container)
    24  type ContainerDeletedListener func(c *Container)
    25  
    26  type Container struct {
    27  	ID           string
    28  	Name         string
    29  	CgroupID     uint64
    30  	PodNamespace string
    31  	PodUID       string
    32  	PodName      string
    33  	Cgroup       *cgroup.Cgroup
    34  	PIDs         []uint32
    35  	Err          error
    36  }
    37  
    38  // Client is generic container client.
    39  type Client struct {
    40  	log             *logging.Logger
    41  	containerClient *containerClient
    42  	cgroupClient    *cgroup.Client
    43  
    44  	containersByCgroup map[uint64]*Container
    45  	mu                 sync.RWMutex
    46  
    47  	containerCreatedListeners []ContainerCreatedListener
    48  	containerDeletedListeners []ContainerDeletedListener
    49  	listenerMu                sync.RWMutex
    50  }
    51  
    52  func NewClient(log *logging.Logger, cgroupClient *cgroup.Client, containerdSock string) (*Client, error) {
    53  	contClient, err := newContainerClient(containerdSock)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	return &Client{
    58  		log:                log.WithField("component", "cgroups"),
    59  		containerClient:    contClient,
    60  		cgroupClient:       cgroupClient,
    61  		containersByCgroup: map[uint64]*Container{},
    62  	}, nil
    63  }
    64  
    65  func (c *Client) ListContainers() []*Container {
    66  	c.mu.RLock()
    67  	defer c.mu.RUnlock()
    68  	return lo.Filter(lo.Values(c.containersByCgroup), func(item *Container, index int) bool {
    69  		return item.Err == nil && item.Cgroup != nil
    70  	})
    71  }
    72  
    73  func (c *Client) addContainerByCgroupID(ctx context.Context, cgroupID cgroup.ID) (cont *Container, rerrr error) {
    74  	defer func() {
    75  		if rerrr != nil {
    76  			// TODO: This is quick fix to prevent constant search for invalid containers.
    77  			// Check for some better error handling. For example container client network error could occur.
    78  			cont = &Container{
    79  				Err: rerrr,
    80  			}
    81  			c.mu.Lock()
    82  			c.containersByCgroup[cgroupID] = cont
    83  			c.mu.Unlock()
    84  		}
    85  	}()
    86  
    87  	cg, err := c.cgroupClient.GetCgroupForID(cgroupID)
    88  	// The found cgroup is not a container.
    89  	if err != nil || cg.ContainerID == "" {
    90  		return nil, ErrContainerNotFound
    91  	}
    92  
    93  	container, err := c.containerClient.client.ContainerService().Get(ctx, cg.ContainerID)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	return c.addContainerWithCgroup(container, cg)
    99  }
   100  
   101  func (c *Client) addContainerWithCgroup(container containerdContainers.Container, cg *cgroup.Cgroup) (cont *Container, rerrr error) {
   102  	podNamespace := container.Labels["io.kubernetes.pod.namespace"]
   103  	containerName := container.Labels["io.kubernetes.container.name"]
   104  	podName := container.Labels["io.kubernetes.pod.name"]
   105  	podID := container.Labels["io.kubernetes.pod.uid"]
   106  
   107  	// Only containerd is supported right now.
   108  	// TODO: We also allow docker here, but support only docker shim. If container type docker we assume that it's still uses containerd.
   109  	if cg.ContainerRuntime != cgroup.ContainerdRuntime && cg.ContainerRuntime != cgroup.DockerRuntime {
   110  		return nil, ErrContainerNotFound
   111  	}
   112  
   113  	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
   114  	defer cancel()
   115  	pids, err := c.containerClient.getContainerPids(ctx, cg.ContainerID)
   116  	if err != nil {
   117  		if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound {
   118  			return nil, ErrContainerNotFound
   119  		}
   120  		return nil, fmt.Errorf("get container pids: %w", err)
   121  	}
   122  
   123  	cont = &Container{
   124  		ID:           cg.ContainerID,
   125  		Name:         containerName,
   126  		CgroupID:     cg.Id,
   127  		PodNamespace: podNamespace,
   128  		PodUID:       podID,
   129  		PodName:      podName,
   130  		Cgroup:       cg,
   131  		PIDs:         pids,
   132  	}
   133  
   134  	c.mu.Lock()
   135  	c.containersByCgroup[cg.Id] = cont
   136  	c.mu.Unlock()
   137  
   138  	c.log.Debugf("added container, id=%s pod=%s name=%s", container.ID, podName, containerName)
   139  
   140  	go c.fireContainerCreatedListeners(cont)
   141  
   142  	return cont, nil
   143  }
   144  
   145  func (c *Client) GetContainerForCgroup(ctx context.Context, cgroup uint64) (*Container, error) {
   146  	container, found, err := c.LookupContainerForCgroupInCache(cgroup)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	if !found {
   152  		metrics.AgentLoadContainerByCgroup.Inc()
   153  		return c.addContainerByCgroupID(ctx, cgroup)
   154  	}
   155  
   156  	return container, nil
   157  }
   158  
   159  func (c *Client) LookupContainerForCgroupInCache(cgroup uint64) (*Container, bool, error) {
   160  	c.mu.RLock()
   161  	container, found := c.containersByCgroup[cgroup]
   162  	c.mu.RUnlock()
   163  
   164  	if !found {
   165  		return nil, false, nil
   166  	}
   167  
   168  	if container.Err != nil {
   169  		return nil, true, container.Err
   170  	}
   171  
   172  	return container, true, nil
   173  }
   174  
   175  func (c *Client) CleanupCgroup(cgroup cgroup.ID) {
   176  	c.mu.Lock()
   177  	container := c.containersByCgroup[cgroup]
   178  	delete(c.containersByCgroup, cgroup)
   179  	c.mu.Unlock()
   180  
   181  	if container != nil {
   182  		c.fireContainerDeletedListeners(container)
   183  	}
   184  }
   185  
   186  func (c *Client) GetCgroupsInNamespace(namespace string) []uint64 {
   187  	c.mu.RLock()
   188  	defer c.mu.RUnlock()
   189  
   190  	var result []uint64
   191  
   192  	for cgroup, container := range c.containersByCgroup {
   193  		if container.PodNamespace == namespace {
   194  			result = append(result, cgroup)
   195  		}
   196  	}
   197  
   198  	return result
   199  }
   200  
   201  func (c *Client) RegisterContainerCreatedListener(l ContainerCreatedListener) {
   202  	c.listenerMu.Lock()
   203  	defer c.listenerMu.Unlock()
   204  
   205  	c.containerCreatedListeners = append(c.containerCreatedListeners, l)
   206  }
   207  
   208  func (c *Client) RegisterContainerDeletedListener(l ContainerDeletedListener) {
   209  	c.listenerMu.Lock()
   210  	defer c.listenerMu.Unlock()
   211  
   212  	c.containerDeletedListeners = append(c.containerDeletedListeners, l)
   213  }
   214  
   215  func (c *Client) fireContainerCreatedListeners(container *Container) {
   216  	c.listenerMu.RLock()
   217  	listeners := c.containerCreatedListeners
   218  	c.listenerMu.RUnlock()
   219  
   220  	for _, l := range listeners {
   221  		l(container)
   222  	}
   223  }
   224  
   225  func (c *Client) fireContainerDeletedListeners(container *Container) {
   226  	c.listenerMu.RLock()
   227  	listeners := c.containerDeletedListeners
   228  	c.listenerMu.RUnlock()
   229  
   230  	for _, l := range listeners {
   231  		l(container)
   232  	}
   233  }
   234  
   235  func (c *Client) GetCgroupCpuStats(cont *Container) (*cgroup.CPUStat, error) {
   236  	return cont.Cgroup.CpuStat()
   237  }
   238  
   239  func (c *Client) GetCgroupMemoryStats(cont *Container) (*cgroup.MemoryStat, error) {
   240  	return cont.Cgroup.MemoryStat()
   241  }