github.com/zoomfoo/nomad@v0.8.5-0.20180907175415-f28fd3a1a056/plugins/device/plugin_test.go (about)

     1  package device
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	pb "github.com/golang/protobuf/proto"
    10  	plugin "github.com/hashicorp/go-plugin"
    11  	"github.com/hashicorp/nomad/nomad/structs"
    12  	"github.com/hashicorp/nomad/plugins/base"
    13  	"github.com/hashicorp/nomad/plugins/shared/hclspec"
    14  	"github.com/hashicorp/nomad/testutil"
    15  	"github.com/stretchr/testify/require"
    16  	"github.com/zclconf/go-cty/cty"
    17  	"github.com/zclconf/go-cty/cty/msgpack"
    18  	"google.golang.org/grpc/status"
    19  )
    20  
    21  func TestDevicePlugin_PluginInfo(t *testing.T) {
    22  	t.Parallel()
    23  	require := require.New(t)
    24  
    25  	const (
    26  		apiVersion    = "v0.1.0"
    27  		pluginVersion = "v0.2.1"
    28  		pluginName    = "mock"
    29  	)
    30  
    31  	knownType := func() (*base.PluginInfoResponse, error) {
    32  		info := &base.PluginInfoResponse{
    33  			Type:             base.PluginTypeDevice,
    34  			PluginApiVersion: apiVersion,
    35  			PluginVersion:    pluginVersion,
    36  			Name:             pluginName,
    37  		}
    38  		return info, nil
    39  	}
    40  	unknownType := func() (*base.PluginInfoResponse, error) {
    41  		info := &base.PluginInfoResponse{
    42  			Type:             "bad",
    43  			PluginApiVersion: apiVersion,
    44  			PluginVersion:    pluginVersion,
    45  			Name:             pluginName,
    46  		}
    47  		return info, nil
    48  	}
    49  
    50  	mock := &MockDevicePlugin{
    51  		MockPlugin: &base.MockPlugin{
    52  			PluginInfoF: knownType,
    53  		},
    54  	}
    55  
    56  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
    57  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
    58  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
    59  	})
    60  	defer server.Stop()
    61  	defer client.Close()
    62  
    63  	raw, err := client.Dispense(base.PluginTypeDevice)
    64  	if err != nil {
    65  		t.Fatalf("err: %s", err)
    66  	}
    67  
    68  	impl, ok := raw.(DevicePlugin)
    69  	if !ok {
    70  		t.Fatalf("bad: %#v", raw)
    71  	}
    72  
    73  	resp, err := impl.PluginInfo()
    74  	require.NoError(err)
    75  	require.Equal(apiVersion, resp.PluginApiVersion)
    76  	require.Equal(pluginVersion, resp.PluginVersion)
    77  	require.Equal(pluginName, resp.Name)
    78  	require.Equal(base.PluginTypeDevice, resp.Type)
    79  
    80  	// Swap the implementation to return an unknown type
    81  	mock.PluginInfoF = unknownType
    82  	_, err = impl.PluginInfo()
    83  	require.Error(err)
    84  	require.Contains(err.Error(), "unknown type")
    85  }
    86  
    87  func TestDevicePlugin_ConfigSchema(t *testing.T) {
    88  	t.Parallel()
    89  	require := require.New(t)
    90  
    91  	mock := &MockDevicePlugin{
    92  		MockPlugin: &base.MockPlugin{
    93  			ConfigSchemaF: func() (*hclspec.Spec, error) {
    94  				return base.TestSpec, nil
    95  			},
    96  		},
    97  	}
    98  
    99  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   100  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   101  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   102  	})
   103  	defer server.Stop()
   104  	defer client.Close()
   105  
   106  	raw, err := client.Dispense(base.PluginTypeDevice)
   107  	if err != nil {
   108  		t.Fatalf("err: %s", err)
   109  	}
   110  
   111  	impl, ok := raw.(DevicePlugin)
   112  	if !ok {
   113  		t.Fatalf("bad: %#v", raw)
   114  	}
   115  
   116  	specOut, err := impl.ConfigSchema()
   117  	require.NoError(err)
   118  	require.True(pb.Equal(base.TestSpec, specOut))
   119  }
   120  
   121  func TestDevicePlugin_SetConfig(t *testing.T) {
   122  	t.Parallel()
   123  	require := require.New(t)
   124  
   125  	var receivedData []byte
   126  	mock := &MockDevicePlugin{
   127  		MockPlugin: &base.MockPlugin{
   128  			ConfigSchemaF: func() (*hclspec.Spec, error) {
   129  				return base.TestSpec, nil
   130  			},
   131  			SetConfigF: func(data []byte) error {
   132  				receivedData = data
   133  				return nil
   134  			},
   135  		},
   136  	}
   137  
   138  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   139  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   140  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   141  	})
   142  	defer server.Stop()
   143  	defer client.Close()
   144  
   145  	raw, err := client.Dispense(base.PluginTypeDevice)
   146  	if err != nil {
   147  		t.Fatalf("err: %s", err)
   148  	}
   149  
   150  	impl, ok := raw.(DevicePlugin)
   151  	if !ok {
   152  		t.Fatalf("bad: %#v", raw)
   153  	}
   154  
   155  	config := cty.ObjectVal(map[string]cty.Value{
   156  		"foo": cty.StringVal("v1"),
   157  		"bar": cty.NumberIntVal(1337),
   158  		"baz": cty.BoolVal(true),
   159  	})
   160  	cdata, err := msgpack.Marshal(config, config.Type())
   161  	require.NoError(err)
   162  	require.NoError(impl.SetConfig(cdata))
   163  	require.Equal(cdata, receivedData)
   164  
   165  	// Decode the value back
   166  	var actual base.TestConfig
   167  	require.NoError(structs.Decode(receivedData, &actual))
   168  	require.Equal("v1", actual.Foo)
   169  	require.EqualValues(1337, actual.Bar)
   170  	require.True(actual.Baz)
   171  }
   172  
   173  func TestDevicePlugin_Fingerprint(t *testing.T) {
   174  	t.Parallel()
   175  	require := require.New(t)
   176  
   177  	devices1 := []*DeviceGroup{
   178  		{
   179  			Vendor: "nvidia",
   180  			Type:   DeviceTypeGPU,
   181  			Name:   "foo",
   182  		},
   183  	}
   184  	devices2 := []*DeviceGroup{
   185  		{
   186  			Vendor: "nvidia",
   187  			Type:   DeviceTypeGPU,
   188  			Name:   "foo",
   189  		},
   190  		{
   191  			Vendor: "nvidia",
   192  			Type:   DeviceTypeGPU,
   193  			Name:   "bar",
   194  		},
   195  	}
   196  
   197  	mock := &MockDevicePlugin{
   198  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   199  			outCh := make(chan *FingerprintResponse, 1)
   200  			go func() {
   201  				// Send two messages
   202  				for _, devs := range [][]*DeviceGroup{devices1, devices2} {
   203  					select {
   204  					case <-ctx.Done():
   205  						return
   206  					case outCh <- &FingerprintResponse{Devices: devs}:
   207  					}
   208  				}
   209  				close(outCh)
   210  				return
   211  			}()
   212  			return outCh, nil
   213  		},
   214  	}
   215  
   216  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   217  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   218  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   219  	})
   220  	defer server.Stop()
   221  	defer client.Close()
   222  
   223  	raw, err := client.Dispense(base.PluginTypeDevice)
   224  	if err != nil {
   225  		t.Fatalf("err: %s", err)
   226  	}
   227  
   228  	impl, ok := raw.(DevicePlugin)
   229  	if !ok {
   230  		t.Fatalf("bad: %#v", raw)
   231  	}
   232  
   233  	// Create a context
   234  	ctx, cancel := context.WithCancel(context.Background())
   235  	defer cancel()
   236  
   237  	// Get the stream
   238  	stream, err := impl.Fingerprint(ctx)
   239  	require.NoError(err)
   240  
   241  	// Get the first message
   242  	var first *FingerprintResponse
   243  	select {
   244  	case <-time.After(1 * time.Second):
   245  		t.Fatal("timeout")
   246  	case first = <-stream:
   247  	}
   248  
   249  	require.NoError(first.Error)
   250  	require.EqualValues(devices1, first.Devices)
   251  
   252  	// Get the second message
   253  	var second *FingerprintResponse
   254  	select {
   255  	case <-time.After(1 * time.Second):
   256  		t.Fatal("timeout")
   257  	case second = <-stream:
   258  	}
   259  
   260  	require.NoError(second.Error)
   261  	require.EqualValues(devices2, second.Devices)
   262  
   263  	select {
   264  	case _, ok := <-stream:
   265  		require.False(ok)
   266  	case <-time.After(1 * time.Second):
   267  		t.Fatal("stream should be closed")
   268  	}
   269  }
   270  
   271  func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) {
   272  	t.Parallel()
   273  	require := require.New(t)
   274  
   275  	ferr := fmt.Errorf("mock fingerprinting failed")
   276  	mock := &MockDevicePlugin{
   277  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   278  			outCh := make(chan *FingerprintResponse, 1)
   279  			go func() {
   280  				// Send the error
   281  				select {
   282  				case <-ctx.Done():
   283  					return
   284  				case outCh <- &FingerprintResponse{Error: ferr}:
   285  				}
   286  
   287  				close(outCh)
   288  				return
   289  			}()
   290  			return outCh, nil
   291  		},
   292  	}
   293  
   294  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   295  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   296  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   297  	})
   298  	defer server.Stop()
   299  	defer client.Close()
   300  
   301  	raw, err := client.Dispense(base.PluginTypeDevice)
   302  	if err != nil {
   303  		t.Fatalf("err: %s", err)
   304  	}
   305  
   306  	impl, ok := raw.(DevicePlugin)
   307  	if !ok {
   308  		t.Fatalf("bad: %#v", raw)
   309  	}
   310  
   311  	// Create a context
   312  	ctx, cancel := context.WithCancel(context.Background())
   313  	defer cancel()
   314  
   315  	// Get the stream
   316  	stream, err := impl.Fingerprint(ctx)
   317  	require.NoError(err)
   318  
   319  	// Get the first message
   320  	var first *FingerprintResponse
   321  	select {
   322  	case <-time.After(1 * time.Second):
   323  		t.Fatal("timeout")
   324  	case first = <-stream:
   325  	}
   326  
   327  	errStatus := status.Convert(ferr)
   328  	require.EqualError(first.Error, errStatus.Err().Error())
   329  }
   330  
   331  func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) {
   332  	t.Parallel()
   333  	require := require.New(t)
   334  
   335  	mock := &MockDevicePlugin{
   336  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   337  			outCh := make(chan *FingerprintResponse, 1)
   338  			go func() {
   339  				<-ctx.Done()
   340  				close(outCh)
   341  				return
   342  			}()
   343  			return outCh, nil
   344  		},
   345  	}
   346  
   347  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   348  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   349  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   350  	})
   351  	defer server.Stop()
   352  	defer client.Close()
   353  
   354  	raw, err := client.Dispense(base.PluginTypeDevice)
   355  	if err != nil {
   356  		t.Fatalf("err: %s", err)
   357  	}
   358  
   359  	impl, ok := raw.(DevicePlugin)
   360  	if !ok {
   361  		t.Fatalf("bad: %#v", raw)
   362  	}
   363  
   364  	// Create a context
   365  	ctx, cancel := context.WithCancel(context.Background())
   366  
   367  	// Get the stream
   368  	stream, err := impl.Fingerprint(ctx)
   369  	require.NoError(err)
   370  
   371  	// Get the first message
   372  	select {
   373  	case <-time.After(testutil.Timeout(10 * time.Millisecond)):
   374  	case _ = <-stream:
   375  		t.Fatal("bad value")
   376  	}
   377  
   378  	// Cancel the context
   379  	cancel()
   380  
   381  	// Make sure we are done
   382  	select {
   383  	case <-time.After(100 * time.Millisecond):
   384  		t.Fatalf("timeout")
   385  	case v := <-stream:
   386  		require.Error(v.Error)
   387  		require.EqualError(v.Error, context.Canceled.Error())
   388  	}
   389  }
   390  
   391  func TestDevicePlugin_Reserve(t *testing.T) {
   392  	t.Parallel()
   393  	require := require.New(t)
   394  
   395  	reservation := &ContainerReservation{
   396  		Envs: map[string]string{
   397  			"foo": "bar",
   398  		},
   399  		Mounts: []*Mount{
   400  			{
   401  				TaskPath: "foo",
   402  				HostPath: "bar",
   403  				ReadOnly: true,
   404  			},
   405  		},
   406  		Devices: []*DeviceSpec{
   407  			{
   408  				TaskPath:    "foo",
   409  				HostPath:    "bar",
   410  				CgroupPerms: "rx",
   411  			},
   412  		},
   413  	}
   414  
   415  	var received []string
   416  	mock := &MockDevicePlugin{
   417  		ReserveF: func(devices []string) (*ContainerReservation, error) {
   418  			received = devices
   419  			return reservation, nil
   420  		},
   421  	}
   422  
   423  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   424  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   425  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   426  	})
   427  	defer server.Stop()
   428  	defer client.Close()
   429  
   430  	raw, err := client.Dispense(base.PluginTypeDevice)
   431  	if err != nil {
   432  		t.Fatalf("err: %s", err)
   433  	}
   434  
   435  	impl, ok := raw.(DevicePlugin)
   436  	if !ok {
   437  		t.Fatalf("bad: %#v", raw)
   438  	}
   439  
   440  	req := []string{"a", "b"}
   441  	containerRes, err := impl.Reserve(req)
   442  	require.NoError(err)
   443  	require.EqualValues(req, received)
   444  	require.EqualValues(reservation, containerRes)
   445  }
   446  
   447  func TestDevicePlugin_Stats(t *testing.T) {
   448  	t.Parallel()
   449  	require := require.New(t)
   450  
   451  	devices1 := []*DeviceGroupStats{
   452  		{
   453  			Vendor: "nvidia",
   454  			Type:   DeviceTypeGPU,
   455  			Name:   "foo",
   456  			InstanceStats: map[string]*DeviceStats{
   457  				"1": {
   458  					Summary: &StatValue{
   459  						IntNumeratorVal:   10,
   460  						IntDenominatorVal: 20,
   461  						Unit:              "MB",
   462  						Desc:              "Unit test",
   463  					},
   464  				},
   465  			},
   466  		},
   467  	}
   468  	devices2 := []*DeviceGroupStats{
   469  		{
   470  			Vendor: "nvidia",
   471  			Type:   DeviceTypeGPU,
   472  			Name:   "foo",
   473  			InstanceStats: map[string]*DeviceStats{
   474  				"1": {
   475  					Summary: &StatValue{
   476  						FloatNumeratorVal:   10.0,
   477  						FloatDenominatorVal: 20.0,
   478  						Unit:                "MB",
   479  						Desc:                "Unit test",
   480  					},
   481  				},
   482  			},
   483  		},
   484  		{
   485  			Vendor: "nvidia",
   486  			Type:   DeviceTypeGPU,
   487  			Name:   "bar",
   488  			InstanceStats: map[string]*DeviceStats{
   489  				"1": {
   490  					Summary: &StatValue{
   491  						StringVal: "foo",
   492  						Unit:      "MB",
   493  						Desc:      "Unit test",
   494  					},
   495  				},
   496  			},
   497  		},
   498  		{
   499  			Vendor: "nvidia",
   500  			Type:   DeviceTypeGPU,
   501  			Name:   "baz",
   502  			InstanceStats: map[string]*DeviceStats{
   503  				"1": {
   504  					Summary: &StatValue{
   505  						BoolVal: true,
   506  						Unit:    "MB",
   507  						Desc:    "Unit test",
   508  					},
   509  				},
   510  			},
   511  		},
   512  	}
   513  
   514  	mock := &MockDevicePlugin{
   515  		StatsF: func(ctx context.Context) (<-chan *StatsResponse, error) {
   516  			outCh := make(chan *StatsResponse, 1)
   517  			go func() {
   518  				// Send two messages
   519  				for _, devs := range [][]*DeviceGroupStats{devices1, devices2} {
   520  					select {
   521  					case <-ctx.Done():
   522  						return
   523  					case outCh <- &StatsResponse{Groups: devs}:
   524  					}
   525  				}
   526  				close(outCh)
   527  				return
   528  			}()
   529  			return outCh, nil
   530  		},
   531  	}
   532  
   533  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   534  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   535  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   536  	})
   537  	defer server.Stop()
   538  	defer client.Close()
   539  
   540  	raw, err := client.Dispense(base.PluginTypeDevice)
   541  	if err != nil {
   542  		t.Fatalf("err: %s", err)
   543  	}
   544  
   545  	impl, ok := raw.(DevicePlugin)
   546  	if !ok {
   547  		t.Fatalf("bad: %#v", raw)
   548  	}
   549  
   550  	// Create a context
   551  	ctx, cancel := context.WithCancel(context.Background())
   552  	defer cancel()
   553  
   554  	// Get the stream
   555  	stream, err := impl.Stats(ctx)
   556  	require.NoError(err)
   557  
   558  	// Get the first message
   559  	var first *StatsResponse
   560  	select {
   561  	case <-time.After(1 * time.Second):
   562  		t.Fatal("timeout")
   563  	case first = <-stream:
   564  	}
   565  
   566  	require.NoError(first.Error)
   567  	require.EqualValues(devices1, first.Groups)
   568  
   569  	// Get the second message
   570  	var second *StatsResponse
   571  	select {
   572  	case <-time.After(1 * time.Second):
   573  		t.Fatal("timeout")
   574  	case second = <-stream:
   575  	}
   576  
   577  	require.NoError(second.Error)
   578  	require.EqualValues(devices2, second.Groups)
   579  
   580  	select {
   581  	case _, ok := <-stream:
   582  		require.False(ok)
   583  	case <-time.After(1 * time.Second):
   584  		t.Fatal("stream should be closed")
   585  	}
   586  }
   587  
   588  func TestDevicePlugin_Stats_StreamErr(t *testing.T) {
   589  	t.Parallel()
   590  	require := require.New(t)
   591  
   592  	ferr := fmt.Errorf("mock stats failed")
   593  	mock := &MockDevicePlugin{
   594  		StatsF: func(ctx context.Context) (<-chan *StatsResponse, error) {
   595  			outCh := make(chan *StatsResponse, 1)
   596  			go func() {
   597  				// Send the error
   598  				select {
   599  				case <-ctx.Done():
   600  					return
   601  				case outCh <- &StatsResponse{Error: ferr}:
   602  				}
   603  
   604  				close(outCh)
   605  				return
   606  			}()
   607  			return outCh, nil
   608  		},
   609  	}
   610  
   611  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   612  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   613  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   614  	})
   615  	defer server.Stop()
   616  	defer client.Close()
   617  
   618  	raw, err := client.Dispense(base.PluginTypeDevice)
   619  	if err != nil {
   620  		t.Fatalf("err: %s", err)
   621  	}
   622  
   623  	impl, ok := raw.(DevicePlugin)
   624  	if !ok {
   625  		t.Fatalf("bad: %#v", raw)
   626  	}
   627  
   628  	// Create a context
   629  	ctx, cancel := context.WithCancel(context.Background())
   630  	defer cancel()
   631  
   632  	// Get the stream
   633  	stream, err := impl.Stats(ctx)
   634  	require.NoError(err)
   635  
   636  	// Get the first message
   637  	var first *StatsResponse
   638  	select {
   639  	case <-time.After(1 * time.Second):
   640  		t.Fatal("timeout")
   641  	case first = <-stream:
   642  	}
   643  
   644  	errStatus := status.Convert(ferr)
   645  	require.EqualError(first.Error, errStatus.Err().Error())
   646  }
   647  
   648  func TestDevicePlugin_Stats_CancelCtx(t *testing.T) {
   649  	t.Parallel()
   650  	require := require.New(t)
   651  
   652  	mock := &MockDevicePlugin{
   653  		StatsF: func(ctx context.Context) (<-chan *StatsResponse, error) {
   654  			outCh := make(chan *StatsResponse, 1)
   655  			go func() {
   656  				<-ctx.Done()
   657  				close(outCh)
   658  				return
   659  			}()
   660  			return outCh, nil
   661  		},
   662  	}
   663  
   664  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   665  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   666  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   667  	})
   668  	defer server.Stop()
   669  	defer client.Close()
   670  
   671  	raw, err := client.Dispense(base.PluginTypeDevice)
   672  	if err != nil {
   673  		t.Fatalf("err: %s", err)
   674  	}
   675  
   676  	impl, ok := raw.(DevicePlugin)
   677  	if !ok {
   678  		t.Fatalf("bad: %#v", raw)
   679  	}
   680  
   681  	// Create a context
   682  	ctx, cancel := context.WithCancel(context.Background())
   683  
   684  	// Get the stream
   685  	stream, err := impl.Stats(ctx)
   686  	require.NoError(err)
   687  
   688  	// Get the first message
   689  	select {
   690  	case <-time.After(testutil.Timeout(10 * time.Millisecond)):
   691  	case _ = <-stream:
   692  		t.Fatal("bad value")
   693  	}
   694  
   695  	// Cancel the context
   696  	cancel()
   697  
   698  	// Make sure we are done
   699  	select {
   700  	case <-time.After(100 * time.Millisecond):
   701  		t.Fatalf("timeout")
   702  	case v := <-stream:
   703  		require.Error(v.Error)
   704  		require.EqualError(v.Error, context.Canceled.Error())
   705  	}
   706  }