github.com/k8snetworkplumbingwg/sriov-network-operator@v1.2.1-0.20240408194816-2d2e5a45d453/test/util/netns/netns.go (about)

     1  package netns
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path/filepath"
     7  	"runtime"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/vishvananda/netlink"
    13  	"github.com/vishvananda/netns"
    14  )
    15  
    16  const (
    17  	sysBusPci       = "/sys/bus/pci/devices"
    18  	sriovNumVfsFile = "sriov_numvfs"
    19  )
    20  
    21  // SetPfVfLinkNetNs requires physical function (PF) PCI address and a string to a network namespace in which to add
    22  // any associated virtual functions (VF). VFs must be attached to kernel driver to provide links. Attaching VFs to
    23  // vfio-pci driver is not supported. PF is set to target network namespace.
    24  // Polling interval is required and this period will determine how often the VFs Links are checked to ensure they are in
    25  // the target network namespace. Two channels are required - one for informing the func to end and one to inform
    26  // the caller of an error or if the function has ended. It is this func responsibility to cleanup the done channel and
    27  // callers responsibility to cleanup quit channel.
    28  func SetPfVfLinkNetNs(pfPciAddr, netNsPath string, pollInterval time.Duration, quitCh chan bool, doneCh chan error) {
    29  	runtime.LockOSThread()
    30  	defer runtime.UnlockOSThread()
    31  	var errL []error
    32  	//merge any errors found during execution, send through channel, and close it.
    33  	defer func() {
    34  		if len(errL) != 0 {
    35  			errMsg := ""
    36  			for i, err := range errL {
    37  				errMsg = errMsg + fmt.Sprintf("error %d) '%s'", i+1, err.Error())
    38  			}
    39  			doneCh <- fmt.Errorf("%s", errMsg)
    40  		}
    41  		close(doneCh)
    42  	}()
    43  
    44  	if pfPciAddr == "" || netNsPath == "" {
    45  		errL = append(errL, fmt.Errorf("SetPfVfLinkNetNs(): specify PF PCI address and/or netns path '%s' & '%s'",
    46  			pfPciAddr, netNsPath))
    47  		return
    48  	}
    49  
    50  	pfPciDir := filepath.Join(sysBusPci, pfPciAddr)
    51  	if _, err := os.Lstat(pfPciDir); err != nil {
    52  		errL = append(errL, fmt.Errorf("SetPfVfLinkNetNs(): failed to find PCI device at '%s': '%s'", pfPciDir,
    53  			err.Error()))
    54  		return
    55  	}
    56  
    57  	targetNetNs, err := netns.GetFromPath(netNsPath)
    58  	if err != nil {
    59  		errL = append(errL, fmt.Errorf("SetPfVfLinkNetNs(): failed to get target network namespace: '%s'",
    60  			err.Error()))
    61  		return
    62  	}
    63  	defer targetNetNs.Close()
    64  	// if failure to set PF netns, emit error, assume its in correct netns and continue
    65  	if err = setLinkNetNs(pfPciAddr, targetNetNs); err != nil {
    66  		errL = append(errL, fmt.Errorf("SetPfVfLinkNetNs(): unable to set physical function '%s' network namespace: '%s'", pfPciAddr,
    67  			err.Error()))
    68  	}
    69  	ticker := time.NewTicker(pollInterval)
    70  	defer ticker.Stop()
    71  
    72  	for {
    73  		select {
    74  		case <-quitCh:
    75  			return
    76  		case <-ticker.C:
    77  			if err := setVfNetNs(pfPciAddr, targetNetNs); err != nil {
    78  				//save errors for returning but continue
    79  				errL = append(errL, err)
    80  			}
    81  		}
    82  	}
    83  }
    84  
    85  // setVfNetNs requires physical function (PF) PCI address and a handle to a network namespace in which to add
    86  // any associated virtual functions (VF). If no VFs are found, no error is returned.
    87  func setVfNetNs(pfPciAddr string, targetNetNs netns.NsHandle) error {
    88  	var (
    89  		err      error
    90  		link     netlink.Link
    91  		vfNetDir string
    92  	)
    93  	numVfsFile := filepath.Join(sysBusPci, pfPciAddr, sriovNumVfsFile)
    94  	if _, err := os.Lstat(numVfsFile); err != nil {
    95  		return fmt.Errorf("setVfNetNs(): unable to open '%s' from device with PCI address '%s': '%s", numVfsFile, pfPciAddr,
    96  			err.Error())
    97  	}
    98  
    99  	data, err := os.ReadFile(numVfsFile)
   100  	if err != nil {
   101  		return fmt.Errorf("setVfNetNs(): failed to read '%s' from device with PCI address  '%s': '%s", numVfsFile, pfPciAddr,
   102  			err.Error())
   103  	}
   104  
   105  	if len(data) == 0 {
   106  		return fmt.Errorf("setVfNetNs(): no data in file '%s'", numVfsFile)
   107  	}
   108  
   109  	vfTotal, err := strconv.Atoi(strings.TrimSpace(string(data)))
   110  	if err != nil {
   111  		return fmt.Errorf("setVfNetNs(): unable to convert file '%s' to integer: '%s'", numVfsFile, err.Error())
   112  	}
   113  
   114  	for vfNo := 0; vfNo <= vfTotal; vfNo++ {
   115  		vfNetDir = filepath.Join(sysBusPci, pfPciAddr, fmt.Sprintf("virtfn%d", vfNo), "net")
   116  		_, err = os.Lstat(vfNetDir)
   117  		if err != nil {
   118  			continue
   119  		}
   120  
   121  		fInfos, err := os.ReadDir(vfNetDir)
   122  		if err != nil {
   123  			return fmt.Errorf("setVfNetNs(): failed to read '%s': '%s'", vfNetDir, err.Error())
   124  		}
   125  		if len(fInfos) == 0 {
   126  			continue
   127  		}
   128  
   129  		for _, f := range fInfos {
   130  			link, err = netlink.LinkByName(f.Name())
   131  			if err != nil {
   132  				return fmt.Errorf("setVfNetNs(): unable to get link with name '%s' from directory '%s': '%s'", f.Name(),
   133  					vfNetDir, err.Error())
   134  			}
   135  			if err = netlink.LinkSetNsFd(link, int(targetNetNs)); err != nil {
   136  				return fmt.Errorf("setVfNetNs(): failed to set link '%s' netns: '%s'", f.Name(), err.Error())
   137  			}
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  // setLinkNetNs attempts to create a link object from PCI address and change the network namespace. Arg pciAddr must have
   144  // an associated interface name in the current network namespace or an error is thrown.
   145  func setLinkNetNs(pciAddr string, targetNetNs netns.NsHandle) error {
   146  	var err error
   147  	netDir := filepath.Join(sysBusPci, pciAddr, "net")
   148  	if _, err := os.Lstat(netDir); err != nil {
   149  		return fmt.Errorf("setLinkNetNs(): unable to find directory '%s': '%s'", netDir, err.Error())
   150  	}
   151  	fInfos, err := os.ReadDir(netDir)
   152  	if err != nil {
   153  		return fmt.Errorf("setLinkNetNs(): failed to read '%s': '%s'", netDir, err.Error())
   154  	}
   155  	if len(fInfos) == 0 {
   156  		return fmt.Errorf("setLinkNetNs(): no files found at directory '%s'", netDir)
   157  	}
   158  	link, err := netlink.LinkByName(fInfos[0].Name())
   159  	if err != nil {
   160  		return fmt.Errorf("setLinkNetNs(): failed to create link object by name '%s' from files at '%s': '%s'", fInfos[0].Name(),
   161  			netDir, err.Error())
   162  	}
   163  	err = netlink.LinkSetNsFd(link, int(targetNetNs))
   164  	if err != nil {
   165  		return fmt.Errorf("setLinkNetNs(): failed to set link '%s' netns: '%s'", link.Attrs().Name, err.Error())
   166  	}
   167  	return nil
   168  }