gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/gpu/drivers/install_driver.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package drivers contains methods to download and install drivers. 16 package drivers 17 18 import ( 19 "context" 20 "crypto/sha256" 21 "fmt" 22 "io" 23 "net/http" 24 "os" 25 "os/exec" 26 "strings" 27 28 "gvisor.dev/gvisor/pkg/log" 29 "gvisor.dev/gvisor/pkg/sentry/devices/nvproxy" 30 ) 31 32 const ( 33 nvidiaSMIPath = "/usr/bin/nvidia-smi" 34 nvidiaUninstallPath = "/usr/bin/nvidia-uninstall" 35 nvidiaBaseURL = "https://us.download.nvidia.com/tesla/" 36 ) 37 38 func init() { 39 nvproxy.Init() 40 } 41 42 // Installer handles the logic to install drivers. 43 type Installer struct { 44 requestedVersion nvproxy.DriverVersion 45 // include functions so they can be mocked in tests. 46 expectedChecksumFunc func(nvproxy.DriverVersion) (string, bool) 47 getCurrentDriverFunc func() (nvproxy.DriverVersion, error) 48 downloadFunc func(context.Context, string) (io.ReadCloser, error) 49 installFunc func(string) error 50 } 51 52 // NewInstaller returns a driver installer instance. 53 func NewInstaller(requestedVersion string, latest bool) (*Installer, error) { 54 ret := &Installer{ 55 expectedChecksumFunc: nvproxy.ExpectedDriverChecksum, 56 getCurrentDriverFunc: getCurrentDriver, 57 downloadFunc: DownloadDriver, 58 installFunc: installDriver, 59 } 60 switch { 61 case latest: 62 ret.requestedVersion = nvproxy.LatestDriver() 63 default: 64 d, err := nvproxy.DriverVersionFrom(requestedVersion) 65 if err != nil { 66 return nil, fmt.Errorf("failed to parse requested driver version: %w", err) 67 } 68 ret.requestedVersion = d 69 } 70 71 return ret, nil 72 } 73 74 // MaybeInstall installs a driver if 1) no driver is present on the system already or 2) the 75 // driver currently installed does not match the requested version. 76 func (i *Installer) MaybeInstall(ctx context.Context) error { 77 // If we don't support the driver, don't attempt to install it. 78 if _, ok := i.expectedChecksumFunc(i.requestedVersion); !ok { 79 return fmt.Errorf("requested driver %q is not supported", i.requestedVersion) 80 } 81 82 existingDriver, err := i.getCurrentDriverFunc() 83 if err != nil { 84 log.Warningf("failed to get current driver: %v", err) 85 } 86 if existingDriver.Equals(i.requestedVersion) { 87 log.Infof("Driver already installed: %s", i.requestedVersion) 88 return nil 89 } 90 91 if !existingDriver.Equals(nvproxy.DriverVersion{}) { 92 log.Infof("Uninstalling driver: %s", existingDriver) 93 if err := i.uninstallDriver(ctx, existingDriver.String()); err != nil { 94 return fmt.Errorf("failed to uninstall driver: %w", err) 95 } 96 log.Infof("Driver uninstalled: %s", i.requestedVersion) 97 } 98 99 log.Infof("Downloading driver: %s", i.requestedVersion) 100 reader, err := i.downloadFunc(ctx, i.requestedVersion.String()) 101 if err != nil { 102 return fmt.Errorf("failed to download driver: %w", err) 103 } 104 105 f, err := os.CreateTemp("", "") 106 if err != nil { 107 return fmt.Errorf("failed to open driver file: %w", err) 108 } 109 defer os.Remove(f.Name()) 110 if err := i.writeAndCheck(f, reader); err != nil { 111 f.Close() 112 return fmt.Errorf("writeAndCheck: %w", err) 113 } 114 if err := f.Chmod(0755); err != nil { 115 return fmt.Errorf("failed to chmod: %w", err) 116 } 117 if err := f.Close(); err != nil { 118 return fmt.Errorf("failed to close driver file: %w", err) 119 } 120 log.Infof("Driver downloaded: %s", i.requestedVersion) 121 log.Infof("Installing driver: %s", i.requestedVersion) 122 if err := i.installFunc(f.Name()); err != nil { 123 return fmt.Errorf("failed to install driver: %w", err) 124 } 125 log.Infof("Installation Complete!") 126 return nil 127 } 128 129 func (i *Installer) uninstallDriver(ctx context.Context, driverVersion string) error { 130 exec.Command(nvidiaUninstallPath, "-s", driverVersion) 131 cmd := exec.Command(nvidiaUninstallPath, "-s") 132 cmd.Stdout = os.Stdout 133 cmd.Stderr = os.Stderr 134 if err := cmd.Run(); err != nil { 135 return fmt.Errorf("failed to run nvidia-uninstall: %w", err) 136 } 137 return nil 138 } 139 140 func (i *Installer) writeAndCheck(f *os.File, reader io.ReadCloser) error { 141 checksum := sha256.New() 142 buf := make([]byte, 1024*1024) 143 for { 144 n, err := reader.Read(buf[0:]) 145 if err != nil && err != io.EOF { 146 return fmt.Errorf("failed to read: %w", err) 147 } 148 if n == 0 || err == io.EOF { 149 break 150 } 151 if _, err := checksum.Write(buf[:n]); err != nil { 152 return fmt.Errorf("failed to write: %w", err) 153 } 154 if _, err := f.Write(buf[:n]); err != nil { 155 return fmt.Errorf("failed to write: %w", err) 156 } 157 } 158 gotChecksum := fmt.Sprintf("%x", checksum.Sum(nil)) 159 wantChecksum, ok := i.expectedChecksumFunc(i.requestedVersion) 160 if !ok { 161 return fmt.Errorf("requested driver %q is not supported", i.requestedVersion) 162 } 163 if gotChecksum != wantChecksum { 164 return fmt.Errorf("driver %q checksum mismatch: got %q, want %q", i.requestedVersion, gotChecksum, wantChecksum) 165 } 166 return nil 167 } 168 169 func getCurrentDriver() (nvproxy.DriverVersion, error) { 170 _, err := os.Stat(nvidiaSMIPath) 171 // If the nvidia-smi executable does not exist, then we don't have a driver installed. 172 if os.IsNotExist(err) { 173 return nvproxy.DriverVersion{}, fmt.Errorf("nvidia-smi does not exist at path: %q", nvidiaSMIPath) 174 } 175 if err != nil { 176 return nvproxy.DriverVersion{}, fmt.Errorf("failed to stat nvidia-smi: %w", err) 177 } 178 out, err := exec.Command(nvidiaSMIPath, []string{"--query-gpu", "driver_version", "--format=csv,noheader"}...).CombinedOutput() 179 if err != nil { 180 log.Warningf("failed to run nvidia-smi: %v", err) 181 return nvproxy.DriverVersion{}, fmt.Errorf("failed to run nvidia-smi: %w", err) 182 } 183 // If there are multiple GPUs, there will be one version per line. 184 // Make sure they are all the same version. 185 sameVersion := "" 186 for _, line := range strings.Split(string(out), "\n") { 187 line = strings.TrimSpace(line) 188 if line == "" { 189 continue 190 } 191 if sameVersion == "" { 192 sameVersion = line 193 continue 194 } 195 if line != sameVersion { 196 return nvproxy.DriverVersion{}, fmt.Errorf("multiple driver versions found: %q and %q", sameVersion, line) 197 } 198 } 199 if sameVersion == "" { 200 return nvproxy.DriverVersion{}, fmt.Errorf("no driver version found") 201 } 202 return nvproxy.DriverVersionFrom(sameVersion) 203 } 204 205 // ListSupportedDrivers prints the driver to stderr in a format that can be 206 // consumed by the Makefile to iterate tests across drivers. 207 func ListSupportedDrivers(outfile string) error { 208 out := os.Stdout 209 if outfile != "" { 210 f, err := os.OpenFile(outfile, os.O_WRONLY, 0644) 211 if err != nil { 212 return fmt.Errorf("failed to open outfile: %w", err) 213 } 214 defer f.Close() 215 out = f 216 } 217 218 var list []string 219 nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, checksum string) { 220 list = append(list, version.String()) 221 }) 222 if _, err := out.WriteString(strings.Join(list, " ") + "\n"); err != nil { 223 return fmt.Errorf("failed to write to outfile: %w", err) 224 } 225 return nil 226 } 227 228 // ChecksumDriver downloads and returns the SHA265 checksum of the driver. 229 func ChecksumDriver(ctx context.Context, driverVersion string) (string, error) { 230 f, err := DownloadDriver(ctx, driverVersion) 231 if err != nil { 232 return "", fmt.Errorf("failed to download driver: %w", err) 233 } 234 checksum := sha256.New() 235 for { 236 n, err := io.Copy(checksum, f) 237 if err == io.EOF || n == 0 { 238 break 239 } 240 if err != nil { 241 return "", fmt.Errorf("failed to copy driver: %w", err) 242 } 243 } 244 return fmt.Sprintf("%x", checksum.Sum(nil)), nil 245 } 246 247 // DownloadDriver downloads the requested driver and returns the binary as a []byte so it can be 248 // checked before written to disk. 249 func DownloadDriver(ctx context.Context, driverVersion string) (io.ReadCloser, error) { 250 url := fmt.Sprintf("%s%s/NVIDIA-Linux-x86_64-%s.run", nvidiaBaseURL, driverVersion, driverVersion) 251 resp, err := http.Get(url) 252 if err != nil { 253 return nil, fmt.Errorf("failed to download driver: %w", err) 254 } 255 if resp.StatusCode != http.StatusOK { 256 return nil, fmt.Errorf("failed to download driver with status: %w", err) 257 } 258 return resp.Body, nil 259 } 260 261 func installDriver(driverPath string) error { 262 // Certain VMs can be broken if we attempt to install drivers on them. Do a simple check of the 263 // PCI device to ensure we have a GPU attached. 264 out, err := exec.Command("lspci").CombinedOutput() 265 if err != nil { 266 return fmt.Errorf("failed to run lspci: %w out: %s", err, string(out)) 267 } 268 if !strings.Contains(string(out), "NVIDIA") { 269 return fmt.Errorf("no NVIDIA PCI device on host:\n%s", string(out)) 270 } 271 272 driverArgs := strings.Split("--dkms -a -s --no-drm --install-libglvnd", " ") 273 cmd := exec.Command(driverPath, driverArgs...) 274 cmd.Stdout = os.Stdout 275 cmd.Stderr = os.Stderr 276 if err := cmd.Run(); err != nil { 277 tryToPrintFailureLogs() 278 return fmt.Errorf("failed to run nvidia-install: %w out: %s", err, string(out)) 279 } 280 281 cmd = exec.Command(nvidiaSMIPath) 282 cmd.Stdout = os.Stdout 283 cmd.Stderr = os.Stderr 284 if err := cmd.Run(); err != nil { 285 return fmt.Errorf("failed to run nvidia-smi post install: %w out: %s", err, string(out)) 286 } 287 return nil 288 } 289 290 func tryToPrintFailureLogs() { 291 // nvidia driver installers print failure logs to this path. 292 const logPath = "/var/log/nvidia-installer.log" 293 f, err := os.OpenFile(logPath, os.O_RDONLY, 0644) 294 if err != nil { 295 log.Warningf("failed to stat nvidia-installer.log: %v", err) 296 return 297 } 298 defer f.Close() 299 300 out, err := io.ReadAll(f) 301 if err != nil { 302 log.Warningf("failed to read nvidia-installer.log: %v", err) 303 return 304 } 305 306 for _, line := range strings.Split(string(out), "\n") { 307 fmt.Printf("[nvidia-installer]: %s\n", line) 308 } 309 }