github.com/zoomfoo/nomad@v0.8.5-0.20180907175415-f28fd3a1a056/plugins/device/plugin_test.go (about) 1 package device 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 pb "github.com/golang/protobuf/proto" 10 plugin "github.com/hashicorp/go-plugin" 11 "github.com/hashicorp/nomad/nomad/structs" 12 "github.com/hashicorp/nomad/plugins/base" 13 "github.com/hashicorp/nomad/plugins/shared/hclspec" 14 "github.com/hashicorp/nomad/testutil" 15 "github.com/stretchr/testify/require" 16 "github.com/zclconf/go-cty/cty" 17 "github.com/zclconf/go-cty/cty/msgpack" 18 "google.golang.org/grpc/status" 19 ) 20 21 func TestDevicePlugin_PluginInfo(t *testing.T) { 22 t.Parallel() 23 require := require.New(t) 24 25 const ( 26 apiVersion = "v0.1.0" 27 pluginVersion = "v0.2.1" 28 pluginName = "mock" 29 ) 30 31 knownType := func() (*base.PluginInfoResponse, error) { 32 info := &base.PluginInfoResponse{ 33 Type: base.PluginTypeDevice, 34 PluginApiVersion: apiVersion, 35 PluginVersion: pluginVersion, 36 Name: pluginName, 37 } 38 return info, nil 39 } 40 unknownType := func() (*base.PluginInfoResponse, error) { 41 info := &base.PluginInfoResponse{ 42 Type: "bad", 43 PluginApiVersion: apiVersion, 44 PluginVersion: pluginVersion, 45 Name: pluginName, 46 } 47 return info, nil 48 } 49 50 mock := &MockDevicePlugin{ 51 MockPlugin: &base.MockPlugin{ 52 PluginInfoF: knownType, 53 }, 54 } 55 56 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 57 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 58 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 59 }) 60 defer server.Stop() 61 defer client.Close() 62 63 raw, err := client.Dispense(base.PluginTypeDevice) 64 if err != nil { 65 t.Fatalf("err: %s", err) 66 } 67 68 impl, ok := raw.(DevicePlugin) 69 if !ok { 70 t.Fatalf("bad: %#v", raw) 71 } 72 73 resp, err := impl.PluginInfo() 74 require.NoError(err) 75 require.Equal(apiVersion, resp.PluginApiVersion) 76 require.Equal(pluginVersion, resp.PluginVersion) 77 require.Equal(pluginName, resp.Name) 78 require.Equal(base.PluginTypeDevice, resp.Type) 79 80 // Swap the implementation to return an unknown type 81 mock.PluginInfoF = unknownType 82 _, err = impl.PluginInfo() 83 require.Error(err) 84 require.Contains(err.Error(), "unknown type") 85 } 86 87 func TestDevicePlugin_ConfigSchema(t *testing.T) { 88 t.Parallel() 89 require := require.New(t) 90 91 mock := &MockDevicePlugin{ 92 MockPlugin: &base.MockPlugin{ 93 ConfigSchemaF: func() (*hclspec.Spec, error) { 94 return base.TestSpec, nil 95 }, 96 }, 97 } 98 99 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 100 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 101 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 102 }) 103 defer server.Stop() 104 defer client.Close() 105 106 raw, err := client.Dispense(base.PluginTypeDevice) 107 if err != nil { 108 t.Fatalf("err: %s", err) 109 } 110 111 impl, ok := raw.(DevicePlugin) 112 if !ok { 113 t.Fatalf("bad: %#v", raw) 114 } 115 116 specOut, err := impl.ConfigSchema() 117 require.NoError(err) 118 require.True(pb.Equal(base.TestSpec, specOut)) 119 } 120 121 func TestDevicePlugin_SetConfig(t *testing.T) { 122 t.Parallel() 123 require := require.New(t) 124 125 var receivedData []byte 126 mock := &MockDevicePlugin{ 127 MockPlugin: &base.MockPlugin{ 128 ConfigSchemaF: func() (*hclspec.Spec, error) { 129 return base.TestSpec, nil 130 }, 131 SetConfigF: func(data []byte) error { 132 receivedData = data 133 return nil 134 }, 135 }, 136 } 137 138 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 139 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 140 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 141 }) 142 defer server.Stop() 143 defer client.Close() 144 145 raw, err := client.Dispense(base.PluginTypeDevice) 146 if err != nil { 147 t.Fatalf("err: %s", err) 148 } 149 150 impl, ok := raw.(DevicePlugin) 151 if !ok { 152 t.Fatalf("bad: %#v", raw) 153 } 154 155 config := cty.ObjectVal(map[string]cty.Value{ 156 "foo": cty.StringVal("v1"), 157 "bar": cty.NumberIntVal(1337), 158 "baz": cty.BoolVal(true), 159 }) 160 cdata, err := msgpack.Marshal(config, config.Type()) 161 require.NoError(err) 162 require.NoError(impl.SetConfig(cdata)) 163 require.Equal(cdata, receivedData) 164 165 // Decode the value back 166 var actual base.TestConfig 167 require.NoError(structs.Decode(receivedData, &actual)) 168 require.Equal("v1", actual.Foo) 169 require.EqualValues(1337, actual.Bar) 170 require.True(actual.Baz) 171 } 172 173 func TestDevicePlugin_Fingerprint(t *testing.T) { 174 t.Parallel() 175 require := require.New(t) 176 177 devices1 := []*DeviceGroup{ 178 { 179 Vendor: "nvidia", 180 Type: DeviceTypeGPU, 181 Name: "foo", 182 }, 183 } 184 devices2 := []*DeviceGroup{ 185 { 186 Vendor: "nvidia", 187 Type: DeviceTypeGPU, 188 Name: "foo", 189 }, 190 { 191 Vendor: "nvidia", 192 Type: DeviceTypeGPU, 193 Name: "bar", 194 }, 195 } 196 197 mock := &MockDevicePlugin{ 198 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 199 outCh := make(chan *FingerprintResponse, 1) 200 go func() { 201 // Send two messages 202 for _, devs := range [][]*DeviceGroup{devices1, devices2} { 203 select { 204 case <-ctx.Done(): 205 return 206 case outCh <- &FingerprintResponse{Devices: devs}: 207 } 208 } 209 close(outCh) 210 return 211 }() 212 return outCh, nil 213 }, 214 } 215 216 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 217 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 218 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 219 }) 220 defer server.Stop() 221 defer client.Close() 222 223 raw, err := client.Dispense(base.PluginTypeDevice) 224 if err != nil { 225 t.Fatalf("err: %s", err) 226 } 227 228 impl, ok := raw.(DevicePlugin) 229 if !ok { 230 t.Fatalf("bad: %#v", raw) 231 } 232 233 // Create a context 234 ctx, cancel := context.WithCancel(context.Background()) 235 defer cancel() 236 237 // Get the stream 238 stream, err := impl.Fingerprint(ctx) 239 require.NoError(err) 240 241 // Get the first message 242 var first *FingerprintResponse 243 select { 244 case <-time.After(1 * time.Second): 245 t.Fatal("timeout") 246 case first = <-stream: 247 } 248 249 require.NoError(first.Error) 250 require.EqualValues(devices1, first.Devices) 251 252 // Get the second message 253 var second *FingerprintResponse 254 select { 255 case <-time.After(1 * time.Second): 256 t.Fatal("timeout") 257 case second = <-stream: 258 } 259 260 require.NoError(second.Error) 261 require.EqualValues(devices2, second.Devices) 262 263 select { 264 case _, ok := <-stream: 265 require.False(ok) 266 case <-time.After(1 * time.Second): 267 t.Fatal("stream should be closed") 268 } 269 } 270 271 func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) { 272 t.Parallel() 273 require := require.New(t) 274 275 ferr := fmt.Errorf("mock fingerprinting failed") 276 mock := &MockDevicePlugin{ 277 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 278 outCh := make(chan *FingerprintResponse, 1) 279 go func() { 280 // Send the error 281 select { 282 case <-ctx.Done(): 283 return 284 case outCh <- &FingerprintResponse{Error: ferr}: 285 } 286 287 close(outCh) 288 return 289 }() 290 return outCh, nil 291 }, 292 } 293 294 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 295 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 296 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 297 }) 298 defer server.Stop() 299 defer client.Close() 300 301 raw, err := client.Dispense(base.PluginTypeDevice) 302 if err != nil { 303 t.Fatalf("err: %s", err) 304 } 305 306 impl, ok := raw.(DevicePlugin) 307 if !ok { 308 t.Fatalf("bad: %#v", raw) 309 } 310 311 // Create a context 312 ctx, cancel := context.WithCancel(context.Background()) 313 defer cancel() 314 315 // Get the stream 316 stream, err := impl.Fingerprint(ctx) 317 require.NoError(err) 318 319 // Get the first message 320 var first *FingerprintResponse 321 select { 322 case <-time.After(1 * time.Second): 323 t.Fatal("timeout") 324 case first = <-stream: 325 } 326 327 errStatus := status.Convert(ferr) 328 require.EqualError(first.Error, errStatus.Err().Error()) 329 } 330 331 func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) { 332 t.Parallel() 333 require := require.New(t) 334 335 mock := &MockDevicePlugin{ 336 FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) { 337 outCh := make(chan *FingerprintResponse, 1) 338 go func() { 339 <-ctx.Done() 340 close(outCh) 341 return 342 }() 343 return outCh, nil 344 }, 345 } 346 347 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 348 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 349 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 350 }) 351 defer server.Stop() 352 defer client.Close() 353 354 raw, err := client.Dispense(base.PluginTypeDevice) 355 if err != nil { 356 t.Fatalf("err: %s", err) 357 } 358 359 impl, ok := raw.(DevicePlugin) 360 if !ok { 361 t.Fatalf("bad: %#v", raw) 362 } 363 364 // Create a context 365 ctx, cancel := context.WithCancel(context.Background()) 366 367 // Get the stream 368 stream, err := impl.Fingerprint(ctx) 369 require.NoError(err) 370 371 // Get the first message 372 select { 373 case <-time.After(testutil.Timeout(10 * time.Millisecond)): 374 case _ = <-stream: 375 t.Fatal("bad value") 376 } 377 378 // Cancel the context 379 cancel() 380 381 // Make sure we are done 382 select { 383 case <-time.After(100 * time.Millisecond): 384 t.Fatalf("timeout") 385 case v := <-stream: 386 require.Error(v.Error) 387 require.EqualError(v.Error, context.Canceled.Error()) 388 } 389 } 390 391 func TestDevicePlugin_Reserve(t *testing.T) { 392 t.Parallel() 393 require := require.New(t) 394 395 reservation := &ContainerReservation{ 396 Envs: map[string]string{ 397 "foo": "bar", 398 }, 399 Mounts: []*Mount{ 400 { 401 TaskPath: "foo", 402 HostPath: "bar", 403 ReadOnly: true, 404 }, 405 }, 406 Devices: []*DeviceSpec{ 407 { 408 TaskPath: "foo", 409 HostPath: "bar", 410 CgroupPerms: "rx", 411 }, 412 }, 413 } 414 415 var received []string 416 mock := &MockDevicePlugin{ 417 ReserveF: func(devices []string) (*ContainerReservation, error) { 418 received = devices 419 return reservation, nil 420 }, 421 } 422 423 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 424 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 425 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 426 }) 427 defer server.Stop() 428 defer client.Close() 429 430 raw, err := client.Dispense(base.PluginTypeDevice) 431 if err != nil { 432 t.Fatalf("err: %s", err) 433 } 434 435 impl, ok := raw.(DevicePlugin) 436 if !ok { 437 t.Fatalf("bad: %#v", raw) 438 } 439 440 req := []string{"a", "b"} 441 containerRes, err := impl.Reserve(req) 442 require.NoError(err) 443 require.EqualValues(req, received) 444 require.EqualValues(reservation, containerRes) 445 } 446 447 func TestDevicePlugin_Stats(t *testing.T) { 448 t.Parallel() 449 require := require.New(t) 450 451 devices1 := []*DeviceGroupStats{ 452 { 453 Vendor: "nvidia", 454 Type: DeviceTypeGPU, 455 Name: "foo", 456 InstanceStats: map[string]*DeviceStats{ 457 "1": { 458 Summary: &StatValue{ 459 IntNumeratorVal: 10, 460 IntDenominatorVal: 20, 461 Unit: "MB", 462 Desc: "Unit test", 463 }, 464 }, 465 }, 466 }, 467 } 468 devices2 := []*DeviceGroupStats{ 469 { 470 Vendor: "nvidia", 471 Type: DeviceTypeGPU, 472 Name: "foo", 473 InstanceStats: map[string]*DeviceStats{ 474 "1": { 475 Summary: &StatValue{ 476 FloatNumeratorVal: 10.0, 477 FloatDenominatorVal: 20.0, 478 Unit: "MB", 479 Desc: "Unit test", 480 }, 481 }, 482 }, 483 }, 484 { 485 Vendor: "nvidia", 486 Type: DeviceTypeGPU, 487 Name: "bar", 488 InstanceStats: map[string]*DeviceStats{ 489 "1": { 490 Summary: &StatValue{ 491 StringVal: "foo", 492 Unit: "MB", 493 Desc: "Unit test", 494 }, 495 }, 496 }, 497 }, 498 { 499 Vendor: "nvidia", 500 Type: DeviceTypeGPU, 501 Name: "baz", 502 InstanceStats: map[string]*DeviceStats{ 503 "1": { 504 Summary: &StatValue{ 505 BoolVal: true, 506 Unit: "MB", 507 Desc: "Unit test", 508 }, 509 }, 510 }, 511 }, 512 } 513 514 mock := &MockDevicePlugin{ 515 StatsF: func(ctx context.Context) (<-chan *StatsResponse, error) { 516 outCh := make(chan *StatsResponse, 1) 517 go func() { 518 // Send two messages 519 for _, devs := range [][]*DeviceGroupStats{devices1, devices2} { 520 select { 521 case <-ctx.Done(): 522 return 523 case outCh <- &StatsResponse{Groups: devs}: 524 } 525 } 526 close(outCh) 527 return 528 }() 529 return outCh, nil 530 }, 531 } 532 533 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 534 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 535 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 536 }) 537 defer server.Stop() 538 defer client.Close() 539 540 raw, err := client.Dispense(base.PluginTypeDevice) 541 if err != nil { 542 t.Fatalf("err: %s", err) 543 } 544 545 impl, ok := raw.(DevicePlugin) 546 if !ok { 547 t.Fatalf("bad: %#v", raw) 548 } 549 550 // Create a context 551 ctx, cancel := context.WithCancel(context.Background()) 552 defer cancel() 553 554 // Get the stream 555 stream, err := impl.Stats(ctx) 556 require.NoError(err) 557 558 // Get the first message 559 var first *StatsResponse 560 select { 561 case <-time.After(1 * time.Second): 562 t.Fatal("timeout") 563 case first = <-stream: 564 } 565 566 require.NoError(first.Error) 567 require.EqualValues(devices1, first.Groups) 568 569 // Get the second message 570 var second *StatsResponse 571 select { 572 case <-time.After(1 * time.Second): 573 t.Fatal("timeout") 574 case second = <-stream: 575 } 576 577 require.NoError(second.Error) 578 require.EqualValues(devices2, second.Groups) 579 580 select { 581 case _, ok := <-stream: 582 require.False(ok) 583 case <-time.After(1 * time.Second): 584 t.Fatal("stream should be closed") 585 } 586 } 587 588 func TestDevicePlugin_Stats_StreamErr(t *testing.T) { 589 t.Parallel() 590 require := require.New(t) 591 592 ferr := fmt.Errorf("mock stats failed") 593 mock := &MockDevicePlugin{ 594 StatsF: func(ctx context.Context) (<-chan *StatsResponse, error) { 595 outCh := make(chan *StatsResponse, 1) 596 go func() { 597 // Send the error 598 select { 599 case <-ctx.Done(): 600 return 601 case outCh <- &StatsResponse{Error: ferr}: 602 } 603 604 close(outCh) 605 return 606 }() 607 return outCh, nil 608 }, 609 } 610 611 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 612 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 613 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 614 }) 615 defer server.Stop() 616 defer client.Close() 617 618 raw, err := client.Dispense(base.PluginTypeDevice) 619 if err != nil { 620 t.Fatalf("err: %s", err) 621 } 622 623 impl, ok := raw.(DevicePlugin) 624 if !ok { 625 t.Fatalf("bad: %#v", raw) 626 } 627 628 // Create a context 629 ctx, cancel := context.WithCancel(context.Background()) 630 defer cancel() 631 632 // Get the stream 633 stream, err := impl.Stats(ctx) 634 require.NoError(err) 635 636 // Get the first message 637 var first *StatsResponse 638 select { 639 case <-time.After(1 * time.Second): 640 t.Fatal("timeout") 641 case first = <-stream: 642 } 643 644 errStatus := status.Convert(ferr) 645 require.EqualError(first.Error, errStatus.Err().Error()) 646 } 647 648 func TestDevicePlugin_Stats_CancelCtx(t *testing.T) { 649 t.Parallel() 650 require := require.New(t) 651 652 mock := &MockDevicePlugin{ 653 StatsF: func(ctx context.Context) (<-chan *StatsResponse, error) { 654 outCh := make(chan *StatsResponse, 1) 655 go func() { 656 <-ctx.Done() 657 close(outCh) 658 return 659 }() 660 return outCh, nil 661 }, 662 } 663 664 client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 665 base.PluginTypeBase: &base.PluginBase{Impl: mock}, 666 base.PluginTypeDevice: &PluginDevice{Impl: mock}, 667 }) 668 defer server.Stop() 669 defer client.Close() 670 671 raw, err := client.Dispense(base.PluginTypeDevice) 672 if err != nil { 673 t.Fatalf("err: %s", err) 674 } 675 676 impl, ok := raw.(DevicePlugin) 677 if !ok { 678 t.Fatalf("bad: %#v", raw) 679 } 680 681 // Create a context 682 ctx, cancel := context.WithCancel(context.Background()) 683 684 // Get the stream 685 stream, err := impl.Stats(ctx) 686 require.NoError(err) 687 688 // Get the first message 689 select { 690 case <-time.After(testutil.Timeout(10 * time.Millisecond)): 691 case _ = <-stream: 692 t.Fatal("bad value") 693 } 694 695 // Cancel the context 696 cancel() 697 698 // Make sure we are done 699 select { 700 case <-time.After(100 * time.Millisecond): 701 t.Fatalf("timeout") 702 case v := <-stream: 703 require.Error(v.Error) 704 require.EqualError(v.Error, context.Canceled.Error()) 705 } 706 }