github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/devicemanager/manager_test.go (about)

     1  package devicemanager
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	log "github.com/hashicorp/go-hclog"
    11  	plugin "github.com/hashicorp/go-plugin"
    12  	"github.com/hashicorp/nomad/ci"
    13  	"github.com/hashicorp/nomad/client/state"
    14  	"github.com/hashicorp/nomad/helper/pluginutils/loader"
    15  	"github.com/hashicorp/nomad/helper/pointer"
    16  	"github.com/hashicorp/nomad/helper/testlog"
    17  	"github.com/hashicorp/nomad/helper/uuid"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  	"github.com/hashicorp/nomad/plugins/base"
    20  	"github.com/hashicorp/nomad/plugins/device"
    21  	psstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    22  	"github.com/hashicorp/nomad/testutil"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  var (
    27  	nvidiaDevice0ID   = uuid.Generate()
    28  	nvidiaDevice1ID   = uuid.Generate()
    29  	nvidiaDeviceGroup = &device.DeviceGroup{
    30  		Vendor: "nvidia",
    31  		Type:   "gpu",
    32  		Name:   "1080ti",
    33  		Devices: []*device.Device{
    34  			{
    35  				ID:      nvidiaDevice0ID,
    36  				Healthy: true,
    37  			},
    38  			{
    39  				ID:      nvidiaDevice1ID,
    40  				Healthy: true,
    41  			},
    42  		},
    43  		Attributes: map[string]*psstructs.Attribute{
    44  			"memory": {
    45  				Int:  pointer.Of(int64(4)),
    46  				Unit: "GB",
    47  			},
    48  		},
    49  	}
    50  
    51  	intelDeviceID    = uuid.Generate()
    52  	intelDeviceGroup = &device.DeviceGroup{
    53  		Vendor: "intel",
    54  		Type:   "gpu",
    55  		Name:   "640GT",
    56  		Devices: []*device.Device{
    57  			{
    58  				ID:      intelDeviceID,
    59  				Healthy: true,
    60  			},
    61  		},
    62  		Attributes: map[string]*psstructs.Attribute{
    63  			"memory": {
    64  				Int:  pointer.Of(int64(2)),
    65  				Unit: "GB",
    66  			},
    67  		},
    68  	}
    69  
    70  	nvidiaDeviceGroupStats = &device.DeviceGroupStats{
    71  		Vendor: "nvidia",
    72  		Type:   "gpu",
    73  		Name:   "1080ti",
    74  		InstanceStats: map[string]*device.DeviceStats{
    75  			nvidiaDevice0ID: {
    76  				Summary: &psstructs.StatValue{
    77  					IntNumeratorVal: pointer.Of(int64(212)),
    78  					Unit:            "F",
    79  					Desc:            "Temperature",
    80  				},
    81  			},
    82  			nvidiaDevice1ID: {
    83  				Summary: &psstructs.StatValue{
    84  					IntNumeratorVal: pointer.Of(int64(218)),
    85  					Unit:            "F",
    86  					Desc:            "Temperature",
    87  				},
    88  			},
    89  		},
    90  	}
    91  
    92  	intelDeviceGroupStats = &device.DeviceGroupStats{
    93  		Vendor: "intel",
    94  		Type:   "gpu",
    95  		Name:   "640GT",
    96  		InstanceStats: map[string]*device.DeviceStats{
    97  			intelDeviceID: {
    98  				Summary: &psstructs.StatValue{
    99  					IntNumeratorVal: pointer.Of(int64(220)),
   100  					Unit:            "F",
   101  					Desc:            "Temperature",
   102  				},
   103  			},
   104  		},
   105  	}
   106  )
   107  
   108  func baseTestConfig(t *testing.T) (
   109  	config *Config,
   110  	deviceUpdateCh chan []*structs.NodeDeviceResource,
   111  	catalog *loader.MockCatalog) {
   112  
   113  	// Create an update handler
   114  	deviceUpdates := make(chan []*structs.NodeDeviceResource, 1)
   115  	updateFn := func(devices []*structs.NodeDeviceResource) {
   116  		deviceUpdates <- devices
   117  	}
   118  
   119  	// Create a mock plugin catalog
   120  	mc := &loader.MockCatalog{}
   121  
   122  	// Create the config
   123  	logger := testlog.HCLogger(t)
   124  	config = &Config{
   125  		Logger:        logger,
   126  		PluginConfig:  &base.AgentConfig{},
   127  		StatsInterval: 100 * time.Millisecond,
   128  		State:         state.NewMemDB(logger),
   129  		Updater:       updateFn,
   130  		Loader:        mc,
   131  	}
   132  
   133  	return config, deviceUpdates, mc
   134  }
   135  
   136  func configureCatalogWith(catalog *loader.MockCatalog, plugins map[*base.PluginInfoResponse]loader.PluginInstance) {
   137  
   138  	catalog.DispenseF = func(name, _ string, _ *base.AgentConfig, _ log.Logger) (loader.PluginInstance, error) {
   139  		for info, v := range plugins {
   140  			if info.Name == name {
   141  				return v, nil
   142  			}
   143  		}
   144  
   145  		return nil, fmt.Errorf("no matching plugin")
   146  	}
   147  
   148  	catalog.ReattachF = func(name, _ string, _ *plugin.ReattachConfig) (loader.PluginInstance, error) {
   149  		for info, v := range plugins {
   150  			if info.Name == name {
   151  				return v, nil
   152  			}
   153  		}
   154  
   155  		return nil, fmt.Errorf("no matching plugin")
   156  	}
   157  
   158  	catalog.CatalogF = func() map[string][]*base.PluginInfoResponse {
   159  		devices := make([]*base.PluginInfoResponse, 0, len(plugins))
   160  		for k := range plugins {
   161  			devices = append(devices, k)
   162  		}
   163  		out := map[string][]*base.PluginInfoResponse{
   164  			base.PluginTypeDevice: devices,
   165  		}
   166  		return out
   167  	}
   168  }
   169  
   170  func pluginInfoResponse(name string) *base.PluginInfoResponse {
   171  	return &base.PluginInfoResponse{
   172  		Type:              base.PluginTypeDevice,
   173  		PluginApiVersions: []string{"v0.0.1"},
   174  		PluginVersion:     "v0.0.1",
   175  		Name:              name,
   176  	}
   177  }
   178  
   179  // drainNodeDeviceUpdates drains all updates to the node device fingerprint channel
   180  func drainNodeDeviceUpdates(ctx context.Context, in chan []*structs.NodeDeviceResource) {
   181  	go func() {
   182  		for {
   183  			select {
   184  			case <-ctx.Done():
   185  				return
   186  			case <-in:
   187  			}
   188  		}
   189  	}()
   190  }
   191  
   192  func deviceReserveFn(ids []string) (*device.ContainerReservation, error) {
   193  	return &device.ContainerReservation{
   194  		Envs: map[string]string{
   195  			"DEVICES": strings.Join(ids, ","),
   196  		},
   197  	}, nil
   198  }
   199  
   200  // nvidiaAndIntelDefaultPlugins adds an nvidia and intel mock plugin to the
   201  // catalog
   202  func nvidiaAndIntelDefaultPlugins(catalog *loader.MockCatalog) {
   203  	pluginInfoNvidia := pluginInfoResponse("nvidia")
   204  	deviceNvidia := &device.MockDevicePlugin{
   205  		MockPlugin: &base.MockPlugin{
   206  			PluginInfoF:   base.StaticInfo(pluginInfoNvidia),
   207  			ConfigSchemaF: base.TestConfigSchema(),
   208  			SetConfigF:    base.NoopSetConfig(),
   209  		},
   210  		FingerprintF: device.StaticFingerprinter([]*device.DeviceGroup{nvidiaDeviceGroup}),
   211  		ReserveF:     deviceReserveFn,
   212  		StatsF:       device.StaticStats([]*device.DeviceGroupStats{nvidiaDeviceGroupStats}),
   213  	}
   214  	pluginNvidia := loader.MockBasicExternalPlugin(deviceNvidia, device.ApiVersion010)
   215  
   216  	pluginInfoIntel := pluginInfoResponse("intel")
   217  	deviceIntel := &device.MockDevicePlugin{
   218  		MockPlugin: &base.MockPlugin{
   219  			PluginInfoF:   base.StaticInfo(pluginInfoIntel),
   220  			ConfigSchemaF: base.TestConfigSchema(),
   221  			SetConfigF:    base.NoopSetConfig(),
   222  		},
   223  		FingerprintF: device.StaticFingerprinter([]*device.DeviceGroup{intelDeviceGroup}),
   224  		ReserveF:     deviceReserveFn,
   225  		StatsF:       device.StaticStats([]*device.DeviceGroupStats{intelDeviceGroupStats}),
   226  	}
   227  	pluginIntel := loader.MockBasicExternalPlugin(deviceIntel, device.ApiVersion010)
   228  
   229  	// Configure the catalog with two plugins
   230  	configureCatalogWith(catalog, map[*base.PluginInfoResponse]loader.PluginInstance{
   231  		pluginInfoNvidia: pluginNvidia,
   232  		pluginInfoIntel:  pluginIntel,
   233  	})
   234  }
   235  
   236  // Test collecting statistics from all devices
   237  func TestManager_AllStats(t *testing.T) {
   238  	ci.Parallel(t)
   239  	require := require.New(t)
   240  
   241  	config, _, catalog := baseTestConfig(t)
   242  	nvidiaAndIntelDefaultPlugins(catalog)
   243  
   244  	m := New(config)
   245  	m.Run()
   246  	defer m.Shutdown()
   247  	require.Len(m.instances, 2)
   248  
   249  	// Wait till we get a fingerprint result
   250  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   251  	defer cancel()
   252  	<-m.WaitForFirstFingerprint(ctx)
   253  	require.NoError(ctx.Err())
   254  
   255  	// Now collect all the stats
   256  	var stats []*device.DeviceGroupStats
   257  	testutil.WaitForResult(func() (bool, error) {
   258  		stats = m.AllStats()
   259  		l := len(stats)
   260  		if l == 2 {
   261  			return true, nil
   262  		}
   263  
   264  		return false, fmt.Errorf("expected count 2; got %d", l)
   265  	}, func(err error) {
   266  		t.Fatal(err)
   267  	})
   268  
   269  	// Check we got stats from both the devices
   270  	var nstats, istats bool
   271  	for _, stat := range stats {
   272  		switch stat.Vendor {
   273  		case "intel":
   274  			istats = true
   275  		case "nvidia":
   276  			nstats = true
   277  		default:
   278  			t.Fatalf("unexpected vendor %q", stat.Vendor)
   279  		}
   280  	}
   281  	require.True(nstats)
   282  	require.True(istats)
   283  }
   284  
   285  // Test collecting statistics from a particular device
   286  func TestManager_DeviceStats(t *testing.T) {
   287  	ci.Parallel(t)
   288  	require := require.New(t)
   289  
   290  	config, _, catalog := baseTestConfig(t)
   291  	nvidiaAndIntelDefaultPlugins(catalog)
   292  
   293  	m := New(config)
   294  	m.Run()
   295  	defer m.Shutdown()
   296  
   297  	// Wait till we get a fingerprint result
   298  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   299  	defer cancel()
   300  	<-m.WaitForFirstFingerprint(ctx)
   301  	require.NoError(ctx.Err())
   302  
   303  	testutil.WaitForResult(func() (bool, error) {
   304  		stats := m.AllStats()
   305  		l := len(stats)
   306  		if l == 2 {
   307  			return true, nil
   308  		}
   309  
   310  		return false, fmt.Errorf("expected count 2; got %d", l)
   311  	}, func(err error) {
   312  		t.Fatal(err)
   313  	})
   314  
   315  	// Now collect the stats for one nvidia device
   316  	stat, err := m.DeviceStats(&structs.AllocatedDeviceResource{
   317  		Vendor:    "nvidia",
   318  		Type:      "gpu",
   319  		Name:      "1080ti",
   320  		DeviceIDs: []string{nvidiaDevice1ID},
   321  	})
   322  	require.NoError(err)
   323  	require.NotNil(stat)
   324  
   325  	require.Len(stat.InstanceStats, 1)
   326  	require.Contains(stat.InstanceStats, nvidiaDevice1ID)
   327  
   328  	istat := stat.InstanceStats[nvidiaDevice1ID]
   329  	require.EqualValues(218, *istat.Summary.IntNumeratorVal)
   330  }
   331  
   332  // Test reserving a particular device
   333  func TestManager_Reserve(t *testing.T) {
   334  	ci.Parallel(t)
   335  	r := require.New(t)
   336  
   337  	config, _, catalog := baseTestConfig(t)
   338  	nvidiaAndIntelDefaultPlugins(catalog)
   339  
   340  	m := New(config)
   341  	m.Run()
   342  	defer m.Shutdown()
   343  
   344  	// Wait till we get a fingerprint result
   345  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   346  	defer cancel()
   347  	<-m.WaitForFirstFingerprint(ctx)
   348  	r.NoError(ctx.Err())
   349  
   350  	cases := []struct {
   351  		in       *structs.AllocatedDeviceResource
   352  		expected string
   353  		err      bool
   354  	}{
   355  		{
   356  			in: &structs.AllocatedDeviceResource{
   357  				Vendor:    "nvidia",
   358  				Type:      "gpu",
   359  				Name:      "1080ti",
   360  				DeviceIDs: []string{nvidiaDevice1ID},
   361  			},
   362  			expected: nvidiaDevice1ID,
   363  		},
   364  		{
   365  			in: &structs.AllocatedDeviceResource{
   366  				Vendor:    "nvidia",
   367  				Type:      "gpu",
   368  				Name:      "1080ti",
   369  				DeviceIDs: []string{nvidiaDevice0ID},
   370  			},
   371  			expected: nvidiaDevice0ID,
   372  		},
   373  		{
   374  			in: &structs.AllocatedDeviceResource{
   375  				Vendor:    "nvidia",
   376  				Type:      "gpu",
   377  				Name:      "1080ti",
   378  				DeviceIDs: []string{nvidiaDevice0ID, nvidiaDevice1ID},
   379  			},
   380  			expected: fmt.Sprintf("%s,%s", nvidiaDevice0ID, nvidiaDevice1ID),
   381  		},
   382  		{
   383  			in: &structs.AllocatedDeviceResource{
   384  				Vendor:    "nvidia",
   385  				Type:      "gpu",
   386  				Name:      "1080ti",
   387  				DeviceIDs: []string{nvidiaDevice0ID, nvidiaDevice1ID, "foo"},
   388  			},
   389  			err: true,
   390  		},
   391  		{
   392  			in: &structs.AllocatedDeviceResource{
   393  				Vendor:    "intel",
   394  				Type:      "gpu",
   395  				Name:      "640GT",
   396  				DeviceIDs: []string{intelDeviceID},
   397  			},
   398  			expected: intelDeviceID,
   399  		},
   400  		{
   401  			in: &structs.AllocatedDeviceResource{
   402  				Vendor:    "intel",
   403  				Type:      "gpu",
   404  				Name:      "foo",
   405  				DeviceIDs: []string{intelDeviceID},
   406  			},
   407  			err: true,
   408  		},
   409  	}
   410  
   411  	for i, c := range cases {
   412  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   413  			r = require.New(t)
   414  
   415  			// Reserve a particular device
   416  			res, err := m.Reserve(c.in)
   417  			if !c.err {
   418  				r.NoError(err)
   419  				r.NotNil(res)
   420  
   421  				r.Len(res.Envs, 1)
   422  				r.Equal(res.Envs["DEVICES"], c.expected)
   423  			} else {
   424  				r.Error(err)
   425  			}
   426  		})
   427  	}
   428  }
   429  
   430  // Test that shutdown shutsdown the plugins
   431  func TestManager_Shutdown(t *testing.T) {
   432  	ci.Parallel(t)
   433  	require := require.New(t)
   434  
   435  	config, _, catalog := baseTestConfig(t)
   436  	nvidiaAndIntelDefaultPlugins(catalog)
   437  
   438  	m := New(config)
   439  	m.Run()
   440  	defer m.Shutdown()
   441  
   442  	// Wait till we get a fingerprint result
   443  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   444  	defer cancel()
   445  	<-m.WaitForFirstFingerprint(ctx)
   446  	require.NoError(ctx.Err())
   447  
   448  	// Call shutdown and assert that we killed the plugins
   449  	m.Shutdown()
   450  
   451  	for _, resp := range catalog.Catalog()[base.PluginTypeDevice] {
   452  		pinst, _ := catalog.Dispense(resp.Name, resp.Type, &base.AgentConfig{}, config.Logger)
   453  		require.True(pinst.Exited())
   454  	}
   455  }
   456  
   457  // Test that startup shutsdown previously launched plugins
   458  func TestManager_Run_ShutdownOld(t *testing.T) {
   459  	ci.Parallel(t)
   460  	require := require.New(t)
   461  
   462  	config, _, catalog := baseTestConfig(t)
   463  	nvidiaAndIntelDefaultPlugins(catalog)
   464  
   465  	m := New(config)
   466  	m.Run()
   467  	defer m.Shutdown()
   468  
   469  	// Wait till we get a fingerprint result
   470  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   471  	defer cancel()
   472  	<-m.WaitForFirstFingerprint(ctx)
   473  	require.NoError(ctx.Err())
   474  
   475  	// Create a new manager with the same config so that it reads the old state
   476  	m2 := New(config)
   477  	go m2.Run()
   478  	defer m2.Shutdown()
   479  
   480  	testutil.WaitForResult(func() (bool, error) {
   481  		for _, resp := range catalog.Catalog()[base.PluginTypeDevice] {
   482  			pinst, _ := catalog.Dispense(resp.Name, resp.Type, &base.AgentConfig{}, config.Logger)
   483  			if !pinst.Exited() {
   484  				return false, fmt.Errorf("plugin %q not shutdown", resp.Name)
   485  			}
   486  		}
   487  
   488  		return true, nil
   489  	}, func(err error) {
   490  		t.Fatal(err)
   491  	})
   492  }