github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/plugins/device/plugin_test.go (about) 1 package device 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 pb "github.com/golang/protobuf/proto" 10 plugin "github.com/hashicorp/go-plugin" 11 "github.com/hashicorp/nomad/ci" 12 "github.com/hashicorp/nomad/helper/pointer" 13 "github.com/hashicorp/nomad/nomad/structs" 14 "github.com/hashicorp/nomad/plugins/base" 15 "github.com/hashicorp/nomad/plugins/shared/hclspec" 16 psstructs "github.com/hashicorp/nomad/plugins/shared/structs" 17 "github.com/hashicorp/nomad/testutil" 18 "github.com/stretchr/testify/require" 19 "github.com/zclconf/go-cty/cty" 20 "github.com/zclconf/go-cty/cty/msgpack" 21 "google.golang.org/grpc/status" 22 ) 23 24 func TestDevicePlugin_PluginInfo(t *testing.T) { 25 ci.Parallel(t) 26 require := require.New(t) 27 28 var ( 29 apiVersions = []string{"v0.1.0", "v0.2.0"} 30 ) 31 32 const ( 33 pluginVersion = "v0.2.1" 34 pluginName = "mock_device" 35 ) 36 37 knownType := func() (*base.PluginInfoResponse, error) { 38 info := &base.PluginInfoResponse{ 39 Type: base.PluginTypeDevice, 40 PluginApiVersions: apiVersions, 41 PluginVersion: pluginVersion, 42 Name: pluginName, 43 } 44 return info, nil 45 } 46 unknownType := func() (*base.PluginInfoResponse, error) { 47 info := &base.PluginInfoResponse{ 48 Type: "bad", 49 PluginApiVersions: apiVersions, 50 PluginVersion: pluginVersion, 51 Name: pluginName, 52 } 53 return info, nil 54 } 55 56 mock := &MockDevicePlugin{ 57 MockPlugin: &base.MockPlugin{ 58 PluginInfoF: knownType, 59 }, 60 } 61 62 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 63 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 64 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 65 }) 66 defer server.Stop() 67 defer client.Close() 68 69 raw, err := client.Dispense(base.PluginTypeDevice) 70 if err != nil { 71 t.Fatalf("err: %s", err) 72 } 73 74 impl, ok := raw.(DevicePlugin) 75 if !ok { 76 t.Fatalf("bad: %#v", raw) 77 } 78 79 resp, err := impl.PluginInfo() 80 require.NoError(err) 81 require.Equal(apiVersions, resp.PluginApiVersions) 82 require.Equal(pluginVersion, resp.PluginVersion) 83 require.Equal(pluginName, resp.Name) 84 require.Equal(base.PluginTypeDevice, resp.Type) 85 86 // Swap the implementation to return an unknown type 87 mock.PluginInfoF = unknownType 88 _, err = impl.PluginInfo() 89 require.Error(err) 90 require.Contains(err.Error(), "unknown type") 91 } 92 93 func TestDevicePlugin_ConfigSchema(t *testing.T) { 94 ci.Parallel(t) 95 require := require.New(t) 96 97 mock := &MockDevicePlugin{ 98 MockPlugin: &base.MockPlugin{ 99 ConfigSchemaF: func() (*hclspec.Spec, error) { 100 return base.TestSpec, nil 101 }, 102 }, 103 } 104 105 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 106 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 107 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 108 }) 109 defer server.Stop() 110 defer client.Close() 111 112 raw, err := client.Dispense(base.PluginTypeDevice) 113 if err != nil { 114 t.Fatalf("err: %s", err) 115 } 116 117 impl, ok := raw.(DevicePlugin) 118 if !ok { 119 t.Fatalf("bad: %#v", raw) 120 } 121 122 specOut, err := impl.ConfigSchema() 123 require.NoError(err) 124 require.True(pb.Equal(base.TestSpec, specOut)) 125 } 126 127 func TestDevicePlugin_SetConfig(t *testing.T) { 128 ci.Parallel(t) 129 require := require.New(t) 130 131 var receivedData []byte 132 mock := &MockDevicePlugin{ 133 MockPlugin: &base.MockPlugin{ 134 PluginInfoF: func() (*base.PluginInfoResponse, error) { 135 return &base.PluginInfoResponse{ 136 Type: base.PluginTypeDevice, 137 PluginApiVersions: []string{"v0.0.1"}, 138 PluginVersion: "v0.0.1", 139 Name: "mock_device", 140 }, nil 141 }, 142 ConfigSchemaF: func() (*hclspec.Spec, error) { 143 return base.TestSpec, nil 144 }, 145 SetConfigF: func(cfg *base.Config) error { 146 receivedData = cfg.PluginConfig 147 return nil 148 }, 149 }, 150 } 151 152 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 153 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 154 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 155 }) 156 defer server.Stop() 157 defer client.Close() 158 159 raw, err := client.Dispense(base.PluginTypeDevice) 160 if err != nil { 161 t.Fatalf("err: %s", err) 162 } 163 164 impl, ok := raw.(DevicePlugin) 165 if !ok { 166 t.Fatalf("bad: %#v", raw) 167 } 168 169 config := cty.ObjectVal(map[string]cty.Value{ 170 "foo": cty.StringVal("v1"), 171 "bar": cty.NumberIntVal(1337), 172 "baz": cty.BoolVal(true), 173 }) 174 cdata, err := msgpack.Marshal(config, config.Type()) 175 require.NoError(err) 176 require.NoError(impl.SetConfig(&base.Config{PluginConfig: cdata})) 177 require.Equal(cdata, receivedData) 178 179 // Decode the value back 180 var actual base.TestConfig 181 require.NoError(structs.Decode(receivedData, &actual)) 182 require.Equal("v1", actual.Foo) 183 require.EqualValues(1337, actual.Bar) 184 require.True(actual.Baz) 185 } 186 187 func TestDevicePlugin_Fingerprint(t *testing.T) { 188 ci.Parallel(t) 189 require := require.New(t) 190 191 devices1 := []*DeviceGroup{ 192 { 193 Vendor: "nvidia", 194 Type: DeviceTypeGPU, 195 Name: "foo", 196 Attributes: map[string]*psstructs.Attribute{ 197 "memory": { 198 Int: pointer.Of(int64(4)), 199 Unit: "GiB", 200 }, 201 }, 202 }, 203 } 204 devices2 := []*DeviceGroup{ 205 { 206 Vendor: "nvidia", 207 Type: DeviceTypeGPU, 208 Name: "foo", 209 }, 210 { 211 Vendor: "nvidia", 212 Type: DeviceTypeGPU, 213 Name: "bar", 214 }, 215 } 216 217 mock := &MockDevicePlugin{ 218 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 219 outCh := make(chan *FingerprintResponse, 1) 220 go func() { 221 // Send two messages 222 for _, devs := range [][]*DeviceGroup{devices1, devices2} { 223 select { 224 case <-ctx.Done(): 225 return 226 case outCh <- &FingerprintResponse{Devices: devs}: 227 } 228 } 229 close(outCh) 230 return 231 }() 232 return outCh, nil 233 }, 234 } 235 236 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 237 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 238 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 239 }) 240 defer server.Stop() 241 defer client.Close() 242 243 raw, err := client.Dispense(base.PluginTypeDevice) 244 if err != nil { 245 t.Fatalf("err: %s", err) 246 } 247 248 impl, ok := raw.(DevicePlugin) 249 if !ok { 250 t.Fatalf("bad: %#v", raw) 251 } 252 253 // Create a context 254 ctx, cancel := context.WithCancel(context.Background()) 255 defer cancel() 256 257 // Get the stream 258 stream, err := impl.Fingerprint(ctx) 259 require.NoError(err) 260 261 // Get the first message 262 var first *FingerprintResponse 263 select { 264 case <-time.After(1 * time.Second): 265 t.Fatal("timeout") 266 case first = <-stream: 267 } 268 269 require.NoError(first.Error) 270 require.EqualValues(devices1, first.Devices) 271 272 // Get the second message 273 var second *FingerprintResponse 274 select { 275 case <-time.After(1 * time.Second): 276 t.Fatal("timeout") 277 case second = <-stream: 278 } 279 280 require.NoError(second.Error) 281 require.EqualValues(devices2, second.Devices) 282 283 select { 284 case _, ok := <-stream: 285 require.False(ok) 286 case <-time.After(1 * time.Second): 287 t.Fatal("stream should be closed") 288 } 289 } 290 291 func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) { 292 ci.Parallel(t) 293 require := require.New(t) 294 295 ferr := fmt.Errorf("mock fingerprinting failed") 296 mock := &MockDevicePlugin{ 297 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 298 outCh := make(chan *FingerprintResponse, 1) 299 go func() { 300 // Send the error 301 select { 302 case <-ctx.Done(): 303 return 304 case outCh <- &FingerprintResponse{Error: ferr}: 305 } 306 307 close(outCh) 308 return 309 }() 310 return outCh, nil 311 }, 312 } 313 314 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 315 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 316 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 317 }) 318 defer server.Stop() 319 defer client.Close() 320 321 raw, err := client.Dispense(base.PluginTypeDevice) 322 if err != nil { 323 t.Fatalf("err: %s", err) 324 } 325 326 impl, ok := raw.(DevicePlugin) 327 if !ok { 328 t.Fatalf("bad: %#v", raw) 329 } 330 331 // Create a context 332 ctx, cancel := context.WithCancel(context.Background()) 333 defer cancel() 334 335 // Get the stream 336 stream, err := impl.Fingerprint(ctx) 337 require.NoError(err) 338 339 // Get the first message 340 var first *FingerprintResponse 341 select { 342 case <-time.After(1 * time.Second): 343 t.Fatal("timeout") 344 case first = <-stream: 345 } 346 347 errStatus := status.Convert(ferr) 348 require.EqualError(first.Error, errStatus.Err().Error()) 349 } 350 351 func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) { 352 ci.Parallel(t) 353 require := require.New(t) 354 355 mock := &MockDevicePlugin{ 356 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 357 outCh := make(chan *FingerprintResponse, 1) 358 go func() { 359 <-ctx.Done() 360 close(outCh) 361 return 362 }() 363 return outCh, nil 364 }, 365 } 366 367 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 368 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 369 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 370 }) 371 defer server.Stop() 372 defer client.Close() 373 374 raw, err := client.Dispense(base.PluginTypeDevice) 375 if err != nil { 376 t.Fatalf("err: %s", err) 377 } 378 379 impl, ok := raw.(DevicePlugin) 380 if !ok { 381 t.Fatalf("bad: %#v", raw) 382 } 383 384 // Create a context 385 ctx, cancel := context.WithCancel(context.Background()) 386 387 // Get the stream 388 stream, err := impl.Fingerprint(ctx) 389 require.NoError(err) 390 391 // Get the first message 392 select { 393 case <-time.After(testutil.Timeout(10 * time.Millisecond)): 394 case _ = <-stream: 395 t.Fatal("bad value") 396 } 397 398 // Cancel the context 399 cancel() 400 401 // Make sure we are done 402 select { 403 case <-time.After(100 * time.Millisecond): 404 t.Fatalf("timeout") 405 case v := <-stream: 406 require.Error(v.Error) 407 require.EqualError(v.Error, context.Canceled.Error()) 408 } 409 } 410 411 func TestDevicePlugin_Reserve(t *testing.T) { 412 ci.Parallel(t) 413 require := require.New(t) 414 415 reservation := &ContainerReservation{ 416 Envs: map[string]string{ 417 "foo": "bar", 418 }, 419 Mounts: []*Mount{ 420 { 421 TaskPath: "foo", 422 HostPath: "bar", 423 ReadOnly: true, 424 }, 425 }, 426 Devices: []*DeviceSpec{ 427 { 428 TaskPath: "foo", 429 HostPath: "bar", 430 CgroupPerms: "rx", 431 }, 432 }, 433 } 434 435 var received []string 436 mock := &MockDevicePlugin{ 437 ReserveF: func(devices []string) (*ContainerReservation, error) { 438 received = devices 439 return reservation, nil 440 }, 441 } 442 443 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 444 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 445 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 446 }) 447 defer server.Stop() 448 defer client.Close() 449 450 raw, err := client.Dispense(base.PluginTypeDevice) 451 if err != nil { 452 t.Fatalf("err: %s", err) 453 } 454 455 impl, ok := raw.(DevicePlugin) 456 if !ok { 457 t.Fatalf("bad: %#v", raw) 458 } 459 460 req := []string{"a", "b"} 461 containerRes, err := impl.Reserve(req) 462 require.NoError(err) 463 require.EqualValues(req, received) 464 require.EqualValues(reservation, containerRes) 465 } 466 467 func TestDevicePlugin_Stats(t *testing.T) { 468 ci.Parallel(t) 469 require := require.New(t) 470 471 devices1 := []*DeviceGroupStats{ 472 { 473 Vendor: "nvidia", 474 Type: DeviceTypeGPU, 475 Name: "foo", 476 InstanceStats: map[string]*DeviceStats{ 477 "1": { 478 Summary: &psstructs.StatValue{ 479 IntNumeratorVal: pointer.Of(int64(10)), 480 IntDenominatorVal: pointer.Of(int64(20)), 481 Unit: "MB", 482 Desc: "Unit test", 483 }, 484 }, 485 }, 486 }, 487 } 488 devices2 := []*DeviceGroupStats{ 489 { 490 Vendor: "nvidia", 491 Type: DeviceTypeGPU, 492 Name: "foo", 493 InstanceStats: map[string]*DeviceStats{ 494 "1": { 495 Summary: &psstructs.StatValue{ 496 FloatNumeratorVal: pointer.Of(float64(10.0)), 497 FloatDenominatorVal: pointer.Of(float64(20.0)), 498 Unit: "MB", 499 Desc: "Unit test", 500 }, 501 }, 502 }, 503 }, 504 { 505 Vendor: "nvidia", 506 Type: DeviceTypeGPU, 507 Name: "bar", 508 InstanceStats: map[string]*DeviceStats{ 509 "1": { 510 Summary: &psstructs.StatValue{ 511 StringVal: pointer.Of("foo"), 512 Unit: "MB", 513 Desc: "Unit test", 514 }, 515 }, 516 }, 517 }, 518 { 519 Vendor: "nvidia", 520 Type: DeviceTypeGPU, 521 Name: "baz", 522 InstanceStats: map[string]*DeviceStats{ 523 "1": { 524 Summary: &psstructs.StatValue{ 525 BoolVal: pointer.Of(true), 526 Unit: "MB", 527 Desc: "Unit test", 528 }, 529 }, 530 }, 531 }, 532 } 533 534 mock := &MockDevicePlugin{ 535 StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { 536 outCh := make(chan *StatsResponse, 1) 537 go func() { 538 // Send two messages 539 for _, devs := range [][]*DeviceGroupStats{devices1, devices2} { 540 select { 541 case <-ctx.Done(): 542 return 543 case outCh <- &StatsResponse{Groups: devs}: 544 } 545 } 546 close(outCh) 547 return 548 }() 549 return outCh, nil 550 }, 551 } 552 553 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 554 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 555 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 556 }) 557 defer server.Stop() 558 defer client.Close() 559 560 raw, err := client.Dispense(base.PluginTypeDevice) 561 if err != nil { 562 t.Fatalf("err: %s", err) 563 } 564 565 impl, ok := raw.(DevicePlugin) 566 if !ok { 567 t.Fatalf("bad: %#v", raw) 568 } 569 570 // Create a context 571 ctx, cancel := context.WithCancel(context.Background()) 572 defer cancel() 573 574 // Get the stream 575 stream, err := impl.Stats(ctx, time.Millisecond) 576 require.NoError(err) 577 578 // Get the first message 579 var first *StatsResponse 580 select { 581 case <-time.After(1 * time.Second): 582 t.Fatal("timeout") 583 case first = <-stream: 584 } 585 586 require.NoError(first.Error) 587 require.EqualValues(devices1, first.Groups) 588 589 // Get the second message 590 var second *StatsResponse 591 select { 592 case <-time.After(1 * time.Second): 593 t.Fatal("timeout") 594 case second = <-stream: 595 } 596 597 require.NoError(second.Error) 598 require.EqualValues(devices2, second.Groups) 599 600 select { 601 case _, ok := <-stream: 602 require.False(ok) 603 case <-time.After(1 * time.Second): 604 t.Fatal("stream should be closed") 605 } 606 } 607 608 func TestDevicePlugin_Stats_StreamErr(t *testing.T) { 609 ci.Parallel(t) 610 require := require.New(t) 611 612 ferr := fmt.Errorf("mock stats failed") 613 mock := &MockDevicePlugin{ 614 StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { 615 outCh := make(chan *StatsResponse, 1) 616 go func() { 617 // Send the error 618 select { 619 case <-ctx.Done(): 620 return 621 case outCh <- &StatsResponse{Error: ferr}: 622 } 623 624 close(outCh) 625 return 626 }() 627 return outCh, nil 628 }, 629 } 630 631 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 632 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 633 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 634 }) 635 defer server.Stop() 636 defer client.Close() 637 638 raw, err := client.Dispense(base.PluginTypeDevice) 639 if err != nil { 640 t.Fatalf("err: %s", err) 641 } 642 643 impl, ok := raw.(DevicePlugin) 644 if !ok { 645 t.Fatalf("bad: %#v", raw) 646 } 647 648 // Create a context 649 ctx, cancel := context.WithCancel(context.Background()) 650 defer cancel() 651 652 // Get the stream 653 stream, err := impl.Stats(ctx, time.Millisecond) 654 require.NoError(err) 655 656 // Get the first message 657 var first *StatsResponse 658 select { 659 case <-time.After(1 * time.Second): 660 t.Fatal("timeout") 661 case first = <-stream: 662 } 663 664 errStatus := status.Convert(ferr) 665 require.EqualError(first.Error, errStatus.Err().Error()) 666 } 667 668 func TestDevicePlugin_Stats_CancelCtx(t *testing.T) { 669 ci.Parallel(t) 670 require := require.New(t) 671 672 mock := &MockDevicePlugin{ 673 StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { 674 outCh := make(chan *StatsResponse, 1) 675 go func() { 676 <-ctx.Done() 677 close(outCh) 678 return 679 }() 680 return outCh, nil 681 }, 682 } 683 684 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 685 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 686 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 687 }) 688 defer server.Stop() 689 defer client.Close() 690 691 raw, err := client.Dispense(base.PluginTypeDevice) 692 if err != nil { 693 t.Fatalf("err: %s", err) 694 } 695 696 impl, ok := raw.(DevicePlugin) 697 if !ok { 698 t.Fatalf("bad: %#v", raw) 699 } 700 701 // Create a context 702 ctx, cancel := context.WithCancel(context.Background()) 703 704 // Get the stream 705 stream, err := impl.Stats(ctx, time.Millisecond) 706 require.NoError(err) 707 708 // Get the first message 709 select { 710 case <-time.After(testutil.Timeout(10 * time.Millisecond)): 711 case _ = <-stream: 712 t.Fatal("bad value") 713 } 714 715 // Cancel the context 716 cancel() 717 718 // Make sure we are done 719 select { 720 case <-time.After(100 * time.Millisecond): 721 t.Fatalf("timeout") 722 case v := <-stream: 723 require.Error(v.Error) 724 require.EqualError(v.Error, context.Canceled.Error()) 725 } 726 }