gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/gpu/main.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package main downloads and installs drivers.
    16  package main
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"os"
    22  
    23  	"gvisor.dev/gvisor/pkg/log"
    24  	"gvisor.dev/gvisor/pkg/sentry/devices/nvproxy"
    25  	"gvisor.dev/gvisor/runsc/flag"
    26  	"gvisor.dev/gvisor/tools/gpu/drivers"
    27  )
    28  
    29  const (
    30  	installCmdStr               = "install"
    31  	installDescription          = "installs a driver on the host machine"
    32  	checksumCmdStr              = "checksum"
    33  	checksumDescription         = "computes the sha256 checksum for a given driver version"
    34  	validateChecksumCmdStr      = "validate_checksum"
    35  	validateChecksumDescription = "validates the checksum of all supported drivers"
    36  	listCmdStr                  = "list"
    37  	listDescription             = "lists the supported drivers"
    38  )
    39  
    40  var (
    41  	// Install installs a give driver on the host machine.
    42  	installCmd = flag.NewFlagSet(installCmdStr, flag.ContinueOnError)
    43  	latest     = installCmd.Bool("latest", false, "install the latest supported driver")
    44  	version    = installCmd.String("version", "", "version of the driver")
    45  
    46  	// Computes the sha256 checksum for a given driver's .run file from the nvidia site.
    47  	checksumCmd     = flag.NewFlagSet(checksumCmdStr, flag.ContinueOnError)
    48  	checksumVersion = checksumCmd.String("version", "", "version of the driver")
    49  
    50  	// Validates all supported driver's checksums of each driver's .run file from the nvidia site.
    51  	validateChecksumCmd = flag.NewFlagSet(validateChecksumCmdStr, flag.ContinueOnError)
    52  
    53  	// The list command returns the list of supported drivers from this tool.
    54  	listCmd = flag.NewFlagSet(listCmdStr, flag.ContinueOnError)
    55  	outfile = listCmd.String("outfile", "", "if set, write the list output to this file")
    56  
    57  	commandSet = map[*flag.FlagSet]string{
    58  		installCmd:          installDescription,
    59  		checksumCmd:         checksumDescription,
    60  		validateChecksumCmd: validateChecksumDescription,
    61  		listCmd:             listDescription,
    62  	}
    63  )
    64  
    65  // printUsage prints the top level usage string.
    66  func printUsage() {
    67  	usage := `Usage: main <command> <flags> ...
    68  
    69  Available commands:`
    70  	fmt.Println(usage)
    71  	for _, f := range []*flag.FlagSet{installCmd, checksumCmd, validateChecksumCmd, listCmd} {
    72  		fmt.Printf("%s	%s\n", f.Name(), commandSet[f])
    73  		f.PrintDefaults()
    74  	}
    75  }
    76  
    77  func main() {
    78  	ctx := context.Background()
    79  	if len(os.Args) < 2 {
    80  		printUsage()
    81  		os.Exit(1)
    82  	}
    83  	nvproxy.Init()
    84  	switch os.Args[1] {
    85  	case installCmdStr:
    86  		if err := installCmd.Parse(os.Args[2:]); err != nil {
    87  			log.Warningf("%s failed with: %v", installCmdStr, err)
    88  			os.Exit(1)
    89  		}
    90  		installer, err := drivers.NewInstaller(*version, *latest)
    91  		if err != nil {
    92  			log.Warningf("Failed to create installer: %v", err.Error())
    93  			os.Exit(1)
    94  		}
    95  		if err := installer.MaybeInstall(ctx); err != nil {
    96  			log.Warningf("Failed to install driver: %v", err.Error())
    97  			os.Exit(1)
    98  		}
    99  	case checksumCmdStr:
   100  		if err := checksumCmd.Parse(os.Args[2:]); err != nil {
   101  			log.Warningf("%s failed with: %v", checksumCmdStr, err)
   102  			os.Exit(1)
   103  		}
   104  
   105  		checksum, err := drivers.ChecksumDriver(ctx, *checksumVersion)
   106  		if err != nil {
   107  			log.Warningf("Failed to compute checksum: %v", err)
   108  			os.Exit(1)
   109  		}
   110  		fmt.Printf("Checksum: %q\n", checksum)
   111  	case validateChecksumCmdStr:
   112  		if err := validateChecksumCmd.Parse(os.Args[2:]); err != nil {
   113  			log.Warningf("%s failed with: %v", validateChecksumCmdStr, err)
   114  			os.Exit(1)
   115  		}
   116  
   117  		nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, checksum string) {
   118  			wantChecksum, err := drivers.ChecksumDriver(ctx, version.String())
   119  			if err != nil {
   120  				log.Warningf("error on version %q: %v", version.String(), err)
   121  				return
   122  			}
   123  			if checksum != wantChecksum {
   124  				log.Warningf("Checksum mismatch on driver %q got: %q want: %q", version.String(), checksum, wantChecksum)
   125  				return
   126  			}
   127  			log.Infof("Checksum matched on driver %q.", version.String())
   128  		})
   129  	case listCmdStr:
   130  		if err := listCmd.Parse(os.Args[2:]); err != nil {
   131  			log.Warningf("%s failed with: %v", listCmdStr, err)
   132  			os.Exit(1)
   133  		}
   134  		if err := drivers.ListSupportedDrivers(*outfile); err != nil {
   135  			log.Warningf("Failed to list drivers: %v", err)
   136  			os.Exit(1)
   137  		}
   138  	default:
   139  		printUsage()
   140  		os.Exit(1)
   141  	}
   142  }