github.com/zoomfoo/nomad@v0.8.5-0.20180907175415-f28fd3a1a056/plugins/base/plugin_test.go (about)

     1  package base
     2  
     3  import (
     4  	"testing"
     5  
     6  	pb "github.com/golang/protobuf/proto"
     7  	plugin "github.com/hashicorp/go-plugin"
     8  	"github.com/hashicorp/nomad/nomad/structs"
     9  	"github.com/hashicorp/nomad/plugins/shared/hclspec"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/zclconf/go-cty/cty"
    12  	"github.com/zclconf/go-cty/cty/msgpack"
    13  )
    14  
    15  func TestBasePlugin_PluginInfo_GRPC(t *testing.T) {
    16  	t.Parallel()
    17  	require := require.New(t)
    18  
    19  	const (
    20  		apiVersion    = "v0.1.0"
    21  		pluginVersion = "v0.2.1"
    22  		pluginName    = "mock"
    23  	)
    24  
    25  	knownType := func() (*PluginInfoResponse, error) {
    26  		info := &PluginInfoResponse{
    27  			Type:             PluginTypeDriver,
    28  			PluginApiVersion: apiVersion,
    29  			PluginVersion:    pluginVersion,
    30  			Name:             pluginName,
    31  		}
    32  		return info, nil
    33  	}
    34  	unknownType := func() (*PluginInfoResponse, error) {
    35  		info := &PluginInfoResponse{
    36  			Type:             "bad",
    37  			PluginApiVersion: apiVersion,
    38  			PluginVersion:    pluginVersion,
    39  			Name:             pluginName,
    40  		}
    41  		return info, nil
    42  	}
    43  
    44  	mock := &MockPlugin{
    45  		PluginInfoF: knownType,
    46  	}
    47  
    48  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
    49  		PluginTypeBase: &PluginBase{Impl: mock},
    50  	})
    51  	defer server.Stop()
    52  	defer client.Close()
    53  
    54  	raw, err := client.Dispense(PluginTypeBase)
    55  	if err != nil {
    56  		t.Fatalf("err: %s", err)
    57  	}
    58  
    59  	impl, ok := raw.(BasePlugin)
    60  	if !ok {
    61  		t.Fatalf("bad: %#v", raw)
    62  	}
    63  
    64  	resp, err := impl.PluginInfo()
    65  	require.NoError(err)
    66  	require.Equal(apiVersion, resp.PluginApiVersion)
    67  	require.Equal(pluginVersion, resp.PluginVersion)
    68  	require.Equal(pluginName, resp.Name)
    69  	require.Equal(PluginTypeDriver, resp.Type)
    70  
    71  	// Swap the implementation to return an unknown type
    72  	mock.PluginInfoF = unknownType
    73  	_, err = impl.PluginInfo()
    74  	require.Error(err)
    75  	require.Contains(err.Error(), "unknown type")
    76  }
    77  
    78  func TestBasePlugin_ConfigSchema(t *testing.T) {
    79  	t.Parallel()
    80  	require := require.New(t)
    81  
    82  	mock := &MockPlugin{
    83  		ConfigSchemaF: func() (*hclspec.Spec, error) {
    84  			return TestSpec, nil
    85  		},
    86  	}
    87  
    88  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
    89  		PluginTypeBase: &PluginBase{Impl: mock},
    90  	})
    91  	defer server.Stop()
    92  	defer client.Close()
    93  
    94  	raw, err := client.Dispense(PluginTypeBase)
    95  	if err != nil {
    96  		t.Fatalf("err: %s", err)
    97  	}
    98  
    99  	impl, ok := raw.(BasePlugin)
   100  	if !ok {
   101  		t.Fatalf("bad: %#v", raw)
   102  	}
   103  
   104  	specOut, err := impl.ConfigSchema()
   105  	require.NoError(err)
   106  	require.True(pb.Equal(TestSpec, specOut))
   107  }
   108  
   109  func TestBasePlugin_SetConfig(t *testing.T) {
   110  	t.Parallel()
   111  	require := require.New(t)
   112  
   113  	var receivedData []byte
   114  	mock := &MockPlugin{
   115  		ConfigSchemaF: func() (*hclspec.Spec, error) {
   116  			return TestSpec, nil
   117  		},
   118  		SetConfigF: func(data []byte) error {
   119  			receivedData = data
   120  			return nil
   121  		},
   122  	}
   123  
   124  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   125  		PluginTypeBase: &PluginBase{Impl: mock},
   126  	})
   127  	defer server.Stop()
   128  	defer client.Close()
   129  
   130  	raw, err := client.Dispense(PluginTypeBase)
   131  	if err != nil {
   132  		t.Fatalf("err: %s", err)
   133  	}
   134  
   135  	impl, ok := raw.(BasePlugin)
   136  	if !ok {
   137  		t.Fatalf("bad: %#v", raw)
   138  	}
   139  
   140  	config := cty.ObjectVal(map[string]cty.Value{
   141  		"foo": cty.StringVal("v1"),
   142  		"bar": cty.NumberIntVal(1337),
   143  		"baz": cty.BoolVal(true),
   144  	})
   145  	cdata, err := msgpack.Marshal(config, config.Type())
   146  	require.NoError(err)
   147  	require.NoError(impl.SetConfig(cdata))
   148  	require.Equal(cdata, receivedData)
   149  
   150  	// Decode the value back
   151  	var actual TestConfig
   152  	require.NoError(structs.Decode(receivedData, &actual))
   153  	require.Equal("v1", actual.Foo)
   154  	require.EqualValues(1337, actual.Bar)
   155  	require.True(actual.Baz)
   156  }