github.com/kata-containers/runtime@v0.0.0-20210505125100-04f29832a923/virtcontainers/pkg/nsenter/nsenter.go (about)

     1  // Copyright (c) 2018 Intel Corporation
     2  // Copyright 2015-2017 CNI authors
     3  //
     4  // SPDX-License-Identifier: Apache-2.0
     5  //
     6  
     7  package nsenter
     8  
     9  import (
    10  	"fmt"
    11  	"os"
    12  	"path/filepath"
    13  	"runtime"
    14  	"strconv"
    15  	"sync"
    16  	"syscall"
    17  
    18  	"golang.org/x/sys/unix"
    19  )
    20  
    21  // Filesystems constants.
    22  const (
    23  	// https://github.com/torvalds/linux/blob/master/include/uapi/linux/magic.h
    24  	nsFSMagic   = 0x6e736673
    25  	procFSMagic = 0x9fa0
    26  
    27  	procRootPath = "/proc"
    28  	nsDirPath    = "ns"
    29  	taskDirPath  = "task"
    30  )
    31  
    32  // NSType defines a namespace type.
    33  type NSType string
    34  
    35  // List of namespace types.
    36  // Notice that neither "mnt" nor "user" are listed into this list.
    37  // Because Golang is multithreaded, we get some errors when trying
    38  // to switch to those namespaces, getting "invalid argument".
    39  // The solution is to reexec the current code so that it will call
    40  // into a C constructor, making sure the namespace can be entered
    41  // without multithreading issues.
    42  const (
    43  	NSTypeCGroup NSType = "cgroup"
    44  	NSTypeIPC    NSType = "ipc"
    45  	NSTypeNet    NSType = "net"
    46  	NSTypePID    NSType = "pid"
    47  	NSTypeUTS    NSType = "uts"
    48  )
    49  
    50  // CloneFlagsTable is exported so that consumers of this package don't need
    51  // to define this same table again.
    52  var CloneFlagsTable = make(map[NSType]int)
    53  
    54  // Namespace describes a namespace that will be entered.
    55  type Namespace struct {
    56  	Path string
    57  	PID  int
    58  	Type NSType
    59  }
    60  
    61  type nsPair struct {
    62  	targetNS *os.File
    63  	threadNS *os.File
    64  }
    65  
    66  func init() {
    67  	var ns = map[NSType]int{
    68  		NSTypeCGroup: unix.CLONE_NEWCGROUP,
    69  		NSTypeIPC:    unix.CLONE_NEWIPC,
    70  		NSTypeNet:    unix.CLONE_NEWNET,
    71  		NSTypePID:    unix.CLONE_NEWPID,
    72  		NSTypeUTS:    unix.CLONE_NEWUTS,
    73  	}
    74  
    75  	for k, v := range ns {
    76  		if _, err := os.Stat(fmt.Sprint("/proc/self/ns/", string(k))); err == nil {
    77  			CloneFlagsTable[k] = v
    78  		}
    79  	}
    80  }
    81  
    82  func getNSPathFromPID(pid int, nsType NSType) string {
    83  	return filepath.Join(procRootPath, strconv.Itoa(pid), nsDirPath, string(nsType))
    84  }
    85  
    86  func getCurrentThreadNSPath(nsType NSType) string {
    87  	return filepath.Join(procRootPath, strconv.Itoa(os.Getpid()),
    88  		taskDirPath, strconv.Itoa(unix.Gettid()), nsDirPath, string(nsType))
    89  }
    90  
    91  func setNS(nsFile *os.File, nsType NSType) error {
    92  	if nsFile == nil {
    93  		return fmt.Errorf("File handler cannot be nil")
    94  	}
    95  
    96  	nsFlag, exist := CloneFlagsTable[nsType]
    97  	if !exist {
    98  		return fmt.Errorf("Unknown namespace type %q", nsType)
    99  	}
   100  
   101  	if err := unix.Setns(int(nsFile.Fd()), nsFlag); err != nil {
   102  		return fmt.Errorf("Error switching to ns %v: %v", nsFile.Name(), err)
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  // getFileFromNS checks the provided file path actually matches a real
   109  // namespace filesystem, and then opens it to return a handler to this
   110  // file. This is needed since the system call setns() expects a file
   111  // descriptor to enter the given namespace.
   112  func getFileFromNS(nsPath string) (*os.File, error) {
   113  	stat := syscall.Statfs_t{}
   114  	if err := syscall.Statfs(nsPath, &stat); err != nil {
   115  		return nil, fmt.Errorf("failed to Statfs %q: %v", nsPath, err)
   116  	}
   117  
   118  	switch stat.Type {
   119  	case nsFSMagic, procFSMagic:
   120  		break
   121  	default:
   122  		return nil, fmt.Errorf("unknown FS magic on %q: %x", nsPath, stat.Type)
   123  	}
   124  
   125  	file, err := os.Open(nsPath)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	return file, nil
   131  }
   132  
   133  // NsEnter executes the passed closure under the given namespace,
   134  // restoring the original namespace afterwards.
   135  func NsEnter(nsList []Namespace, toRun func() error) error {
   136  	targetNSList := make(map[NSType]*nsPair)
   137  
   138  	// Open all targeted namespaces.
   139  	for _, ns := range nsList {
   140  		targetNSPath := ns.Path
   141  		if targetNSPath == "" {
   142  			targetNSPath = getNSPathFromPID(ns.PID, ns.Type)
   143  		}
   144  
   145  		targetNS, err := getFileFromNS(targetNSPath)
   146  		if err != nil {
   147  			return fmt.Errorf("failed to open target ns: %v", err)
   148  		}
   149  		defer targetNS.Close()
   150  
   151  		targetNSList[ns.Type] = &nsPair{
   152  			targetNS: targetNS,
   153  		}
   154  	}
   155  
   156  	containedCall := func() error {
   157  		for nsType := range targetNSList {
   158  			threadNS, err := getFileFromNS(getCurrentThreadNSPath(nsType))
   159  			if err != nil {
   160  				return fmt.Errorf("failed to open current ns: %v", err)
   161  			}
   162  			defer threadNS.Close()
   163  
   164  			targetNSList[nsType].threadNS = threadNS
   165  		}
   166  
   167  		// Switch to namespaces all at once.
   168  		for nsType, pair := range targetNSList {
   169  			// Switch to targeted namespace.
   170  			if err := setNS(pair.targetNS, nsType); err != nil {
   171  				return fmt.Errorf("error switching to ns %v: %v", pair.targetNS.Name(), err)
   172  			}
   173  			// Switch back to initial namespace after closure return.
   174  			defer setNS(pair.threadNS, nsType)
   175  		}
   176  
   177  		return toRun()
   178  	}
   179  
   180  	var wg sync.WaitGroup
   181  	wg.Add(1)
   182  
   183  	var innerError error
   184  	go func() {
   185  		defer wg.Done()
   186  		runtime.LockOSThread()
   187  		defer runtime.UnlockOSThread()
   188  		innerError = containedCall()
   189  	}()
   190  	wg.Wait()
   191  
   192  	return innerError
   193  }