gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/gpu/main.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package main downloads and installs drivers. 16 package main 17 18 import ( 19 "context" 20 "fmt" 21 "os" 22 23 "gvisor.dev/gvisor/pkg/log" 24 "gvisor.dev/gvisor/pkg/sentry/devices/nvproxy" 25 "gvisor.dev/gvisor/runsc/flag" 26 "gvisor.dev/gvisor/tools/gpu/drivers" 27 ) 28 29 const ( 30 installCmdStr = "install" 31 installDescription = "installs a driver on the host machine" 32 checksumCmdStr = "checksum" 33 checksumDescription = "computes the sha256 checksum for a given driver version" 34 validateChecksumCmdStr = "validate_checksum" 35 validateChecksumDescription = "validates the checksum of all supported drivers" 36 listCmdStr = "list" 37 listDescription = "lists the supported drivers" 38 ) 39 40 var ( 41 // Install installs a give driver on the host machine. 42 installCmd = flag.NewFlagSet(installCmdStr, flag.ContinueOnError) 43 latest = installCmd.Bool("latest", false, "install the latest supported driver") 44 version = installCmd.String("version", "", "version of the driver") 45 46 // Computes the sha256 checksum for a given driver's .run file from the nvidia site. 47 checksumCmd = flag.NewFlagSet(checksumCmdStr, flag.ContinueOnError) 48 checksumVersion = checksumCmd.String("version", "", "version of the driver") 49 50 // Validates all supported driver's checksums of each driver's .run file from the nvidia site. 51 validateChecksumCmd = flag.NewFlagSet(validateChecksumCmdStr, flag.ContinueOnError) 52 53 // The list command returns the list of supported drivers from this tool. 54 listCmd = flag.NewFlagSet(listCmdStr, flag.ContinueOnError) 55 outfile = listCmd.String("outfile", "", "if set, write the list output to this file") 56 57 commandSet = map[*flag.FlagSet]string{ 58 installCmd: installDescription, 59 checksumCmd: checksumDescription, 60 validateChecksumCmd: validateChecksumDescription, 61 listCmd: listDescription, 62 } 63 ) 64 65 // printUsage prints the top level usage string. 66 func printUsage() { 67 usage := `Usage: main <command> <flags> ... 68 69 Available commands:` 70 fmt.Println(usage) 71 for _, f := range []*flag.FlagSet{installCmd, checksumCmd, validateChecksumCmd, listCmd} { 72 fmt.Printf("%s %s\n", f.Name(), commandSet[f]) 73 f.PrintDefaults() 74 } 75 } 76 77 func main() { 78 ctx := context.Background() 79 if len(os.Args) < 2 { 80 printUsage() 81 os.Exit(1) 82 } 83 nvproxy.Init() 84 switch os.Args[1] { 85 case installCmdStr: 86 if err := installCmd.Parse(os.Args[2:]); err != nil { 87 log.Warningf("%s failed with: %v", installCmdStr, err) 88 os.Exit(1) 89 } 90 installer, err := drivers.NewInstaller(*version, *latest) 91 if err != nil { 92 log.Warningf("Failed to create installer: %v", err.Error()) 93 os.Exit(1) 94 } 95 if err := installer.MaybeInstall(ctx); err != nil { 96 log.Warningf("Failed to install driver: %v", err.Error()) 97 os.Exit(1) 98 } 99 case checksumCmdStr: 100 if err := checksumCmd.Parse(os.Args[2:]); err != nil { 101 log.Warningf("%s failed with: %v", checksumCmdStr, err) 102 os.Exit(1) 103 } 104 105 checksum, err := drivers.ChecksumDriver(ctx, *checksumVersion) 106 if err != nil { 107 log.Warningf("Failed to compute checksum: %v", err) 108 os.Exit(1) 109 } 110 fmt.Printf("Checksum: %q\n", checksum) 111 case validateChecksumCmdStr: 112 if err := validateChecksumCmd.Parse(os.Args[2:]); err != nil { 113 log.Warningf("%s failed with: %v", validateChecksumCmdStr, err) 114 os.Exit(1) 115 } 116 117 nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, checksum string) { 118 wantChecksum, err := drivers.ChecksumDriver(ctx, version.String()) 119 if err != nil { 120 log.Warningf("error on version %q: %v", version.String(), err) 121 return 122 } 123 if checksum != wantChecksum { 124 log.Warningf("Checksum mismatch on driver %q got: %q want: %q", version.String(), checksum, wantChecksum) 125 return 126 } 127 log.Infof("Checksum matched on driver %q.", version.String()) 128 }) 129 case listCmdStr: 130 if err := listCmd.Parse(os.Args[2:]); err != nil { 131 log.Warningf("%s failed with: %v", listCmdStr, err) 132 os.Exit(1) 133 } 134 if err := drivers.ListSupportedDrivers(*outfile); err != nil { 135 log.Warningf("Failed to list drivers: %v", err) 136 os.Exit(1) 137 } 138 default: 139 printUsage() 140 os.Exit(1) 141 } 142 }