github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/device_hook_test.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/hashicorp/nomad/ci"
     9  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    10  	"github.com/hashicorp/nomad/client/devicemanager"
    11  	"github.com/hashicorp/nomad/helper/testlog"
    12  	"github.com/hashicorp/nomad/nomad/structs"
    13  	"github.com/hashicorp/nomad/plugins/device"
    14  	"github.com/hashicorp/nomad/plugins/drivers"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestDeviceHook_CorrectDevice(t *testing.T) {
    19  	ci.Parallel(t)
    20  	require := require.New(t)
    21  
    22  	dm := devicemanager.NoopMockManager()
    23  	l := testlog.HCLogger(t)
    24  	h := newDeviceHook(dm, l)
    25  
    26  	reqDev := &structs.AllocatedDeviceResource{
    27  		Vendor:    "foo",
    28  		Type:      "bar",
    29  		Name:      "baz",
    30  		DeviceIDs: []string{"123"},
    31  	}
    32  
    33  	// Build the hook request
    34  	req := &interfaces.TaskPrestartRequest{
    35  		TaskResources: &structs.AllocatedTaskResources{
    36  			Devices: []*structs.AllocatedDeviceResource{
    37  				reqDev,
    38  			},
    39  		},
    40  	}
    41  
    42  	// Setup the device manager to return a response
    43  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
    44  		if d.Vendor != reqDev.Vendor || d.Type != reqDev.Type ||
    45  			d.Name != reqDev.Name || len(d.DeviceIDs) != 1 || d.DeviceIDs[0] != reqDev.DeviceIDs[0] {
    46  			return nil, fmt.Errorf("unexpected request: %+v", d)
    47  		}
    48  
    49  		res := &device.ContainerReservation{
    50  			Envs: map[string]string{
    51  				"123": "456",
    52  			},
    53  			Mounts: []*device.Mount{
    54  				{
    55  					ReadOnly: true,
    56  					TaskPath: "foo",
    57  					HostPath: "bar",
    58  				},
    59  			},
    60  			Devices: []*device.DeviceSpec{
    61  				{
    62  					TaskPath:    "foo",
    63  					HostPath:    "bar",
    64  					CgroupPerms: "123",
    65  				},
    66  			},
    67  		}
    68  		return res, nil
    69  	}
    70  
    71  	var resp interfaces.TaskPrestartResponse
    72  	err := h.Prestart(context.Background(), req, &resp)
    73  	require.NoError(err)
    74  	require.NotNil(resp)
    75  
    76  	expEnv := map[string]string{
    77  		"123": "456",
    78  	}
    79  	require.EqualValues(expEnv, resp.Env)
    80  
    81  	expMounts := []*drivers.MountConfig{
    82  		{
    83  			Readonly: true,
    84  			TaskPath: "foo",
    85  			HostPath: "bar",
    86  		},
    87  	}
    88  	require.EqualValues(expMounts, resp.Mounts)
    89  
    90  	expDevices := []*drivers.DeviceConfig{
    91  		{
    92  			TaskPath:    "foo",
    93  			HostPath:    "bar",
    94  			Permissions: "123",
    95  		},
    96  	}
    97  	require.EqualValues(expDevices, resp.Devices)
    98  }
    99  
   100  func TestDeviceHook_IncorrectDevice(t *testing.T) {
   101  	ci.Parallel(t)
   102  	require := require.New(t)
   103  
   104  	dm := devicemanager.NoopMockManager()
   105  	l := testlog.HCLogger(t)
   106  	h := newDeviceHook(dm, l)
   107  
   108  	reqDev := &structs.AllocatedDeviceResource{
   109  		Vendor:    "foo",
   110  		Type:      "bar",
   111  		Name:      "baz",
   112  		DeviceIDs: []string{"123"},
   113  	}
   114  
   115  	// Build the hook request
   116  	req := &interfaces.TaskPrestartRequest{
   117  		TaskResources: &structs.AllocatedTaskResources{
   118  			Devices: []*structs.AllocatedDeviceResource{
   119  				reqDev,
   120  			},
   121  		},
   122  	}
   123  
   124  	// Setup the device manager to return a response
   125  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
   126  		return nil, fmt.Errorf("bad request")
   127  	}
   128  
   129  	var resp interfaces.TaskPrestartResponse
   130  	err := h.Prestart(context.Background(), req, &resp)
   131  	require.Error(err)
   132  }