github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/devicemanager/manager_test.go (about) 1 package devicemanager 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 "testing" 8 "time" 9 10 log "github.com/hashicorp/go-hclog" 11 plugin "github.com/hashicorp/go-plugin" 12 "github.com/hashicorp/nomad/ci" 13 "github.com/hashicorp/nomad/client/state" 14 "github.com/hashicorp/nomad/helper/pluginutils/loader" 15 "github.com/hashicorp/nomad/helper/pointer" 16 "github.com/hashicorp/nomad/helper/testlog" 17 "github.com/hashicorp/nomad/helper/uuid" 18 "github.com/hashicorp/nomad/nomad/structs" 19 "github.com/hashicorp/nomad/plugins/base" 20 "github.com/hashicorp/nomad/plugins/device" 21 psstructs "github.com/hashicorp/nomad/plugins/shared/structs" 22 "github.com/hashicorp/nomad/testutil" 23 "github.com/stretchr/testify/require" 24 ) 25 26 var ( 27 nvidiaDevice0ID = uuid.Generate() 28 nvidiaDevice1ID = uuid.Generate() 29 nvidiaDeviceGroup = &device.DeviceGroup{ 30 Vendor: "nvidia", 31 Type: "gpu", 32 Name: "1080ti", 33 Devices: []*device.Device{ 34 { 35 ID: nvidiaDevice0ID, 36 Healthy: true, 37 }, 38 { 39 ID: nvidiaDevice1ID, 40 Healthy: true, 41 }, 42 }, 43 Attributes: map[string]*psstructs.Attribute{ 44 "memory": { 45 Int: pointer.Of(int64(4)), 46 Unit: "GB", 47 }, 48 }, 49 } 50 51 intelDeviceID = uuid.Generate() 52 intelDeviceGroup = &device.DeviceGroup{ 53 Vendor: "intel", 54 Type: "gpu", 55 Name: "640GT", 56 Devices: []*device.Device{ 57 { 58 ID: intelDeviceID, 59 Healthy: true, 60 }, 61 }, 62 Attributes: map[string]*psstructs.Attribute{ 63 "memory": { 64 Int: pointer.Of(int64(2)), 65 Unit: "GB", 66 }, 67 }, 68 } 69 70 nvidiaDeviceGroupStats = &device.DeviceGroupStats{ 71 Vendor: "nvidia", 72 Type: "gpu", 73 Name: "1080ti", 74 InstanceStats: map[string]*device.DeviceStats{ 75 nvidiaDevice0ID: { 76 Summary: &psstructs.StatValue{ 77 IntNumeratorVal: pointer.Of(int64(212)), 78 Unit: "F", 79 Desc: "Temperature", 80 }, 81 }, 82 nvidiaDevice1ID: { 83 Summary: &psstructs.StatValue{ 84 IntNumeratorVal: pointer.Of(int64(218)), 85 Unit: "F", 86 Desc: "Temperature", 87 }, 88 }, 89 }, 90 } 91 92 intelDeviceGroupStats = &device.DeviceGroupStats{ 93 Vendor: "intel", 94 Type: "gpu", 95 Name: "640GT", 96 InstanceStats: map[string]*device.DeviceStats{ 97 intelDeviceID: { 98 Summary: &psstructs.StatValue{ 99 IntNumeratorVal: pointer.Of(int64(220)), 100 Unit: "F", 101 Desc: "Temperature", 102 }, 103 }, 104 }, 105 } 106 ) 107 108 func baseTestConfig(t *testing.T) ( 109 config *Config, 110 deviceUpdateCh chan []*structs.NodeDeviceResource, 111 catalog *loader.MockCatalog) { 112 113 // Create an update handler 114 deviceUpdates := make(chan []*structs.NodeDeviceResource, 1) 115 updateFn := func(devices []*structs.NodeDeviceResource) { 116 deviceUpdates <- devices 117 } 118 119 // Create a mock plugin catalog 120 mc := &loader.MockCatalog{} 121 122 // Create the config 123 logger := testlog.HCLogger(t) 124 config = &Config{ 125 Logger: logger, 126 PluginConfig: &base.AgentConfig{}, 127 StatsInterval: 100 * time.Millisecond, 128 State: state.NewMemDB(logger), 129 Updater: updateFn, 130 Loader: mc, 131 } 132 133 return config, deviceUpdates, mc 134 } 135 136 func configureCatalogWith(catalog *loader.MockCatalog, plugins map[*base.PluginInfoResponse]loader.PluginInstance) { 137 138 catalog.DispenseF = func(name, _ string, _ *base.AgentConfig, _ log.Logger) (loader.PluginInstance, error) { 139 for info, v := range plugins { 140 if info.Name == name { 141 return v, nil 142 } 143 } 144 145 return nil, fmt.Errorf("no matching plugin") 146 } 147 148 catalog.ReattachF = func(name, _ string, _ *plugin.ReattachConfig) (loader.PluginInstance, error) { 149 for info, v := range plugins { 150 if info.Name == name { 151 return v, nil 152 } 153 } 154 155 return nil, fmt.Errorf("no matching plugin") 156 } 157 158 catalog.CatalogF = func() map[string][]*base.PluginInfoResponse { 159 devices := make([]*base.PluginInfoResponse, 0, len(plugins)) 160 for k := range plugins { 161 devices = append(devices, k) 162 } 163 out := map[string][]*base.PluginInfoResponse{ 164 base.PluginTypeDevice: devices, 165 } 166 return out 167 } 168 } 169 170 func pluginInfoResponse(name string) *base.PluginInfoResponse { 171 return &base.PluginInfoResponse{ 172 Type: base.PluginTypeDevice, 173 PluginApiVersions: []string{"v0.0.1"}, 174 PluginVersion: "v0.0.1", 175 Name: name, 176 } 177 } 178 179 // drainNodeDeviceUpdates drains all updates to the node device fingerprint channel 180 func drainNodeDeviceUpdates(ctx context.Context, in chan []*structs.NodeDeviceResource) { 181 go func() { 182 for { 183 select { 184 case <-ctx.Done(): 185 return 186 case <-in: 187 } 188 } 189 }() 190 } 191 192 func deviceReserveFn(ids []string) (*device.ContainerReservation, error) { 193 return &device.ContainerReservation{ 194 Envs: map[string]string{ 195 "DEVICES": strings.Join(ids, ","), 196 }, 197 }, nil 198 } 199 200 // nvidiaAndIntelDefaultPlugins adds an nvidia and intel mock plugin to the 201 // catalog 202 func nvidiaAndIntelDefaultPlugins(catalog *loader.MockCatalog) { 203 pluginInfoNvidia := pluginInfoResponse("nvidia") 204 deviceNvidia := &device.MockDevicePlugin{ 205 MockPlugin: &base.MockPlugin{ 206 PluginInfoF: base.StaticInfo(pluginInfoNvidia), 207 ConfigSchemaF: base.TestConfigSchema(), 208 SetConfigF: base.NoopSetConfig(), 209 }, 210 FingerprintF: device.StaticFingerprinter([]*device.DeviceGroup{nvidiaDeviceGroup}), 211 ReserveF: deviceReserveFn, 212 StatsF: device.StaticStats([]*device.DeviceGroupStats{nvidiaDeviceGroupStats}), 213 } 214 pluginNvidia := loader.MockBasicExternalPlugin(deviceNvidia, device.ApiVersion010) 215 216 pluginInfoIntel := pluginInfoResponse("intel") 217 deviceIntel := &device.MockDevicePlugin{ 218 MockPlugin: &base.MockPlugin{ 219 PluginInfoF: base.StaticInfo(pluginInfoIntel), 220 ConfigSchemaF: base.TestConfigSchema(), 221 SetConfigF: base.NoopSetConfig(), 222 }, 223 FingerprintF: device.StaticFingerprinter([]*device.DeviceGroup{intelDeviceGroup}), 224 ReserveF: deviceReserveFn, 225 StatsF: device.StaticStats([]*device.DeviceGroupStats{intelDeviceGroupStats}), 226 } 227 pluginIntel := loader.MockBasicExternalPlugin(deviceIntel, device.ApiVersion010) 228 229 // Configure the catalog with two plugins 230 configureCatalogWith(catalog, map[*base.PluginInfoResponse]loader.PluginInstance{ 231 pluginInfoNvidia: pluginNvidia, 232 pluginInfoIntel: pluginIntel, 233 }) 234 } 235 236 // Test collecting statistics from all devices 237 func TestManager_AllStats(t *testing.T) { 238 ci.Parallel(t) 239 require := require.New(t) 240 241 config, _, catalog := baseTestConfig(t) 242 nvidiaAndIntelDefaultPlugins(catalog) 243 244 m := New(config) 245 m.Run() 246 defer m.Shutdown() 247 require.Len(m.instances, 2) 248 249 // Wait till we get a fingerprint result 250 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 251 defer cancel() 252 <-m.WaitForFirstFingerprint(ctx) 253 require.NoError(ctx.Err()) 254 255 // Now collect all the stats 256 var stats []*device.DeviceGroupStats 257 testutil.WaitForResult(func() (bool, error) { 258 stats = m.AllStats() 259 l := len(stats) 260 if l == 2 { 261 return true, nil 262 } 263 264 return false, fmt.Errorf("expected count 2; got %d", l) 265 }, func(err error) { 266 t.Fatal(err) 267 }) 268 269 // Check we got stats from both the devices 270 var nstats, istats bool 271 for _, stat := range stats { 272 switch stat.Vendor { 273 case "intel": 274 istats = true 275 case "nvidia": 276 nstats = true 277 default: 278 t.Fatalf("unexpected vendor %q", stat.Vendor) 279 } 280 } 281 require.True(nstats) 282 require.True(istats) 283 } 284 285 // Test collecting statistics from a particular device 286 func TestManager_DeviceStats(t *testing.T) { 287 ci.Parallel(t) 288 require := require.New(t) 289 290 config, _, catalog := baseTestConfig(t) 291 nvidiaAndIntelDefaultPlugins(catalog) 292 293 m := New(config) 294 m.Run() 295 defer m.Shutdown() 296 297 // Wait till we get a fingerprint result 298 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 299 defer cancel() 300 <-m.WaitForFirstFingerprint(ctx) 301 require.NoError(ctx.Err()) 302 303 testutil.WaitForResult(func() (bool, error) { 304 stats := m.AllStats() 305 l := len(stats) 306 if l == 2 { 307 return true, nil 308 } 309 310 return false, fmt.Errorf("expected count 2; got %d", l) 311 }, func(err error) { 312 t.Fatal(err) 313 }) 314 315 // Now collect the stats for one nvidia device 316 stat, err := m.DeviceStats(&structs.AllocatedDeviceResource{ 317 Vendor: "nvidia", 318 Type: "gpu", 319 Name: "1080ti", 320 DeviceIDs: []string{nvidiaDevice1ID}, 321 }) 322 require.NoError(err) 323 require.NotNil(stat) 324 325 require.Len(stat.InstanceStats, 1) 326 require.Contains(stat.InstanceStats, nvidiaDevice1ID) 327 328 istat := stat.InstanceStats[nvidiaDevice1ID] 329 require.EqualValues(218, *istat.Summary.IntNumeratorVal) 330 } 331 332 // Test reserving a particular device 333 func TestManager_Reserve(t *testing.T) { 334 ci.Parallel(t) 335 r := require.New(t) 336 337 config, _, catalog := baseTestConfig(t) 338 nvidiaAndIntelDefaultPlugins(catalog) 339 340 m := New(config) 341 m.Run() 342 defer m.Shutdown() 343 344 // Wait till we get a fingerprint result 345 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 346 defer cancel() 347 <-m.WaitForFirstFingerprint(ctx) 348 r.NoError(ctx.Err()) 349 350 cases := []struct { 351 in *structs.AllocatedDeviceResource 352 expected string 353 err bool 354 }{ 355 { 356 in: &structs.AllocatedDeviceResource{ 357 Vendor: "nvidia", 358 Type: "gpu", 359 Name: "1080ti", 360 DeviceIDs: []string{nvidiaDevice1ID}, 361 }, 362 expected: nvidiaDevice1ID, 363 }, 364 { 365 in: &structs.AllocatedDeviceResource{ 366 Vendor: "nvidia", 367 Type: "gpu", 368 Name: "1080ti", 369 DeviceIDs: []string{nvidiaDevice0ID}, 370 }, 371 expected: nvidiaDevice0ID, 372 }, 373 { 374 in: &structs.AllocatedDeviceResource{ 375 Vendor: "nvidia", 376 Type: "gpu", 377 Name: "1080ti", 378 DeviceIDs: []string{nvidiaDevice0ID, nvidiaDevice1ID}, 379 }, 380 expected: fmt.Sprintf("%s,%s", nvidiaDevice0ID, nvidiaDevice1ID), 381 }, 382 { 383 in: &structs.AllocatedDeviceResource{ 384 Vendor: "nvidia", 385 Type: "gpu", 386 Name: "1080ti", 387 DeviceIDs: []string{nvidiaDevice0ID, nvidiaDevice1ID, "foo"}, 388 }, 389 err: true, 390 }, 391 { 392 in: &structs.AllocatedDeviceResource{ 393 Vendor: "intel", 394 Type: "gpu", 395 Name: "640GT", 396 DeviceIDs: []string{intelDeviceID}, 397 }, 398 expected: intelDeviceID, 399 }, 400 { 401 in: &structs.AllocatedDeviceResource{ 402 Vendor: "intel", 403 Type: "gpu", 404 Name: "foo", 405 DeviceIDs: []string{intelDeviceID}, 406 }, 407 err: true, 408 }, 409 } 410 411 for i, c := range cases { 412 t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { 413 r = require.New(t) 414 415 // Reserve a particular device 416 res, err := m.Reserve(c.in) 417 if !c.err { 418 r.NoError(err) 419 r.NotNil(res) 420 421 r.Len(res.Envs, 1) 422 r.Equal(res.Envs["DEVICES"], c.expected) 423 } else { 424 r.Error(err) 425 } 426 }) 427 } 428 } 429 430 // Test that shutdown shutsdown the plugins 431 func TestManager_Shutdown(t *testing.T) { 432 ci.Parallel(t) 433 require := require.New(t) 434 435 config, _, catalog := baseTestConfig(t) 436 nvidiaAndIntelDefaultPlugins(catalog) 437 438 m := New(config) 439 m.Run() 440 defer m.Shutdown() 441 442 // Wait till we get a fingerprint result 443 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 444 defer cancel() 445 <-m.WaitForFirstFingerprint(ctx) 446 require.NoError(ctx.Err()) 447 448 // Call shutdown and assert that we killed the plugins 449 m.Shutdown() 450 451 for _, resp := range catalog.Catalog()[base.PluginTypeDevice] { 452 pinst, _ := catalog.Dispense(resp.Name, resp.Type, &base.AgentConfig{}, config.Logger) 453 require.True(pinst.Exited()) 454 } 455 } 456 457 // Test that startup shutsdown previously launched plugins 458 func TestManager_Run_ShutdownOld(t *testing.T) { 459 ci.Parallel(t) 460 require := require.New(t) 461 462 config, _, catalog := baseTestConfig(t) 463 nvidiaAndIntelDefaultPlugins(catalog) 464 465 m := New(config) 466 m.Run() 467 defer m.Shutdown() 468 469 // Wait till we get a fingerprint result 470 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 471 defer cancel() 472 <-m.WaitForFirstFingerprint(ctx) 473 require.NoError(ctx.Err()) 474 475 // Create a new manager with the same config so that it reads the old state 476 m2 := New(config) 477 go m2.Run() 478 defer m2.Shutdown() 479 480 testutil.WaitForResult(func() (bool, error) { 481 for _, resp := range catalog.Catalog()[base.PluginTypeDevice] { 482 pinst, _ := catalog.Dispense(resp.Name, resp.Type, &base.AgentConfig{}, config.Logger) 483 if !pinst.Exited() { 484 return false, fmt.Errorf("plugin %q not shutdown", resp.Name) 485 } 486 } 487 488 return true, nil 489 }, func(err error) { 490 t.Fatal(err) 491 }) 492 }