github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/devices/gpu/nvidia/nvml/client_test.go (about) 1 package nvml 2 3 import ( 4 "errors" 5 "testing" 6 7 "github.com/hashicorp/nomad/helper" 8 "github.com/stretchr/testify/require" 9 ) 10 11 type MockNVMLDriver struct { 12 systemDriverCallSuccessful bool 13 deviceCountCallSuccessful bool 14 deviceInfoByIndexCallSuccessful bool 15 deviceInfoAndStatusByIndexCallSuccessful bool 16 driverVersion string 17 devices []*DeviceInfo 18 deviceStatus []*DeviceStatus 19 } 20 21 func (m *MockNVMLDriver) Initialize() error { 22 return nil 23 } 24 25 func (m *MockNVMLDriver) Shutdown() error { 26 return nil 27 } 28 29 func (m *MockNVMLDriver) SystemDriverVersion() (string, error) { 30 if !m.systemDriverCallSuccessful { 31 return "", errors.New("failed to get system driver") 32 } 33 return m.driverVersion, nil 34 } 35 36 func (m *MockNVMLDriver) DeviceCount() (uint, error) { 37 if !m.deviceCountCallSuccessful { 38 return 0, errors.New("failed to get device length") 39 } 40 return uint(len(m.devices)), nil 41 } 42 43 func (m *MockNVMLDriver) DeviceInfoByIndex(index uint) (*DeviceInfo, error) { 44 if index >= uint(len(m.devices)) { 45 return nil, errors.New("index is out of range") 46 } 47 if !m.deviceInfoByIndexCallSuccessful { 48 return nil, errors.New("failed to get device info by index") 49 } 50 return m.devices[index], nil 51 } 52 53 func (m *MockNVMLDriver) DeviceInfoAndStatusByIndex(index uint) (*DeviceInfo, *DeviceStatus, error) { 54 if index >= uint(len(m.devices)) || index >= uint(len(m.deviceStatus)) { 55 return nil, nil, errors.New("index is out of range") 56 } 57 if !m.deviceInfoAndStatusByIndexCallSuccessful { 58 return nil, nil, errors.New("failed to get device info and status by index") 59 } 60 return m.devices[index], m.deviceStatus[index], nil 61 } 62 63 func TestGetFingerprintDataFromNVML(t *testing.T) { 64 for _, testCase := range []struct { 65 Name string 66 DriverConfiguration *MockNVMLDriver 67 ExpectedError bool 68 ExpectedResult *FingerprintData 69 }{ 70 { 71 Name: "fail on systemDriverCallSuccessful", 72 ExpectedError: true, 73 ExpectedResult: nil, 74 DriverConfiguration: &MockNVMLDriver{ 75 systemDriverCallSuccessful: false, 76 deviceCountCallSuccessful: true, 77 deviceInfoByIndexCallSuccessful: true, 78 }, 79 }, 80 { 81 Name: "fail on deviceCountCallSuccessful", 82 ExpectedError: true, 83 ExpectedResult: nil, 84 DriverConfiguration: &MockNVMLDriver{ 85 systemDriverCallSuccessful: true, 86 deviceCountCallSuccessful: false, 87 deviceInfoByIndexCallSuccessful: true, 88 }, 89 }, 90 { 91 Name: "fail on deviceInfoByIndexCall", 92 ExpectedError: true, 93 ExpectedResult: nil, 94 DriverConfiguration: &MockNVMLDriver{ 95 systemDriverCallSuccessful: true, 96 deviceCountCallSuccessful: true, 97 deviceInfoByIndexCallSuccessful: false, 98 devices: []*DeviceInfo{ 99 { 100 UUID: "UUID1", 101 Name: helper.StringToPtr("ModelName1"), 102 MemoryMiB: helper.Uint64ToPtr(16), 103 PCIBusID: "busId", 104 PowerW: helper.UintToPtr(100), 105 BAR1MiB: helper.Uint64ToPtr(100), 106 PCIBandwidthMBPerS: helper.UintToPtr(100), 107 CoresClockMHz: helper.UintToPtr(100), 108 MemoryClockMHz: helper.UintToPtr(100), 109 }, { 110 UUID: "UUID2", 111 Name: helper.StringToPtr("ModelName2"), 112 MemoryMiB: helper.Uint64ToPtr(8), 113 PCIBusID: "busId", 114 PowerW: helper.UintToPtr(100), 115 BAR1MiB: helper.Uint64ToPtr(100), 116 PCIBandwidthMBPerS: helper.UintToPtr(100), 117 CoresClockMHz: helper.UintToPtr(100), 118 MemoryClockMHz: helper.UintToPtr(100), 119 }, 120 }, 121 }, 122 }, 123 { 124 Name: "successful outcome", 125 ExpectedError: false, 126 ExpectedResult: &FingerprintData{ 127 DriverVersion: "driverVersion", 128 Devices: []*FingerprintDeviceData{ 129 { 130 DeviceData: &DeviceData{ 131 DeviceName: helper.StringToPtr("ModelName1"), 132 UUID: "UUID1", 133 MemoryMiB: helper.Uint64ToPtr(16), 134 PowerW: helper.UintToPtr(100), 135 BAR1MiB: helper.Uint64ToPtr(100), 136 }, 137 PCIBusID: "busId1", 138 PCIBandwidthMBPerS: helper.UintToPtr(100), 139 CoresClockMHz: helper.UintToPtr(100), 140 MemoryClockMHz: helper.UintToPtr(100), 141 DisplayState: "Enabled", 142 PersistenceMode: "Enabled", 143 }, { 144 DeviceData: &DeviceData{ 145 DeviceName: helper.StringToPtr("ModelName2"), 146 UUID: "UUID2", 147 MemoryMiB: helper.Uint64ToPtr(8), 148 PowerW: helper.UintToPtr(200), 149 BAR1MiB: helper.Uint64ToPtr(200), 150 }, 151 PCIBusID: "busId2", 152 PCIBandwidthMBPerS: helper.UintToPtr(200), 153 CoresClockMHz: helper.UintToPtr(200), 154 MemoryClockMHz: helper.UintToPtr(200), 155 DisplayState: "Enabled", 156 PersistenceMode: "Enabled", 157 }, 158 }, 159 }, 160 DriverConfiguration: &MockNVMLDriver{ 161 systemDriverCallSuccessful: true, 162 deviceCountCallSuccessful: true, 163 deviceInfoByIndexCallSuccessful: true, 164 driverVersion: "driverVersion", 165 devices: []*DeviceInfo{ 166 { 167 UUID: "UUID1", 168 Name: helper.StringToPtr("ModelName1"), 169 MemoryMiB: helper.Uint64ToPtr(16), 170 PCIBusID: "busId1", 171 PowerW: helper.UintToPtr(100), 172 BAR1MiB: helper.Uint64ToPtr(100), 173 PCIBandwidthMBPerS: helper.UintToPtr(100), 174 CoresClockMHz: helper.UintToPtr(100), 175 MemoryClockMHz: helper.UintToPtr(100), 176 DisplayState: "Enabled", 177 PersistenceMode: "Enabled", 178 }, { 179 UUID: "UUID2", 180 Name: helper.StringToPtr("ModelName2"), 181 MemoryMiB: helper.Uint64ToPtr(8), 182 PCIBusID: "busId2", 183 PowerW: helper.UintToPtr(200), 184 BAR1MiB: helper.Uint64ToPtr(200), 185 PCIBandwidthMBPerS: helper.UintToPtr(200), 186 CoresClockMHz: helper.UintToPtr(200), 187 MemoryClockMHz: helper.UintToPtr(200), 188 DisplayState: "Enabled", 189 PersistenceMode: "Enabled", 190 }, 191 }, 192 }, 193 }, 194 } { 195 cli := nvmlClient{driver: testCase.DriverConfiguration} 196 fingerprintData, err := cli.GetFingerprintData() 197 if testCase.ExpectedError && err == nil { 198 t.Errorf("case '%s' : expected Error, but didn't get one", testCase.Name) 199 } 200 if !testCase.ExpectedError && err != nil { 201 t.Errorf("case '%s' : unexpected Error '%v'", testCase.Name, err) 202 } 203 require.New(t).Equal(testCase.ExpectedResult, fingerprintData) 204 } 205 } 206 207 func TestGetStatsDataFromNVML(t *testing.T) { 208 for _, testCase := range []struct { 209 Name string 210 DriverConfiguration *MockNVMLDriver 211 ExpectedError bool 212 ExpectedResult []*StatsData 213 }{ 214 { 215 Name: "fail on deviceCountCallSuccessful", 216 ExpectedError: true, 217 ExpectedResult: nil, 218 DriverConfiguration: &MockNVMLDriver{ 219 systemDriverCallSuccessful: true, 220 deviceCountCallSuccessful: false, 221 deviceInfoByIndexCallSuccessful: true, 222 deviceInfoAndStatusByIndexCallSuccessful: true, 223 }, 224 }, 225 { 226 Name: "fail on DeviceInfoAndStatusByIndex call", 227 ExpectedError: true, 228 ExpectedResult: nil, 229 DriverConfiguration: &MockNVMLDriver{ 230 systemDriverCallSuccessful: true, 231 deviceCountCallSuccessful: true, 232 deviceInfoAndStatusByIndexCallSuccessful: false, 233 devices: []*DeviceInfo{ 234 { 235 UUID: "UUID1", 236 Name: helper.StringToPtr("ModelName1"), 237 MemoryMiB: helper.Uint64ToPtr(16), 238 PCIBusID: "busId1", 239 PowerW: helper.UintToPtr(100), 240 BAR1MiB: helper.Uint64ToPtr(100), 241 PCIBandwidthMBPerS: helper.UintToPtr(100), 242 CoresClockMHz: helper.UintToPtr(100), 243 MemoryClockMHz: helper.UintToPtr(100), 244 }, { 245 UUID: "UUID2", 246 Name: helper.StringToPtr("ModelName2"), 247 MemoryMiB: helper.Uint64ToPtr(8), 248 PCIBusID: "busId2", 249 PowerW: helper.UintToPtr(200), 250 BAR1MiB: helper.Uint64ToPtr(200), 251 PCIBandwidthMBPerS: helper.UintToPtr(200), 252 CoresClockMHz: helper.UintToPtr(200), 253 MemoryClockMHz: helper.UintToPtr(200), 254 }, 255 }, 256 deviceStatus: []*DeviceStatus{ 257 { 258 TemperatureC: helper.UintToPtr(1), 259 GPUUtilization: helper.UintToPtr(1), 260 MemoryUtilization: helper.UintToPtr(1), 261 EncoderUtilization: helper.UintToPtr(1), 262 DecoderUtilization: helper.UintToPtr(1), 263 UsedMemoryMiB: helper.Uint64ToPtr(1), 264 ECCErrorsL1Cache: helper.Uint64ToPtr(1), 265 ECCErrorsL2Cache: helper.Uint64ToPtr(1), 266 ECCErrorsDevice: helper.Uint64ToPtr(1), 267 PowerUsageW: helper.UintToPtr(1), 268 BAR1UsedMiB: helper.Uint64ToPtr(1), 269 }, 270 { 271 TemperatureC: helper.UintToPtr(2), 272 GPUUtilization: helper.UintToPtr(2), 273 MemoryUtilization: helper.UintToPtr(2), 274 EncoderUtilization: helper.UintToPtr(2), 275 DecoderUtilization: helper.UintToPtr(2), 276 UsedMemoryMiB: helper.Uint64ToPtr(2), 277 ECCErrorsL1Cache: helper.Uint64ToPtr(2), 278 ECCErrorsL2Cache: helper.Uint64ToPtr(2), 279 ECCErrorsDevice: helper.Uint64ToPtr(2), 280 PowerUsageW: helper.UintToPtr(2), 281 BAR1UsedMiB: helper.Uint64ToPtr(2), 282 }, 283 }, 284 }, 285 }, 286 { 287 Name: "successful outcome", 288 ExpectedError: false, 289 ExpectedResult: []*StatsData{ 290 { 291 DeviceData: &DeviceData{ 292 DeviceName: helper.StringToPtr("ModelName1"), 293 UUID: "UUID1", 294 MemoryMiB: helper.Uint64ToPtr(16), 295 PowerW: helper.UintToPtr(100), 296 BAR1MiB: helper.Uint64ToPtr(100), 297 }, 298 TemperatureC: helper.UintToPtr(1), 299 GPUUtilization: helper.UintToPtr(1), 300 MemoryUtilization: helper.UintToPtr(1), 301 EncoderUtilization: helper.UintToPtr(1), 302 DecoderUtilization: helper.UintToPtr(1), 303 UsedMemoryMiB: helper.Uint64ToPtr(1), 304 ECCErrorsL1Cache: helper.Uint64ToPtr(1), 305 ECCErrorsL2Cache: helper.Uint64ToPtr(1), 306 ECCErrorsDevice: helper.Uint64ToPtr(1), 307 PowerUsageW: helper.UintToPtr(1), 308 BAR1UsedMiB: helper.Uint64ToPtr(1), 309 }, 310 { 311 DeviceData: &DeviceData{ 312 DeviceName: helper.StringToPtr("ModelName2"), 313 UUID: "UUID2", 314 MemoryMiB: helper.Uint64ToPtr(8), 315 PowerW: helper.UintToPtr(200), 316 BAR1MiB: helper.Uint64ToPtr(200), 317 }, 318 TemperatureC: helper.UintToPtr(2), 319 GPUUtilization: helper.UintToPtr(2), 320 MemoryUtilization: helper.UintToPtr(2), 321 EncoderUtilization: helper.UintToPtr(2), 322 DecoderUtilization: helper.UintToPtr(2), 323 UsedMemoryMiB: helper.Uint64ToPtr(2), 324 ECCErrorsL1Cache: helper.Uint64ToPtr(2), 325 ECCErrorsL2Cache: helper.Uint64ToPtr(2), 326 ECCErrorsDevice: helper.Uint64ToPtr(2), 327 PowerUsageW: helper.UintToPtr(2), 328 BAR1UsedMiB: helper.Uint64ToPtr(2), 329 }, 330 }, 331 DriverConfiguration: &MockNVMLDriver{ 332 deviceCountCallSuccessful: true, 333 deviceInfoByIndexCallSuccessful: true, 334 deviceInfoAndStatusByIndexCallSuccessful: true, 335 devices: []*DeviceInfo{ 336 { 337 UUID: "UUID1", 338 Name: helper.StringToPtr("ModelName1"), 339 MemoryMiB: helper.Uint64ToPtr(16), 340 PCIBusID: "busId1", 341 PowerW: helper.UintToPtr(100), 342 BAR1MiB: helper.Uint64ToPtr(100), 343 PCIBandwidthMBPerS: helper.UintToPtr(100), 344 CoresClockMHz: helper.UintToPtr(100), 345 MemoryClockMHz: helper.UintToPtr(100), 346 }, { 347 UUID: "UUID2", 348 Name: helper.StringToPtr("ModelName2"), 349 MemoryMiB: helper.Uint64ToPtr(8), 350 PCIBusID: "busId2", 351 PowerW: helper.UintToPtr(200), 352 BAR1MiB: helper.Uint64ToPtr(200), 353 PCIBandwidthMBPerS: helper.UintToPtr(200), 354 CoresClockMHz: helper.UintToPtr(200), 355 MemoryClockMHz: helper.UintToPtr(200), 356 }, 357 }, 358 deviceStatus: []*DeviceStatus{ 359 { 360 TemperatureC: helper.UintToPtr(1), 361 GPUUtilization: helper.UintToPtr(1), 362 MemoryUtilization: helper.UintToPtr(1), 363 EncoderUtilization: helper.UintToPtr(1), 364 DecoderUtilization: helper.UintToPtr(1), 365 UsedMemoryMiB: helper.Uint64ToPtr(1), 366 ECCErrorsL1Cache: helper.Uint64ToPtr(1), 367 ECCErrorsL2Cache: helper.Uint64ToPtr(1), 368 ECCErrorsDevice: helper.Uint64ToPtr(1), 369 PowerUsageW: helper.UintToPtr(1), 370 BAR1UsedMiB: helper.Uint64ToPtr(1), 371 }, 372 { 373 TemperatureC: helper.UintToPtr(2), 374 GPUUtilization: helper.UintToPtr(2), 375 MemoryUtilization: helper.UintToPtr(2), 376 EncoderUtilization: helper.UintToPtr(2), 377 DecoderUtilization: helper.UintToPtr(2), 378 UsedMemoryMiB: helper.Uint64ToPtr(2), 379 ECCErrorsL1Cache: helper.Uint64ToPtr(2), 380 ECCErrorsL2Cache: helper.Uint64ToPtr(2), 381 ECCErrorsDevice: helper.Uint64ToPtr(2), 382 PowerUsageW: helper.UintToPtr(2), 383 BAR1UsedMiB: helper.Uint64ToPtr(2), 384 }, 385 }, 386 }, 387 }, 388 } { 389 cli := nvmlClient{driver: testCase.DriverConfiguration} 390 statsData, err := cli.GetStatsData() 391 if testCase.ExpectedError && err == nil { 392 t.Errorf("case '%s' : expected Error, but didn't get one", testCase.Name) 393 } 394 if !testCase.ExpectedError && err != nil { 395 t.Errorf("case '%s' : unexpected Error '%v'", testCase.Name, err) 396 } 397 require.New(t).Equal(testCase.ExpectedResult, statsData) 398 } 399 }