github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/allocrunner/taskrunner/device_hook_test.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
     9  	"github.com/hashicorp/nomad/client/devicemanager"
    10  	"github.com/hashicorp/nomad/helper/testlog"
    11  	"github.com/hashicorp/nomad/nomad/structs"
    12  	"github.com/hashicorp/nomad/plugins/device"
    13  	"github.com/hashicorp/nomad/plugins/drivers"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestDeviceHook_CorrectDevice(t *testing.T) {
    18  	t.Parallel()
    19  	require := require.New(t)
    20  
    21  	dm := devicemanager.NoopMockManager()
    22  	l := testlog.HCLogger(t)
    23  	h := newDeviceHook(dm, l)
    24  
    25  	reqDev := &structs.AllocatedDeviceResource{
    26  		Vendor:    "foo",
    27  		Type:      "bar",
    28  		Name:      "baz",
    29  		DeviceIDs: []string{"123"},
    30  	}
    31  
    32  	// Build the hook request
    33  	req := &interfaces.TaskPrestartRequest{
    34  		TaskResources: &structs.AllocatedTaskResources{
    35  			Devices: []*structs.AllocatedDeviceResource{
    36  				reqDev,
    37  			},
    38  		},
    39  	}
    40  
    41  	// Setup the device manager to return a response
    42  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
    43  		if d.Vendor != reqDev.Vendor || d.Type != reqDev.Type ||
    44  			d.Name != reqDev.Name || len(d.DeviceIDs) != 1 || d.DeviceIDs[0] != reqDev.DeviceIDs[0] {
    45  			return nil, fmt.Errorf("unexpected request: %+v", d)
    46  		}
    47  
    48  		res := &device.ContainerReservation{
    49  			Envs: map[string]string{
    50  				"123": "456",
    51  			},
    52  			Mounts: []*device.Mount{
    53  				{
    54  					ReadOnly: true,
    55  					TaskPath: "foo",
    56  					HostPath: "bar",
    57  				},
    58  			},
    59  			Devices: []*device.DeviceSpec{
    60  				{
    61  					TaskPath:    "foo",
    62  					HostPath:    "bar",
    63  					CgroupPerms: "123",
    64  				},
    65  			},
    66  		}
    67  		return res, nil
    68  	}
    69  
    70  	var resp interfaces.TaskPrestartResponse
    71  	err := h.Prestart(context.Background(), req, &resp)
    72  	require.NoError(err)
    73  	require.NotNil(resp)
    74  
    75  	expEnv := map[string]string{
    76  		"123": "456",
    77  	}
    78  	require.EqualValues(expEnv, resp.Env)
    79  
    80  	expMounts := []*drivers.MountConfig{
    81  		{
    82  			Readonly: true,
    83  			TaskPath: "foo",
    84  			HostPath: "bar",
    85  		},
    86  	}
    87  	require.EqualValues(expMounts, resp.Mounts)
    88  
    89  	expDevices := []*drivers.DeviceConfig{
    90  		{
    91  			TaskPath:    "foo",
    92  			HostPath:    "bar",
    93  			Permissions: "123",
    94  		},
    95  	}
    96  	require.EqualValues(expDevices, resp.Devices)
    97  }
    98  
    99  func TestDeviceHook_IncorrectDevice(t *testing.T) {
   100  	t.Parallel()
   101  	require := require.New(t)
   102  
   103  	dm := devicemanager.NoopMockManager()
   104  	l := testlog.HCLogger(t)
   105  	h := newDeviceHook(dm, l)
   106  
   107  	reqDev := &structs.AllocatedDeviceResource{
   108  		Vendor:    "foo",
   109  		Type:      "bar",
   110  		Name:      "baz",
   111  		DeviceIDs: []string{"123"},
   112  	}
   113  
   114  	// Build the hook request
   115  	req := &interfaces.TaskPrestartRequest{
   116  		TaskResources: &structs.AllocatedTaskResources{
   117  			Devices: []*structs.AllocatedDeviceResource{
   118  				reqDev,
   119  			},
   120  		},
   121  	}
   122  
   123  	// Setup the device manager to return a response
   124  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
   125  		return nil, fmt.Errorf("bad request")
   126  	}
   127  
   128  	var resp interfaces.TaskPrestartResponse
   129  	err := h.Prestart(context.Background(), req, &resp)
   130  	require.Error(err)
   131  }