github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/devices/gpu/nvidia/nvml/client_test.go (about)

     1  package nvml
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/hashicorp/nomad/helper"
     8  	"github.com/stretchr/testify/require"
     9  )
    10  
    11  type MockNVMLDriver struct {
    12  	systemDriverCallSuccessful               bool
    13  	deviceCountCallSuccessful                bool
    14  	deviceInfoByIndexCallSuccessful          bool
    15  	deviceInfoAndStatusByIndexCallSuccessful bool
    16  	driverVersion                            string
    17  	devices                                  []*DeviceInfo
    18  	deviceStatus                             []*DeviceStatus
    19  }
    20  
    21  func (m *MockNVMLDriver) Initialize() error {
    22  	return nil
    23  }
    24  
    25  func (m *MockNVMLDriver) Shutdown() error {
    26  	return nil
    27  }
    28  
    29  func (m *MockNVMLDriver) SystemDriverVersion() (string, error) {
    30  	if !m.systemDriverCallSuccessful {
    31  		return "", errors.New("failed to get system driver")
    32  	}
    33  	return m.driverVersion, nil
    34  }
    35  
    36  func (m *MockNVMLDriver) DeviceCount() (uint, error) {
    37  	if !m.deviceCountCallSuccessful {
    38  		return 0, errors.New("failed to get device length")
    39  	}
    40  	return uint(len(m.devices)), nil
    41  }
    42  
    43  func (m *MockNVMLDriver) DeviceInfoByIndex(index uint) (*DeviceInfo, error) {
    44  	if index >= uint(len(m.devices)) {
    45  		return nil, errors.New("index is out of range")
    46  	}
    47  	if !m.deviceInfoByIndexCallSuccessful {
    48  		return nil, errors.New("failed to get device info by index")
    49  	}
    50  	return m.devices[index], nil
    51  }
    52  
    53  func (m *MockNVMLDriver) DeviceInfoAndStatusByIndex(index uint) (*DeviceInfo, *DeviceStatus, error) {
    54  	if index >= uint(len(m.devices)) || index >= uint(len(m.deviceStatus)) {
    55  		return nil, nil, errors.New("index is out of range")
    56  	}
    57  	if !m.deviceInfoAndStatusByIndexCallSuccessful {
    58  		return nil, nil, errors.New("failed to get device info and status by index")
    59  	}
    60  	return m.devices[index], m.deviceStatus[index], nil
    61  }
    62  
    63  func TestGetFingerprintDataFromNVML(t *testing.T) {
    64  	for _, testCase := range []struct {
    65  		Name                string
    66  		DriverConfiguration *MockNVMLDriver
    67  		ExpectedError       bool
    68  		ExpectedResult      *FingerprintData
    69  	}{
    70  		{
    71  			Name:           "fail on systemDriverCallSuccessful",
    72  			ExpectedError:  true,
    73  			ExpectedResult: nil,
    74  			DriverConfiguration: &MockNVMLDriver{
    75  				systemDriverCallSuccessful:      false,
    76  				deviceCountCallSuccessful:       true,
    77  				deviceInfoByIndexCallSuccessful: true,
    78  			},
    79  		},
    80  		{
    81  			Name:           "fail on deviceCountCallSuccessful",
    82  			ExpectedError:  true,
    83  			ExpectedResult: nil,
    84  			DriverConfiguration: &MockNVMLDriver{
    85  				systemDriverCallSuccessful:      true,
    86  				deviceCountCallSuccessful:       false,
    87  				deviceInfoByIndexCallSuccessful: true,
    88  			},
    89  		},
    90  		{
    91  			Name:           "fail on deviceInfoByIndexCall",
    92  			ExpectedError:  true,
    93  			ExpectedResult: nil,
    94  			DriverConfiguration: &MockNVMLDriver{
    95  				systemDriverCallSuccessful:      true,
    96  				deviceCountCallSuccessful:       true,
    97  				deviceInfoByIndexCallSuccessful: false,
    98  				devices: []*DeviceInfo{
    99  					{
   100  						UUID:               "UUID1",
   101  						Name:               helper.StringToPtr("ModelName1"),
   102  						MemoryMiB:          helper.Uint64ToPtr(16),
   103  						PCIBusID:           "busId",
   104  						PowerW:             helper.UintToPtr(100),
   105  						BAR1MiB:            helper.Uint64ToPtr(100),
   106  						PCIBandwidthMBPerS: helper.UintToPtr(100),
   107  						CoresClockMHz:      helper.UintToPtr(100),
   108  						MemoryClockMHz:     helper.UintToPtr(100),
   109  					}, {
   110  						UUID:               "UUID2",
   111  						Name:               helper.StringToPtr("ModelName2"),
   112  						MemoryMiB:          helper.Uint64ToPtr(8),
   113  						PCIBusID:           "busId",
   114  						PowerW:             helper.UintToPtr(100),
   115  						BAR1MiB:            helper.Uint64ToPtr(100),
   116  						PCIBandwidthMBPerS: helper.UintToPtr(100),
   117  						CoresClockMHz:      helper.UintToPtr(100),
   118  						MemoryClockMHz:     helper.UintToPtr(100),
   119  					},
   120  				},
   121  			},
   122  		},
   123  		{
   124  			Name:          "successful outcome",
   125  			ExpectedError: false,
   126  			ExpectedResult: &FingerprintData{
   127  				DriverVersion: "driverVersion",
   128  				Devices: []*FingerprintDeviceData{
   129  					{
   130  						DeviceData: &DeviceData{
   131  							DeviceName: helper.StringToPtr("ModelName1"),
   132  							UUID:       "UUID1",
   133  							MemoryMiB:  helper.Uint64ToPtr(16),
   134  							PowerW:     helper.UintToPtr(100),
   135  							BAR1MiB:    helper.Uint64ToPtr(100),
   136  						},
   137  						PCIBusID:           "busId1",
   138  						PCIBandwidthMBPerS: helper.UintToPtr(100),
   139  						CoresClockMHz:      helper.UintToPtr(100),
   140  						MemoryClockMHz:     helper.UintToPtr(100),
   141  						DisplayState:       "Enabled",
   142  						PersistenceMode:    "Enabled",
   143  					}, {
   144  						DeviceData: &DeviceData{
   145  							DeviceName: helper.StringToPtr("ModelName2"),
   146  							UUID:       "UUID2",
   147  							MemoryMiB:  helper.Uint64ToPtr(8),
   148  							PowerW:     helper.UintToPtr(200),
   149  							BAR1MiB:    helper.Uint64ToPtr(200),
   150  						},
   151  						PCIBusID:           "busId2",
   152  						PCIBandwidthMBPerS: helper.UintToPtr(200),
   153  						CoresClockMHz:      helper.UintToPtr(200),
   154  						MemoryClockMHz:     helper.UintToPtr(200),
   155  						DisplayState:       "Enabled",
   156  						PersistenceMode:    "Enabled",
   157  					},
   158  				},
   159  			},
   160  			DriverConfiguration: &MockNVMLDriver{
   161  				systemDriverCallSuccessful:      true,
   162  				deviceCountCallSuccessful:       true,
   163  				deviceInfoByIndexCallSuccessful: true,
   164  				driverVersion:                   "driverVersion",
   165  				devices: []*DeviceInfo{
   166  					{
   167  						UUID:               "UUID1",
   168  						Name:               helper.StringToPtr("ModelName1"),
   169  						MemoryMiB:          helper.Uint64ToPtr(16),
   170  						PCIBusID:           "busId1",
   171  						PowerW:             helper.UintToPtr(100),
   172  						BAR1MiB:            helper.Uint64ToPtr(100),
   173  						PCIBandwidthMBPerS: helper.UintToPtr(100),
   174  						CoresClockMHz:      helper.UintToPtr(100),
   175  						MemoryClockMHz:     helper.UintToPtr(100),
   176  						DisplayState:       "Enabled",
   177  						PersistenceMode:    "Enabled",
   178  					}, {
   179  						UUID:               "UUID2",
   180  						Name:               helper.StringToPtr("ModelName2"),
   181  						MemoryMiB:          helper.Uint64ToPtr(8),
   182  						PCIBusID:           "busId2",
   183  						PowerW:             helper.UintToPtr(200),
   184  						BAR1MiB:            helper.Uint64ToPtr(200),
   185  						PCIBandwidthMBPerS: helper.UintToPtr(200),
   186  						CoresClockMHz:      helper.UintToPtr(200),
   187  						MemoryClockMHz:     helper.UintToPtr(200),
   188  						DisplayState:       "Enabled",
   189  						PersistenceMode:    "Enabled",
   190  					},
   191  				},
   192  			},
   193  		},
   194  	} {
   195  		cli := nvmlClient{driver: testCase.DriverConfiguration}
   196  		fingerprintData, err := cli.GetFingerprintData()
   197  		if testCase.ExpectedError && err == nil {
   198  			t.Errorf("case '%s' : expected Error, but didn't get one", testCase.Name)
   199  		}
   200  		if !testCase.ExpectedError && err != nil {
   201  			t.Errorf("case '%s' : unexpected Error '%v'", testCase.Name, err)
   202  		}
   203  		require.New(t).Equal(testCase.ExpectedResult, fingerprintData)
   204  	}
   205  }
   206  
   207  func TestGetStatsDataFromNVML(t *testing.T) {
   208  	for _, testCase := range []struct {
   209  		Name                string
   210  		DriverConfiguration *MockNVMLDriver
   211  		ExpectedError       bool
   212  		ExpectedResult      []*StatsData
   213  	}{
   214  		{
   215  			Name:           "fail on deviceCountCallSuccessful",
   216  			ExpectedError:  true,
   217  			ExpectedResult: nil,
   218  			DriverConfiguration: &MockNVMLDriver{
   219  				systemDriverCallSuccessful:               true,
   220  				deviceCountCallSuccessful:                false,
   221  				deviceInfoByIndexCallSuccessful:          true,
   222  				deviceInfoAndStatusByIndexCallSuccessful: true,
   223  			},
   224  		},
   225  		{
   226  			Name:           "fail on DeviceInfoAndStatusByIndex call",
   227  			ExpectedError:  true,
   228  			ExpectedResult: nil,
   229  			DriverConfiguration: &MockNVMLDriver{
   230  				systemDriverCallSuccessful:               true,
   231  				deviceCountCallSuccessful:                true,
   232  				deviceInfoAndStatusByIndexCallSuccessful: false,
   233  				devices: []*DeviceInfo{
   234  					{
   235  						UUID:               "UUID1",
   236  						Name:               helper.StringToPtr("ModelName1"),
   237  						MemoryMiB:          helper.Uint64ToPtr(16),
   238  						PCIBusID:           "busId1",
   239  						PowerW:             helper.UintToPtr(100),
   240  						BAR1MiB:            helper.Uint64ToPtr(100),
   241  						PCIBandwidthMBPerS: helper.UintToPtr(100),
   242  						CoresClockMHz:      helper.UintToPtr(100),
   243  						MemoryClockMHz:     helper.UintToPtr(100),
   244  					}, {
   245  						UUID:               "UUID2",
   246  						Name:               helper.StringToPtr("ModelName2"),
   247  						MemoryMiB:          helper.Uint64ToPtr(8),
   248  						PCIBusID:           "busId2",
   249  						PowerW:             helper.UintToPtr(200),
   250  						BAR1MiB:            helper.Uint64ToPtr(200),
   251  						PCIBandwidthMBPerS: helper.UintToPtr(200),
   252  						CoresClockMHz:      helper.UintToPtr(200),
   253  						MemoryClockMHz:     helper.UintToPtr(200),
   254  					},
   255  				},
   256  				deviceStatus: []*DeviceStatus{
   257  					{
   258  						TemperatureC:       helper.UintToPtr(1),
   259  						GPUUtilization:     helper.UintToPtr(1),
   260  						MemoryUtilization:  helper.UintToPtr(1),
   261  						EncoderUtilization: helper.UintToPtr(1),
   262  						DecoderUtilization: helper.UintToPtr(1),
   263  						UsedMemoryMiB:      helper.Uint64ToPtr(1),
   264  						ECCErrorsL1Cache:   helper.Uint64ToPtr(1),
   265  						ECCErrorsL2Cache:   helper.Uint64ToPtr(1),
   266  						ECCErrorsDevice:    helper.Uint64ToPtr(1),
   267  						PowerUsageW:        helper.UintToPtr(1),
   268  						BAR1UsedMiB:        helper.Uint64ToPtr(1),
   269  					},
   270  					{
   271  						TemperatureC:       helper.UintToPtr(2),
   272  						GPUUtilization:     helper.UintToPtr(2),
   273  						MemoryUtilization:  helper.UintToPtr(2),
   274  						EncoderUtilization: helper.UintToPtr(2),
   275  						DecoderUtilization: helper.UintToPtr(2),
   276  						UsedMemoryMiB:      helper.Uint64ToPtr(2),
   277  						ECCErrorsL1Cache:   helper.Uint64ToPtr(2),
   278  						ECCErrorsL2Cache:   helper.Uint64ToPtr(2),
   279  						ECCErrorsDevice:    helper.Uint64ToPtr(2),
   280  						PowerUsageW:        helper.UintToPtr(2),
   281  						BAR1UsedMiB:        helper.Uint64ToPtr(2),
   282  					},
   283  				},
   284  			},
   285  		},
   286  		{
   287  			Name:          "successful outcome",
   288  			ExpectedError: false,
   289  			ExpectedResult: []*StatsData{
   290  				{
   291  					DeviceData: &DeviceData{
   292  						DeviceName: helper.StringToPtr("ModelName1"),
   293  						UUID:       "UUID1",
   294  						MemoryMiB:  helper.Uint64ToPtr(16),
   295  						PowerW:     helper.UintToPtr(100),
   296  						BAR1MiB:    helper.Uint64ToPtr(100),
   297  					},
   298  					TemperatureC:       helper.UintToPtr(1),
   299  					GPUUtilization:     helper.UintToPtr(1),
   300  					MemoryUtilization:  helper.UintToPtr(1),
   301  					EncoderUtilization: helper.UintToPtr(1),
   302  					DecoderUtilization: helper.UintToPtr(1),
   303  					UsedMemoryMiB:      helper.Uint64ToPtr(1),
   304  					ECCErrorsL1Cache:   helper.Uint64ToPtr(1),
   305  					ECCErrorsL2Cache:   helper.Uint64ToPtr(1),
   306  					ECCErrorsDevice:    helper.Uint64ToPtr(1),
   307  					PowerUsageW:        helper.UintToPtr(1),
   308  					BAR1UsedMiB:        helper.Uint64ToPtr(1),
   309  				},
   310  				{
   311  					DeviceData: &DeviceData{
   312  						DeviceName: helper.StringToPtr("ModelName2"),
   313  						UUID:       "UUID2",
   314  						MemoryMiB:  helper.Uint64ToPtr(8),
   315  						PowerW:     helper.UintToPtr(200),
   316  						BAR1MiB:    helper.Uint64ToPtr(200),
   317  					},
   318  					TemperatureC:       helper.UintToPtr(2),
   319  					GPUUtilization:     helper.UintToPtr(2),
   320  					MemoryUtilization:  helper.UintToPtr(2),
   321  					EncoderUtilization: helper.UintToPtr(2),
   322  					DecoderUtilization: helper.UintToPtr(2),
   323  					UsedMemoryMiB:      helper.Uint64ToPtr(2),
   324  					ECCErrorsL1Cache:   helper.Uint64ToPtr(2),
   325  					ECCErrorsL2Cache:   helper.Uint64ToPtr(2),
   326  					ECCErrorsDevice:    helper.Uint64ToPtr(2),
   327  					PowerUsageW:        helper.UintToPtr(2),
   328  					BAR1UsedMiB:        helper.Uint64ToPtr(2),
   329  				},
   330  			},
   331  			DriverConfiguration: &MockNVMLDriver{
   332  				deviceCountCallSuccessful:                true,
   333  				deviceInfoByIndexCallSuccessful:          true,
   334  				deviceInfoAndStatusByIndexCallSuccessful: true,
   335  				devices: []*DeviceInfo{
   336  					{
   337  						UUID:               "UUID1",
   338  						Name:               helper.StringToPtr("ModelName1"),
   339  						MemoryMiB:          helper.Uint64ToPtr(16),
   340  						PCIBusID:           "busId1",
   341  						PowerW:             helper.UintToPtr(100),
   342  						BAR1MiB:            helper.Uint64ToPtr(100),
   343  						PCIBandwidthMBPerS: helper.UintToPtr(100),
   344  						CoresClockMHz:      helper.UintToPtr(100),
   345  						MemoryClockMHz:     helper.UintToPtr(100),
   346  					}, {
   347  						UUID:               "UUID2",
   348  						Name:               helper.StringToPtr("ModelName2"),
   349  						MemoryMiB:          helper.Uint64ToPtr(8),
   350  						PCIBusID:           "busId2",
   351  						PowerW:             helper.UintToPtr(200),
   352  						BAR1MiB:            helper.Uint64ToPtr(200),
   353  						PCIBandwidthMBPerS: helper.UintToPtr(200),
   354  						CoresClockMHz:      helper.UintToPtr(200),
   355  						MemoryClockMHz:     helper.UintToPtr(200),
   356  					},
   357  				},
   358  				deviceStatus: []*DeviceStatus{
   359  					{
   360  						TemperatureC:       helper.UintToPtr(1),
   361  						GPUUtilization:     helper.UintToPtr(1),
   362  						MemoryUtilization:  helper.UintToPtr(1),
   363  						EncoderUtilization: helper.UintToPtr(1),
   364  						DecoderUtilization: helper.UintToPtr(1),
   365  						UsedMemoryMiB:      helper.Uint64ToPtr(1),
   366  						ECCErrorsL1Cache:   helper.Uint64ToPtr(1),
   367  						ECCErrorsL2Cache:   helper.Uint64ToPtr(1),
   368  						ECCErrorsDevice:    helper.Uint64ToPtr(1),
   369  						PowerUsageW:        helper.UintToPtr(1),
   370  						BAR1UsedMiB:        helper.Uint64ToPtr(1),
   371  					},
   372  					{
   373  						TemperatureC:       helper.UintToPtr(2),
   374  						GPUUtilization:     helper.UintToPtr(2),
   375  						MemoryUtilization:  helper.UintToPtr(2),
   376  						EncoderUtilization: helper.UintToPtr(2),
   377  						DecoderUtilization: helper.UintToPtr(2),
   378  						UsedMemoryMiB:      helper.Uint64ToPtr(2),
   379  						ECCErrorsL1Cache:   helper.Uint64ToPtr(2),
   380  						ECCErrorsL2Cache:   helper.Uint64ToPtr(2),
   381  						ECCErrorsDevice:    helper.Uint64ToPtr(2),
   382  						PowerUsageW:        helper.UintToPtr(2),
   383  						BAR1UsedMiB:        helper.Uint64ToPtr(2),
   384  					},
   385  				},
   386  			},
   387  		},
   388  	} {
   389  		cli := nvmlClient{driver: testCase.DriverConfiguration}
   390  		statsData, err := cli.GetStatsData()
   391  		if testCase.ExpectedError && err == nil {
   392  			t.Errorf("case '%s' : expected Error, but didn't get one", testCase.Name)
   393  		}
   394  		if !testCase.ExpectedError && err != nil {
   395  			t.Errorf("case '%s' : unexpected Error '%v'", testCase.Name, err)
   396  		}
   397  		require.New(t).Equal(testCase.ExpectedResult, statsData)
   398  	}
   399  }