gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/gpu/drivers/install_driver.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package drivers contains methods to download and install drivers.
    16  package drivers
    17  
    18  import (
    19  	"context"
    20  	"crypto/sha256"
    21  	"fmt"
    22  	"io"
    23  	"net/http"
    24  	"os"
    25  	"os/exec"
    26  	"strings"
    27  
    28  	"gvisor.dev/gvisor/pkg/log"
    29  	"gvisor.dev/gvisor/pkg/sentry/devices/nvproxy"
    30  )
    31  
    32  const (
    33  	nvidiaSMIPath       = "/usr/bin/nvidia-smi"
    34  	nvidiaUninstallPath = "/usr/bin/nvidia-uninstall"
    35  	nvidiaBaseURL       = "https://us.download.nvidia.com/tesla/"
    36  )
    37  
    38  func init() {
    39  	nvproxy.Init()
    40  }
    41  
    42  // Installer handles the logic to install drivers.
    43  type Installer struct {
    44  	requestedVersion nvproxy.DriverVersion
    45  	// include functions so they can be mocked in tests.
    46  	expectedChecksumFunc func(nvproxy.DriverVersion) (string, bool)
    47  	getCurrentDriverFunc func() (nvproxy.DriverVersion, error)
    48  	downloadFunc         func(context.Context, string) (io.ReadCloser, error)
    49  	installFunc          func(string) error
    50  }
    51  
    52  // NewInstaller returns a driver installer instance.
    53  func NewInstaller(requestedVersion string, latest bool) (*Installer, error) {
    54  	ret := &Installer{
    55  		expectedChecksumFunc: nvproxy.ExpectedDriverChecksum,
    56  		getCurrentDriverFunc: getCurrentDriver,
    57  		downloadFunc:         DownloadDriver,
    58  		installFunc:          installDriver,
    59  	}
    60  	switch {
    61  	case latest:
    62  		ret.requestedVersion = nvproxy.LatestDriver()
    63  	default:
    64  		d, err := nvproxy.DriverVersionFrom(requestedVersion)
    65  		if err != nil {
    66  			return nil, fmt.Errorf("failed to parse requested driver version: %w", err)
    67  		}
    68  		ret.requestedVersion = d
    69  	}
    70  
    71  	return ret, nil
    72  }
    73  
    74  // MaybeInstall installs a driver if 1) no driver is present on the system already or 2) the
    75  // driver currently installed does not match the requested version.
    76  func (i *Installer) MaybeInstall(ctx context.Context) error {
    77  	// If we don't support the driver, don't attempt to install it.
    78  	if _, ok := i.expectedChecksumFunc(i.requestedVersion); !ok {
    79  		return fmt.Errorf("requested driver %q is not supported", i.requestedVersion)
    80  	}
    81  
    82  	existingDriver, err := i.getCurrentDriverFunc()
    83  	if err != nil {
    84  		log.Warningf("failed to get current driver: %v", err)
    85  	}
    86  	if existingDriver.Equals(i.requestedVersion) {
    87  		log.Infof("Driver already installed: %s", i.requestedVersion)
    88  		return nil
    89  	}
    90  
    91  	if !existingDriver.Equals(nvproxy.DriverVersion{}) {
    92  		log.Infof("Uninstalling driver: %s", existingDriver)
    93  		if err := i.uninstallDriver(ctx, existingDriver.String()); err != nil {
    94  			return fmt.Errorf("failed to uninstall driver: %w", err)
    95  		}
    96  		log.Infof("Driver uninstalled: %s", i.requestedVersion)
    97  	}
    98  
    99  	log.Infof("Downloading driver: %s", i.requestedVersion)
   100  	reader, err := i.downloadFunc(ctx, i.requestedVersion.String())
   101  	if err != nil {
   102  		return fmt.Errorf("failed to download driver: %w", err)
   103  	}
   104  
   105  	f, err := os.CreateTemp("", "")
   106  	if err != nil {
   107  		return fmt.Errorf("failed to open driver file: %w", err)
   108  	}
   109  	defer os.Remove(f.Name())
   110  	if err := i.writeAndCheck(f, reader); err != nil {
   111  		f.Close()
   112  		return fmt.Errorf("writeAndCheck: %w", err)
   113  	}
   114  	if err := f.Chmod(0755); err != nil {
   115  		return fmt.Errorf("failed to chmod: %w", err)
   116  	}
   117  	if err := f.Close(); err != nil {
   118  		return fmt.Errorf("failed to close driver file: %w", err)
   119  	}
   120  	log.Infof("Driver downloaded: %s", i.requestedVersion)
   121  	log.Infof("Installing driver: %s", i.requestedVersion)
   122  	if err := i.installFunc(f.Name()); err != nil {
   123  		return fmt.Errorf("failed to install driver: %w", err)
   124  	}
   125  	log.Infof("Installation Complete!")
   126  	return nil
   127  }
   128  
   129  func (i *Installer) uninstallDriver(ctx context.Context, driverVersion string) error {
   130  	exec.Command(nvidiaUninstallPath, "-s", driverVersion)
   131  	cmd := exec.Command(nvidiaUninstallPath, "-s")
   132  	cmd.Stdout = os.Stdout
   133  	cmd.Stderr = os.Stderr
   134  	if err := cmd.Run(); err != nil {
   135  		return fmt.Errorf("failed to run nvidia-uninstall: %w", err)
   136  	}
   137  	return nil
   138  }
   139  
   140  func (i *Installer) writeAndCheck(f *os.File, reader io.ReadCloser) error {
   141  	checksum := sha256.New()
   142  	buf := make([]byte, 1024*1024)
   143  	for {
   144  		n, err := reader.Read(buf[0:])
   145  		if err != nil && err != io.EOF {
   146  			return fmt.Errorf("failed to read: %w", err)
   147  		}
   148  		if n == 0 || err == io.EOF {
   149  			break
   150  		}
   151  		if _, err := checksum.Write(buf[:n]); err != nil {
   152  			return fmt.Errorf("failed to write: %w", err)
   153  		}
   154  		if _, err := f.Write(buf[:n]); err != nil {
   155  			return fmt.Errorf("failed to write: %w", err)
   156  		}
   157  	}
   158  	gotChecksum := fmt.Sprintf("%x", checksum.Sum(nil))
   159  	wantChecksum, ok := i.expectedChecksumFunc(i.requestedVersion)
   160  	if !ok {
   161  		return fmt.Errorf("requested driver %q is not supported", i.requestedVersion)
   162  	}
   163  	if gotChecksum != wantChecksum {
   164  		return fmt.Errorf("driver %q checksum mismatch: got %q, want %q", i.requestedVersion, gotChecksum, wantChecksum)
   165  	}
   166  	return nil
   167  }
   168  
   169  func getCurrentDriver() (nvproxy.DriverVersion, error) {
   170  	_, err := os.Stat(nvidiaSMIPath)
   171  	// If the nvidia-smi executable does not exist, then we don't have a driver installed.
   172  	if os.IsNotExist(err) {
   173  		return nvproxy.DriverVersion{}, fmt.Errorf("nvidia-smi does not exist at path: %q", nvidiaSMIPath)
   174  	}
   175  	if err != nil {
   176  		return nvproxy.DriverVersion{}, fmt.Errorf("failed to stat nvidia-smi: %w", err)
   177  	}
   178  	out, err := exec.Command(nvidiaSMIPath, []string{"--query-gpu", "driver_version", "--format=csv,noheader"}...).CombinedOutput()
   179  	if err != nil {
   180  		log.Warningf("failed to run nvidia-smi: %v", err)
   181  		return nvproxy.DriverVersion{}, fmt.Errorf("failed to run nvidia-smi: %w", err)
   182  	}
   183  	// If there are multiple GPUs, there will be one version per line.
   184  	// Make sure they are all the same version.
   185  	sameVersion := ""
   186  	for _, line := range strings.Split(string(out), "\n") {
   187  		line = strings.TrimSpace(line)
   188  		if line == "" {
   189  			continue
   190  		}
   191  		if sameVersion == "" {
   192  			sameVersion = line
   193  			continue
   194  		}
   195  		if line != sameVersion {
   196  			return nvproxy.DriverVersion{}, fmt.Errorf("multiple driver versions found: %q and %q", sameVersion, line)
   197  		}
   198  	}
   199  	if sameVersion == "" {
   200  		return nvproxy.DriverVersion{}, fmt.Errorf("no driver version found")
   201  	}
   202  	return nvproxy.DriverVersionFrom(sameVersion)
   203  }
   204  
   205  // ListSupportedDrivers prints the driver to stderr in a format that can be
   206  // consumed by the Makefile to iterate tests across drivers.
   207  func ListSupportedDrivers(outfile string) error {
   208  	out := os.Stdout
   209  	if outfile != "" {
   210  		f, err := os.OpenFile(outfile, os.O_WRONLY, 0644)
   211  		if err != nil {
   212  			return fmt.Errorf("failed to open outfile: %w", err)
   213  		}
   214  		defer f.Close()
   215  		out = f
   216  	}
   217  
   218  	var list []string
   219  	nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, checksum string) {
   220  		list = append(list, version.String())
   221  	})
   222  	if _, err := out.WriteString(strings.Join(list, " ") + "\n"); err != nil {
   223  		return fmt.Errorf("failed to write to outfile: %w", err)
   224  	}
   225  	return nil
   226  }
   227  
   228  // ChecksumDriver downloads and returns the SHA265 checksum of the driver.
   229  func ChecksumDriver(ctx context.Context, driverVersion string) (string, error) {
   230  	f, err := DownloadDriver(ctx, driverVersion)
   231  	if err != nil {
   232  		return "", fmt.Errorf("failed to download driver: %w", err)
   233  	}
   234  	checksum := sha256.New()
   235  	for {
   236  		n, err := io.Copy(checksum, f)
   237  		if err == io.EOF || n == 0 {
   238  			break
   239  		}
   240  		if err != nil {
   241  			return "", fmt.Errorf("failed to copy driver: %w", err)
   242  		}
   243  	}
   244  	return fmt.Sprintf("%x", checksum.Sum(nil)), nil
   245  }
   246  
   247  // DownloadDriver downloads the requested driver and returns the binary as a []byte so it can be
   248  // checked before written to disk.
   249  func DownloadDriver(ctx context.Context, driverVersion string) (io.ReadCloser, error) {
   250  	url := fmt.Sprintf("%s%s/NVIDIA-Linux-x86_64-%s.run", nvidiaBaseURL, driverVersion, driverVersion)
   251  	resp, err := http.Get(url)
   252  	if err != nil {
   253  		return nil, fmt.Errorf("failed to download driver: %w", err)
   254  	}
   255  	if resp.StatusCode != http.StatusOK {
   256  		return nil, fmt.Errorf("failed to download driver with status: %w", err)
   257  	}
   258  	return resp.Body, nil
   259  }
   260  
   261  func installDriver(driverPath string) error {
   262  	// Certain VMs can be broken if we attempt to install drivers on them. Do a simple check of the
   263  	// PCI device to ensure we have a GPU attached.
   264  	out, err := exec.Command("lspci").CombinedOutput()
   265  	if err != nil {
   266  		return fmt.Errorf("failed to run lspci: %w out: %s", err, string(out))
   267  	}
   268  	if !strings.Contains(string(out), "NVIDIA") {
   269  		return fmt.Errorf("no NVIDIA PCI device on host:\n%s", string(out))
   270  	}
   271  
   272  	driverArgs := strings.Split("--dkms -a -s --no-drm --install-libglvnd", " ")
   273  	cmd := exec.Command(driverPath, driverArgs...)
   274  	cmd.Stdout = os.Stdout
   275  	cmd.Stderr = os.Stderr
   276  	if err := cmd.Run(); err != nil {
   277  		tryToPrintFailureLogs()
   278  		return fmt.Errorf("failed to run nvidia-install: %w out: %s", err, string(out))
   279  	}
   280  
   281  	cmd = exec.Command(nvidiaSMIPath)
   282  	cmd.Stdout = os.Stdout
   283  	cmd.Stderr = os.Stderr
   284  	if err := cmd.Run(); err != nil {
   285  		return fmt.Errorf("failed to run nvidia-smi post install: %w out: %s", err, string(out))
   286  	}
   287  	return nil
   288  }
   289  
   290  func tryToPrintFailureLogs() {
   291  	// nvidia driver installers print failure logs to this path.
   292  	const logPath = "/var/log/nvidia-installer.log"
   293  	f, err := os.OpenFile(logPath, os.O_RDONLY, 0644)
   294  	if err != nil {
   295  		log.Warningf("failed to stat nvidia-installer.log: %v", err)
   296  		return
   297  	}
   298  	defer f.Close()
   299  
   300  	out, err := io.ReadAll(f)
   301  	if err != nil {
   302  		log.Warningf("failed to read nvidia-installer.log: %v", err)
   303  		return
   304  	}
   305  
   306  	for _, line := range strings.Split(string(out), "\n") {
   307  		fmt.Printf("[nvidia-installer]: %s\n", line)
   308  	}
   309  }