github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/internal/runner/cuda_linux.go (about)

     1  // +build !NO_CUDA
     2  
     3  // Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
     4  
     5  package runner
     6  
     7  // This file contains the implementation and interface code for the CUDA capable devices
     8  // that are provisioned on a system
     9  
    10  import (
    11  	"fmt"
    12  	"sync"
    13  
    14  	"github.com/go-stack/stack"
    15  	"github.com/jjeffery/kv" // MIT License
    16  
    17  	nvml "github.com/karlmutch/go-nvml" // MIT License
    18  )
    19  
    20  var (
    21  	initErr  kv.Error
    22  	nvmlOnce sync.Once
    23  
    24  	nvmlInit = func() {
    25  		if errGo := nvml.NVMLInit(); errGo != nil {
    26  			initErr = kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
    27  
    28  			return
    29  		}
    30  
    31  		// If the cuda management layer started and is working then check
    32  		// what hardware capabilities exist and print warning etc if needed as the server is started
    33  		devs, errGo := nvml.GetAllGPUs()
    34  		if errGo != nil {
    35  			fmt.Println(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
    36  			return
    37  		}
    38  		for _, dev := range devs {
    39  			name, _ := dev.Name()
    40  
    41  			uuid, errGo := dev.UUID()
    42  			if errGo != nil {
    43  				fmt.Println(kv.Wrap(errGo).With("name", name).With("stack", stack.Trace().TrimRuntime()))
    44  				continue
    45  			}
    46  
    47  			if _, errGo = dev.MemoryInfo(); errGo != nil {
    48  				fmt.Println(kv.Wrap(errGo).With("name", name).With("GPUID", uuid).With("stack", stack.Trace().TrimRuntime()))
    49  				continue
    50  			}
    51  
    52  			if errEcc := dev.EccErrors(); errEcc != nil {
    53  				fmt.Println(kv.Wrap(errEcc).With("name", name).With("GPUID", uuid).With("stack", stack.Trace().TrimRuntime()))
    54  				continue
    55  			}
    56  		}
    57  	}
    58  )
    59  
    60  // HasCUDA allows an external package to test for the presence of CUDA support
    61  // in the go code of this package
    62  func HasCUDA() bool {
    63  	nvmlOnce.Do(nvmlInit)
    64  	return true
    65  }
    66  
    67  func getCUDAInfo() (outDevs cudaDevices, err kv.Error) {
    68  
    69  	nvmlOnce.Do(nvmlInit)
    70  
    71  	outDevs = cudaDevices{
    72  		Devices: []device{},
    73  	}
    74  
    75  	// Dont let the GetAllGPUs log a fatal error catch it first
    76  	if initErr != nil {
    77  		return outDevs, initErr
    78  	}
    79  
    80  	devs, errGo := nvml.GetAllGPUs()
    81  	if errGo != nil {
    82  		return outDevs, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
    83  	}
    84  
    85  	for _, dev := range devs {
    86  
    87  		name, _ := dev.Name()
    88  
    89  		uuid, errGo := dev.UUID()
    90  		if errGo != nil {
    91  			return outDevs, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
    92  		}
    93  
    94  		temp, _ := dev.Temp()
    95  		powr, _ := dev.PowerUsage()
    96  
    97  		mem, errGo := dev.MemoryInfo()
    98  		if errGo != nil {
    99  			return outDevs, kv.Wrap(errGo).With("GPUID", uuid).With("stack", stack.Trace().TrimRuntime())
   100  		}
   101  
   102  		runnerDev := device{
   103  			Name:    name,
   104  			UUID:    uuid,
   105  			Temp:    temp,
   106  			Powr:    powr,
   107  			MemTot:  mem.Total,
   108  			MemUsed: mem.Used,
   109  			MemFree: mem.Free,
   110  		}
   111  		// Dont use the ECC Error check on AWS as the NVML APIs do not appear to return the expected values
   112  		if isAWS, _ := IsAWS(); !isAWS && !CudaInTest {
   113  			_, _, errGo := dev.EccCounts()
   114  			if errGo != nil && errGo.Error() != "nvmlDeviceGetMemoryErrorCounter is not supported on this hardware" {
   115  				if errEcc := dev.EccVolatileErrors(); errEcc != nil {
   116  					err := kv.Wrap(errEcc).With("stack", stack.Trace().TrimRuntime())
   117  					runnerDev.EccFailure = &err
   118  				}
   119  			}
   120  		}
   121  		outDevs.Devices = append(outDevs.Devices, runnerDev)
   122  	}
   123  	return outDevs, nil
   124  }