github.com/castai/kvisor@v1.7.1-0.20240516114728-b3572a2607b5/pkg/proc/proc.go (about)

     1  package proc
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io/fs"
     8  	"os"
     9  	"path"
    10  	"strconv"
    11  	"syscall"
    12  
    13  	"github.com/samber/lo"
    14  )
    15  
    16  // Path to proc filesystem.
    17  const Path = "/proc"
    18  
    19  func GetFS() ProcFS {
    20  	// DirFS guarantees to return a fs.StatFS, fs.ReadFileFS and fs.ReadDirFS implementation, hence we can simply cast it here
    21  	return os.DirFS("/proc").(ProcFS)
    22  }
    23  
    24  type PID = uint32
    25  type NamespaceID = uint64
    26  
    27  type NamespaceType string
    28  
    29  const (
    30  	PIDNamespace   NamespaceType = "pid"
    31  	MountNamespace NamespaceType = "mnt"
    32  )
    33  
    34  type ProcFS interface {
    35  	fs.ReadDirFS
    36  	fs.ReadFileFS
    37  	fs.StatFS
    38  }
    39  
    40  var (
    41  	ErrCannotGetPIDNSInode            = errors.New("cannot get pidns inode")
    42  	ErrParseStatFileInvalidCommFormat = errors.New("cannot parse stat file, invalid comm format")
    43  	ErrParseStatFileNotEnoughFields   = errors.New("cannot parse stat file, not enough fields")
    44  )
    45  
    46  type Proc struct {
    47  	procFS ProcFS
    48  }
    49  
    50  func New() *Proc {
    51  	return &Proc{
    52  		procFS: GetFS(),
    53  	}
    54  }
    55  
    56  // HostPath returns full file path on the host file system using procfs, eg: /proc/1/root/<my-path>
    57  func HostPath(p string) string {
    58  	return path.Join(Path, strconv.Itoa(1), p)
    59  }
    60  
    61  func (p *Proc) GetCurrentPIDNSID() (NamespaceID, error) {
    62  	return p.GetNSForPID(1, PIDNamespace)
    63  }
    64  
    65  func (p *Proc) LoadMountNSOldestProcesses() (map[NamespaceID]PID, error) {
    66  	files, err := p.procFS.ReadDir(".")
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	type processInfo struct {
    72  		pid PID
    73  		age uint64
    74  	}
    75  
    76  	namespaceMap := map[NamespaceID]processInfo{}
    77  
    78  	for _, f := range files {
    79  		pid, err := parsePIDFromString(f.Name())
    80  		if err != nil {
    81  			continue
    82  		}
    83  
    84  		mntNS, err := p.GetNSForPID(pid, MountNamespace)
    85  		if err != nil {
    86  			continue
    87  		}
    88  
    89  		processAge, err := p.GetProcessStartTime(pid)
    90  		if err != nil {
    91  			continue
    92  		}
    93  
    94  		current, found := namespaceMap[mntNS]
    95  		if found {
    96  			if current.age < processAge {
    97  				continue
    98  			}
    99  		}
   100  
   101  		namespaceMap[mntNS] = processInfo{
   102  			pid: pid,
   103  			age: processAge,
   104  		}
   105  	}
   106  
   107  	return lo.MapValues(namespaceMap, func(value processInfo, key NamespaceID) PID {
   108  		return value.pid
   109  	}), nil
   110  }
   111  
   112  func parsePIDFromString(pidStr string) (PID, error) {
   113  	pid, err := strconv.ParseUint(pidStr, 10, 32)
   114  	if err != nil {
   115  		return 0, err
   116  	}
   117  
   118  	return PID(pid), nil
   119  }
   120  
   121  func (p *Proc) GetNSForPID(pid PID, ns NamespaceType) (NamespaceID, error) {
   122  	info, err := p.procFS.Stat(fmt.Sprintf("%d/ns/%s", pid, ns))
   123  	if err != nil {
   124  		return 0, err
   125  	}
   126  	stat, ok := info.Sys().(*syscall.Stat_t)
   127  	if !ok {
   128  		return 0, ErrCannotGetPIDNSInode
   129  	}
   130  
   131  	return NamespaceID(stat.Ino), nil
   132  }
   133  
   134  // GetProcessStartTime parses the /proc/<pid>/stat file to determine the start time of the process after system boot.
   135  func (p *Proc) GetProcessStartTime(pid PID) (uint64, error) {
   136  	data, err := p.procFS.ReadFile(fmt.Sprintf("%d/stat", pid))
   137  	if err != nil {
   138  		return 0, err
   139  	}
   140  
   141  	commEndIndex := bytes.Index(data, []byte{')', ' '})
   142  	if commEndIndex < 0 {
   143  		return 0, ErrParseStatFileInvalidCommFormat
   144  	}
   145  
   146  	fields := bytes.Split(data[commEndIndex+2:], []byte{' '})
   147  	// According to https://man7.org/linux/man-pages/man5/proc.5.html , the start time is the 22 field. Since we cut
   148  	// out `comm` (2 field) we need to adjust the index. The -1 is to adjust for zero being the first elements in slices.
   149  	adjustedStartTimeIdx := 22 - 2 - 1
   150  
   151  	if len(fields) < adjustedStartTimeIdx {
   152  		return 0, ErrParseStatFileNotEnoughFields
   153  	}
   154  
   155  	return strconv.ParseUint(string(fields[adjustedStartTimeIdx]), 10, 64)
   156  }