github.com/ferranbt/nomad@v0.9.3-0.20190607002617-85c449b7667c/plugins/device/plugin_test.go (about) 1 package device 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 pb "github.com/golang/protobuf/proto" 10 plugin "github.com/hashicorp/go-plugin" 11 "github.com/hashicorp/nomad/helper" 12 "github.com/hashicorp/nomad/nomad/structs" 13 "github.com/hashicorp/nomad/plugins/base" 14 "github.com/hashicorp/nomad/plugins/shared/hclspec" 15 psstructs "github.com/hashicorp/nomad/plugins/shared/structs" 16 "github.com/hashicorp/nomad/testutil" 17 "github.com/stretchr/testify/require" 18 "github.com/zclconf/go-cty/cty" 19 "github.com/zclconf/go-cty/cty/msgpack" 20 "google.golang.org/grpc/status" 21 ) 22 23 func TestDevicePlugin_PluginInfo(t *testing.T) { 24 t.Parallel() 25 require := require.New(t) 26 27 var ( 28 apiVersions = []string{"v0.1.0", "v0.2.0"} 29 ) 30 31 const ( 32 pluginVersion = "v0.2.1" 33 pluginName = "mock_device" 34 ) 35 36 knownType := func() (*base.PluginInfoResponse, error) { 37 info := &base.PluginInfoResponse{ 38 Type: base.PluginTypeDevice, 39 PluginApiVersions: apiVersions, 40 PluginVersion: pluginVersion, 41 Name: pluginName, 42 } 43 return info, nil 44 } 45 unknownType := func() (*base.PluginInfoResponse, error) { 46 info := &base.PluginInfoResponse{ 47 Type: "bad", 48 PluginApiVersions: apiVersions, 49 PluginVersion: pluginVersion, 50 Name: pluginName, 51 } 52 return info, nil 53 } 54 55 mock := &MockDevicePlugin{ 56 MockPlugin: &base.MockPlugin{ 57 PluginInfoF: knownType, 58 }, 59 } 60 61 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 62 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 63 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 64 }) 65 defer server.Stop() 66 defer client.Close() 67 68 raw, err := client.Dispense(base.PluginTypeDevice) 69 if err != nil { 70 t.Fatalf("err: %s", err) 71 } 72 73 impl, ok := raw.(DevicePlugin) 74 if !ok { 75 t.Fatalf("bad: %#v", raw) 76 } 77 78 resp, err := impl.PluginInfo() 79 require.NoError(err) 80 require.Equal(apiVersions, resp.PluginApiVersions) 81 require.Equal(pluginVersion, resp.PluginVersion) 82 require.Equal(pluginName, resp.Name) 83 require.Equal(base.PluginTypeDevice, resp.Type) 84 85 // Swap the implementation to return an unknown type 86 mock.PluginInfoF = unknownType 87 _, err = impl.PluginInfo() 88 require.Error(err) 89 require.Contains(err.Error(), "unknown type") 90 } 91 92 func TestDevicePlugin_ConfigSchema(t *testing.T) { 93 t.Parallel() 94 require := require.New(t) 95 96 mock := &MockDevicePlugin{ 97 MockPlugin: &base.MockPlugin{ 98 ConfigSchemaF: func() (*hclspec.Spec, error) { 99 return base.TestSpec, nil 100 }, 101 }, 102 } 103 104 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 105 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 106 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 107 }) 108 defer server.Stop() 109 defer client.Close() 110 111 raw, err := client.Dispense(base.PluginTypeDevice) 112 if err != nil { 113 t.Fatalf("err: %s", err) 114 } 115 116 impl, ok := raw.(DevicePlugin) 117 if !ok { 118 t.Fatalf("bad: %#v", raw) 119 } 120 121 specOut, err := impl.ConfigSchema() 122 require.NoError(err) 123 require.True(pb.Equal(base.TestSpec, specOut)) 124 } 125 126 func TestDevicePlugin_SetConfig(t *testing.T) { 127 t.Parallel() 128 require := require.New(t) 129 130 var receivedData []byte 131 mock := &MockDevicePlugin{ 132 MockPlugin: &base.MockPlugin{ 133 PluginInfoF: func() (*base.PluginInfoResponse, error) { 134 return &base.PluginInfoResponse{ 135 Type: base.PluginTypeDevice, 136 PluginApiVersions: []string{"v0.0.1"}, 137 PluginVersion: "v0.0.1", 138 Name: "mock_device", 139 }, nil 140 }, 141 ConfigSchemaF: func() (*hclspec.Spec, error) { 142 return base.TestSpec, nil 143 }, 144 SetConfigF: func(cfg *base.Config) error { 145 receivedData = cfg.PluginConfig 146 return nil 147 }, 148 }, 149 } 150 151 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 152 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 153 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 154 }) 155 defer server.Stop() 156 defer client.Close() 157 158 raw, err := client.Dispense(base.PluginTypeDevice) 159 if err != nil { 160 t.Fatalf("err: %s", err) 161 } 162 163 impl, ok := raw.(DevicePlugin) 164 if !ok { 165 t.Fatalf("bad: %#v", raw) 166 } 167 168 config := cty.ObjectVal(map[string]cty.Value{ 169 "foo": cty.StringVal("v1"), 170 "bar": cty.NumberIntVal(1337), 171 "baz": cty.BoolVal(true), 172 }) 173 cdata, err := msgpack.Marshal(config, config.Type()) 174 require.NoError(err) 175 require.NoError(impl.SetConfig(&base.Config{PluginConfig: cdata})) 176 require.Equal(cdata, receivedData) 177 178 // Decode the value back 179 var actual base.TestConfig 180 require.NoError(structs.Decode(receivedData, &actual)) 181 require.Equal("v1", actual.Foo) 182 require.EqualValues(1337, actual.Bar) 183 require.True(actual.Baz) 184 } 185 186 func TestDevicePlugin_Fingerprint(t *testing.T) { 187 t.Parallel() 188 require := require.New(t) 189 190 devices1 := []*DeviceGroup{ 191 { 192 Vendor: "nvidia", 193 Type: DeviceTypeGPU, 194 Name: "foo", 195 Attributes: map[string]*psstructs.Attribute{ 196 "memory": { 197 Int: helper.Int64ToPtr(4), 198 Unit: "GiB", 199 }, 200 }, 201 }, 202 } 203 devices2 := []*DeviceGroup{ 204 { 205 Vendor: "nvidia", 206 Type: DeviceTypeGPU, 207 Name: "foo", 208 }, 209 { 210 Vendor: "nvidia", 211 Type: DeviceTypeGPU, 212 Name: "bar", 213 }, 214 } 215 216 mock := &MockDevicePlugin{ 217 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 218 outCh := make(chan *FingerprintResponse, 1) 219 go func() { 220 // Send two messages 221 for _, devs := range [][]*DeviceGroup{devices1, devices2} { 222 select { 223 case <-ctx.Done(): 224 return 225 case outCh <- &FingerprintResponse{Devices: devs}: 226 } 227 } 228 close(outCh) 229 return 230 }() 231 return outCh, nil 232 }, 233 } 234 235 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 236 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 237 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 238 }) 239 defer server.Stop() 240 defer client.Close() 241 242 raw, err := client.Dispense(base.PluginTypeDevice) 243 if err != nil { 244 t.Fatalf("err: %s", err) 245 } 246 247 impl, ok := raw.(DevicePlugin) 248 if !ok { 249 t.Fatalf("bad: %#v", raw) 250 } 251 252 // Create a context 253 ctx, cancel := context.WithCancel(context.Background()) 254 defer cancel() 255 256 // Get the stream 257 stream, err := impl.Fingerprint(ctx) 258 require.NoError(err) 259 260 // Get the first message 261 var first *FingerprintResponse 262 select { 263 case <-time.After(1 * time.Second): 264 t.Fatal("timeout") 265 case first = <-stream: 266 } 267 268 require.NoError(first.Error) 269 require.EqualValues(devices1, first.Devices) 270 271 // Get the second message 272 var second *FingerprintResponse 273 select { 274 case <-time.After(1 * time.Second): 275 t.Fatal("timeout") 276 case second = <-stream: 277 } 278 279 require.NoError(second.Error) 280 require.EqualValues(devices2, second.Devices) 281 282 select { 283 case _, ok := <-stream: 284 require.False(ok) 285 case <-time.After(1 * time.Second): 286 t.Fatal("stream should be closed") 287 } 288 } 289 290 func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) { 291 t.Parallel() 292 require := require.New(t) 293 294 ferr := fmt.Errorf("mock fingerprinting failed") 295 mock := &MockDevicePlugin{ 296 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 297 outCh := make(chan *FingerprintResponse, 1) 298 go func() { 299 // Send the error 300 select { 301 case <-ctx.Done(): 302 return 303 case outCh <- &FingerprintResponse{Error: ferr}: 304 } 305 306 close(outCh) 307 return 308 }() 309 return outCh, nil 310 }, 311 } 312 313 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 314 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 315 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 316 }) 317 defer server.Stop() 318 defer client.Close() 319 320 raw, err := client.Dispense(base.PluginTypeDevice) 321 if err != nil { 322 t.Fatalf("err: %s", err) 323 } 324 325 impl, ok := raw.(DevicePlugin) 326 if !ok { 327 t.Fatalf("bad: %#v", raw) 328 } 329 330 // Create a context 331 ctx, cancel := context.WithCancel(context.Background()) 332 defer cancel() 333 334 // Get the stream 335 stream, err := impl.Fingerprint(ctx) 336 require.NoError(err) 337 338 // Get the first message 339 var first *FingerprintResponse 340 select { 341 case <-time.After(1 * time.Second): 342 t.Fatal("timeout") 343 case first = <-stream: 344 } 345 346 errStatus := status.Convert(ferr) 347 require.EqualError(first.Error, errStatus.Err().Error()) 348 } 349 350 func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) { 351 t.Parallel() 352 require := require.New(t) 353 354 mock := &MockDevicePlugin{ 355 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 356 outCh := make(chan *FingerprintResponse, 1) 357 go func() { 358 <-ctx.Done() 359 close(outCh) 360 return 361 }() 362 return outCh, nil 363 }, 364 } 365 366 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 367 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 368 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 369 }) 370 defer server.Stop() 371 defer client.Close() 372 373 raw, err := client.Dispense(base.PluginTypeDevice) 374 if err != nil { 375 t.Fatalf("err: %s", err) 376 } 377 378 impl, ok := raw.(DevicePlugin) 379 if !ok { 380 t.Fatalf("bad: %#v", raw) 381 } 382 383 // Create a context 384 ctx, cancel := context.WithCancel(context.Background()) 385 386 // Get the stream 387 stream, err := impl.Fingerprint(ctx) 388 require.NoError(err) 389 390 // Get the first message 391 select { 392 case <-time.After(testutil.Timeout(10 * time.Millisecond)): 393 case _ = <-stream: 394 t.Fatal("bad value") 395 } 396 397 // Cancel the context 398 cancel() 399 400 // Make sure we are done 401 select { 402 case <-time.After(100 * time.Millisecond): 403 t.Fatalf("timeout") 404 case v := <-stream: 405 require.Error(v.Error) 406 require.EqualError(v.Error, context.Canceled.Error()) 407 } 408 } 409 410 func TestDevicePlugin_Reserve(t *testing.T) { 411 t.Parallel() 412 require := require.New(t) 413 414 reservation := &ContainerReservation{ 415 Envs: map[string]string{ 416 "foo": "bar", 417 }, 418 Mounts: []*Mount{ 419 { 420 TaskPath: "foo", 421 HostPath: "bar", 422 ReadOnly: true, 423 }, 424 }, 425 Devices: []*DeviceSpec{ 426 { 427 TaskPath: "foo", 428 HostPath: "bar", 429 CgroupPerms: "rx", 430 }, 431 }, 432 } 433 434 var received []string 435 mock := &MockDevicePlugin{ 436 ReserveF: func(devices []string) (*ContainerReservation, error) { 437 received = devices 438 return reservation, nil 439 }, 440 } 441 442 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 443 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 444 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 445 }) 446 defer server.Stop() 447 defer client.Close() 448 449 raw, err := client.Dispense(base.PluginTypeDevice) 450 if err != nil { 451 t.Fatalf("err: %s", err) 452 } 453 454 impl, ok := raw.(DevicePlugin) 455 if !ok { 456 t.Fatalf("bad: %#v", raw) 457 } 458 459 req := []string{"a", "b"} 460 containerRes, err := impl.Reserve(req) 461 require.NoError(err) 462 require.EqualValues(req, received) 463 require.EqualValues(reservation, containerRes) 464 } 465 466 func TestDevicePlugin_Stats(t *testing.T) { 467 t.Parallel() 468 require := require.New(t) 469 470 devices1 := []*DeviceGroupStats{ 471 { 472 Vendor: "nvidia", 473 Type: DeviceTypeGPU, 474 Name: "foo", 475 InstanceStats: map[string]*DeviceStats{ 476 "1": { 477 Summary: &psstructs.StatValue{ 478 IntNumeratorVal: helper.Int64ToPtr(10), 479 IntDenominatorVal: helper.Int64ToPtr(20), 480 Unit: "MB", 481 Desc: "Unit test", 482 }, 483 }, 484 }, 485 }, 486 } 487 devices2 := []*DeviceGroupStats{ 488 { 489 Vendor: "nvidia", 490 Type: DeviceTypeGPU, 491 Name: "foo", 492 InstanceStats: map[string]*DeviceStats{ 493 "1": { 494 Summary: &psstructs.StatValue{ 495 FloatNumeratorVal: helper.Float64ToPtr(10.0), 496 FloatDenominatorVal: helper.Float64ToPtr(20.0), 497 Unit: "MB", 498 Desc: "Unit test", 499 }, 500 }, 501 }, 502 }, 503 { 504 Vendor: "nvidia", 505 Type: DeviceTypeGPU, 506 Name: "bar", 507 InstanceStats: map[string]*DeviceStats{ 508 "1": { 509 Summary: &psstructs.StatValue{ 510 StringVal: helper.StringToPtr("foo"), 511 Unit: "MB", 512 Desc: "Unit test", 513 }, 514 }, 515 }, 516 }, 517 { 518 Vendor: "nvidia", 519 Type: DeviceTypeGPU, 520 Name: "baz", 521 InstanceStats: map[string]*DeviceStats{ 522 "1": { 523 Summary: &psstructs.StatValue{ 524 BoolVal: helper.BoolToPtr(true), 525 Unit: "MB", 526 Desc: "Unit test", 527 }, 528 }, 529 }, 530 }, 531 } 532 533 mock := &MockDevicePlugin{ 534 StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { 535 outCh := make(chan *StatsResponse, 1) 536 go func() { 537 // Send two messages 538 for _, devs := range [][]*DeviceGroupStats{devices1, devices2} { 539 select { 540 case <-ctx.Done(): 541 return 542 case outCh <- &StatsResponse{Groups: devs}: 543 } 544 } 545 close(outCh) 546 return 547 }() 548 return outCh, nil 549 }, 550 } 551 552 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 553 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 554 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 555 }) 556 defer server.Stop() 557 defer client.Close() 558 559 raw, err := client.Dispense(base.PluginTypeDevice) 560 if err != nil { 561 t.Fatalf("err: %s", err) 562 } 563 564 impl, ok := raw.(DevicePlugin) 565 if !ok { 566 t.Fatalf("bad: %#v", raw) 567 } 568 569 // Create a context 570 ctx, cancel := context.WithCancel(context.Background()) 571 defer cancel() 572 573 // Get the stream 574 stream, err := impl.Stats(ctx, time.Millisecond) 575 require.NoError(err) 576 577 // Get the first message 578 var first *StatsResponse 579 select { 580 case <-time.After(1 * time.Second): 581 t.Fatal("timeout") 582 case first = <-stream: 583 } 584 585 require.NoError(first.Error) 586 require.EqualValues(devices1, first.Groups) 587 588 // Get the second message 589 var second *StatsResponse 590 select { 591 case <-time.After(1 * time.Second): 592 t.Fatal("timeout") 593 case second = <-stream: 594 } 595 596 require.NoError(second.Error) 597 require.EqualValues(devices2, second.Groups) 598 599 select { 600 case _, ok := <-stream: 601 require.False(ok) 602 case <-time.After(1 * time.Second): 603 t.Fatal("stream should be closed") 604 } 605 } 606 607 func TestDevicePlugin_Stats_StreamErr(t *testing.T) { 608 t.Parallel() 609 require := require.New(t) 610 611 ferr := fmt.Errorf("mock stats failed") 612 mock := &MockDevicePlugin{ 613 StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { 614 outCh := make(chan *StatsResponse, 1) 615 go func() { 616 // Send the error 617 select { 618 case <-ctx.Done(): 619 return 620 case outCh <- &StatsResponse{Error: ferr}: 621 } 622 623 close(outCh) 624 return 625 }() 626 return outCh, nil 627 }, 628 } 629 630 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 631 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 632 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 633 }) 634 defer server.Stop() 635 defer client.Close() 636 637 raw, err := client.Dispense(base.PluginTypeDevice) 638 if err != nil { 639 t.Fatalf("err: %s", err) 640 } 641 642 impl, ok := raw.(DevicePlugin) 643 if !ok { 644 t.Fatalf("bad: %#v", raw) 645 } 646 647 // Create a context 648 ctx, cancel := context.WithCancel(context.Background()) 649 defer cancel() 650 651 // Get the stream 652 stream, err := impl.Stats(ctx, time.Millisecond) 653 require.NoError(err) 654 655 // Get the first message 656 var first *StatsResponse 657 select { 658 case <-time.After(1 * time.Second): 659 t.Fatal("timeout") 660 case first = <-stream: 661 } 662 663 errStatus := status.Convert(ferr) 664 require.EqualError(first.Error, errStatus.Err().Error()) 665 } 666 667 func TestDevicePlugin_Stats_CancelCtx(t *testing.T) { 668 t.Parallel() 669 require := require.New(t) 670 671 mock := &MockDevicePlugin{ 672 StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { 673 outCh := make(chan *StatsResponse, 1) 674 go func() { 675 <-ctx.Done() 676 close(outCh) 677 return 678 }() 679 return outCh, nil 680 }, 681 } 682 683 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 684 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 685 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 686 }) 687 defer server.Stop() 688 defer client.Close() 689 690 raw, err := client.Dispense(base.PluginTypeDevice) 691 if err != nil { 692 t.Fatalf("err: %s", err) 693 } 694 695 impl, ok := raw.(DevicePlugin) 696 if !ok { 697 t.Fatalf("bad: %#v", raw) 698 } 699 700 // Create a context 701 ctx, cancel := context.WithCancel(context.Background()) 702 703 // Get the stream 704 stream, err := impl.Stats(ctx, time.Millisecond) 705 require.NoError(err) 706 707 // Get the first message 708 select { 709 case <-time.After(testutil.Timeout(10 * time.Millisecond)): 710 case _ = <-stream: 711 t.Fatal("bad value") 712 } 713 714 // Cancel the context 715 cancel() 716 717 // Make sure we are done 718 select { 719 case <-time.After(100 * time.Millisecond): 720 t.Fatalf("timeout") 721 case v := <-stream: 722 require.Error(v.Error) 723 require.EqualError(v.Error, context.Canceled.Error()) 724 } 725 }