github.com/outbrain/consul@v1.4.5/agent/connect/ca/plugin/plugin_test.go (about)

     1  package plugin
     2  
     3  import (
     4  	"crypto/x509"
     5  	"encoding/pem"
     6  	"errors"
     7  	"testing"
     8  
     9  	"github.com/hashicorp/consul/agent/connect"
    10  	"github.com/hashicorp/consul/agent/connect/ca"
    11  	"github.com/hashicorp/go-plugin"
    12  	"github.com/stretchr/testify/mock"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func TestProvider_Configure(t *testing.T) {
    17  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
    18  		require := require.New(t)
    19  
    20  		// Basic configure
    21  		m.On("Configure", "foo", false, map[string]interface{}{
    22  			"string": "bar",
    23  			"number": float64(42), // because json
    24  		}).Once().Return(nil)
    25  		require.NoError(p.Configure("foo", false, map[string]interface{}{
    26  			"string": "bar",
    27  			"number": float64(42),
    28  		}))
    29  		m.AssertExpectations(t)
    30  
    31  		// Try with an error
    32  		m.Mock = mock.Mock{}
    33  		m.On("Configure", "foo", false, map[string]interface{}{}).Once().Return(errors.New("hello world"))
    34  		err := p.Configure("foo", false, map[string]interface{}{})
    35  		require.Error(err)
    36  		require.Contains(err.Error(), "hello")
    37  		m.AssertExpectations(t)
    38  	})
    39  }
    40  
    41  func TestProvider_GenerateRoot(t *testing.T) {
    42  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
    43  		require := require.New(t)
    44  
    45  		// Try cleanup with no error
    46  		m.On("GenerateRoot").Once().Return(nil)
    47  		require.NoError(p.GenerateRoot())
    48  		m.AssertExpectations(t)
    49  
    50  		// Try with an error
    51  		m.Mock = mock.Mock{}
    52  		m.On("GenerateRoot").Once().Return(errors.New("hello world"))
    53  		err := p.GenerateRoot()
    54  		require.Error(err)
    55  		require.Contains(err.Error(), "hello")
    56  		m.AssertExpectations(t)
    57  	})
    58  }
    59  
    60  func TestProvider_ActiveRoot(t *testing.T) {
    61  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
    62  		require := require.New(t)
    63  
    64  		// Try cleanup with no error
    65  		m.On("ActiveRoot").Once().Return("foo", nil)
    66  		actual, err := p.ActiveRoot()
    67  		require.NoError(err)
    68  		require.Equal(actual, "foo")
    69  		m.AssertExpectations(t)
    70  
    71  		// Try with an error
    72  		m.Mock = mock.Mock{}
    73  		m.On("ActiveRoot").Once().Return("", errors.New("hello world"))
    74  		actual, err = p.ActiveRoot()
    75  		require.Error(err)
    76  		require.Contains(err.Error(), "hello")
    77  		m.AssertExpectations(t)
    78  	})
    79  }
    80  
    81  func TestProvider_GenerateIntermediateCSR(t *testing.T) {
    82  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
    83  		require := require.New(t)
    84  
    85  		// Try cleanup with no error
    86  		m.On("GenerateIntermediateCSR").Once().Return("foo", nil)
    87  		actual, err := p.GenerateIntermediateCSR()
    88  		require.NoError(err)
    89  		require.Equal(actual, "foo")
    90  		m.AssertExpectations(t)
    91  
    92  		// Try with an error
    93  		m.Mock = mock.Mock{}
    94  		m.On("GenerateIntermediateCSR").Once().Return("", errors.New("hello world"))
    95  		actual, err = p.GenerateIntermediateCSR()
    96  		require.Error(err)
    97  		require.Contains(err.Error(), "hello")
    98  		m.AssertExpectations(t)
    99  	})
   100  }
   101  
   102  func TestProvider_SetIntermediate(t *testing.T) {
   103  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   104  		require := require.New(t)
   105  
   106  		// Try cleanup with no error
   107  		m.On("SetIntermediate", "foo", "bar").Once().Return(nil)
   108  		err := p.SetIntermediate("foo", "bar")
   109  		require.NoError(err)
   110  		m.AssertExpectations(t)
   111  
   112  		// Try with an error
   113  		m.Mock = mock.Mock{}
   114  		m.On("SetIntermediate", "foo", "bar").Once().Return(errors.New("hello world"))
   115  		err = p.SetIntermediate("foo", "bar")
   116  		require.Error(err)
   117  		require.Contains(err.Error(), "hello")
   118  		m.AssertExpectations(t)
   119  	})
   120  }
   121  
   122  func TestProvider_ActiveIntermediate(t *testing.T) {
   123  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   124  		require := require.New(t)
   125  
   126  		// Try cleanup with no error
   127  		m.On("ActiveIntermediate").Once().Return("foo", nil)
   128  		actual, err := p.ActiveIntermediate()
   129  		require.NoError(err)
   130  		require.Equal(actual, "foo")
   131  		m.AssertExpectations(t)
   132  
   133  		// Try with an error
   134  		m.Mock = mock.Mock{}
   135  		m.On("ActiveIntermediate").Once().Return("", errors.New("hello world"))
   136  		actual, err = p.ActiveIntermediate()
   137  		require.Error(err)
   138  		require.Contains(err.Error(), "hello")
   139  		m.AssertExpectations(t)
   140  	})
   141  }
   142  
   143  func TestProvider_GenerateIntermediate(t *testing.T) {
   144  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   145  		require := require.New(t)
   146  
   147  		// Try cleanup with no error
   148  		m.On("GenerateIntermediate").Once().Return("foo", nil)
   149  		actual, err := p.GenerateIntermediate()
   150  		require.NoError(err)
   151  		require.Equal(actual, "foo")
   152  		m.AssertExpectations(t)
   153  
   154  		// Try with an error
   155  		m.Mock = mock.Mock{}
   156  		m.On("GenerateIntermediate").Once().Return("", errors.New("hello world"))
   157  		actual, err = p.GenerateIntermediate()
   158  		require.Error(err)
   159  		require.Contains(err.Error(), "hello")
   160  		m.AssertExpectations(t)
   161  	})
   162  }
   163  
   164  func TestProvider_Sign(t *testing.T) {
   165  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   166  		require := require.New(t)
   167  
   168  		// Create a CSR
   169  		csrPEM, _ := connect.TestCSR(t, connect.TestSpiffeIDService(t, "web"))
   170  		block, _ := pem.Decode([]byte(csrPEM))
   171  		csr, err := x509.ParseCertificateRequest(block.Bytes)
   172  		require.NoError(err)
   173  		require.NoError(csr.CheckSignature())
   174  
   175  		// No error
   176  		m.On("Sign", mock.Anything).Once().Return("foo", nil).Run(func(args mock.Arguments) {
   177  			csr := args.Get(0).(*x509.CertificateRequest)
   178  			require.NoError(csr.CheckSignature())
   179  		})
   180  		actual, err := p.Sign(csr)
   181  		require.NoError(err)
   182  		require.Equal(actual, "foo")
   183  		m.AssertExpectations(t)
   184  
   185  		// Try with an error
   186  		m.Mock = mock.Mock{}
   187  		m.On("Sign", mock.Anything).Once().Return("", errors.New("hello world"))
   188  		actual, err = p.Sign(csr)
   189  		require.Error(err)
   190  		require.Contains(err.Error(), "hello")
   191  		m.AssertExpectations(t)
   192  	})
   193  }
   194  
   195  func TestProvider_SignIntermediate(t *testing.T) {
   196  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   197  		require := require.New(t)
   198  
   199  		// Create a CSR
   200  		csrPEM, _ := connect.TestCSR(t, connect.TestSpiffeIDService(t, "web"))
   201  		block, _ := pem.Decode([]byte(csrPEM))
   202  		csr, err := x509.ParseCertificateRequest(block.Bytes)
   203  		require.NoError(err)
   204  		require.NoError(csr.CheckSignature())
   205  
   206  		// No error
   207  		m.On("SignIntermediate", mock.Anything).Once().Return("foo", nil).Run(func(args mock.Arguments) {
   208  			csr := args.Get(0).(*x509.CertificateRequest)
   209  			require.NoError(csr.CheckSignature())
   210  		})
   211  		actual, err := p.SignIntermediate(csr)
   212  		require.NoError(err)
   213  		require.Equal(actual, "foo")
   214  		m.AssertExpectations(t)
   215  
   216  		// Try with an error
   217  		m.Mock = mock.Mock{}
   218  		m.On("SignIntermediate", mock.Anything).Once().Return("", errors.New("hello world"))
   219  		actual, err = p.SignIntermediate(csr)
   220  		require.Error(err)
   221  		require.Contains(err.Error(), "hello")
   222  		m.AssertExpectations(t)
   223  	})
   224  }
   225  
   226  func TestProvider_CrossSignCA(t *testing.T) {
   227  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   228  		require := require.New(t)
   229  
   230  		// Create a CSR
   231  		root := connect.TestCA(t, nil)
   232  		block, _ := pem.Decode([]byte(root.RootCert))
   233  		crt, err := x509.ParseCertificate(block.Bytes)
   234  		require.NoError(err)
   235  
   236  		// No error
   237  		m.On("CrossSignCA", mock.Anything).Once().Return("foo", nil).Run(func(args mock.Arguments) {
   238  			actual := args.Get(0).(*x509.Certificate)
   239  			require.True(crt.Equal(actual))
   240  		})
   241  		actual, err := p.CrossSignCA(crt)
   242  		require.NoError(err)
   243  		require.Equal(actual, "foo")
   244  		m.AssertExpectations(t)
   245  
   246  		// Try with an error
   247  		m.Mock = mock.Mock{}
   248  		m.On("CrossSignCA", mock.Anything).Once().Return("", errors.New("hello world"))
   249  		actual, err = p.CrossSignCA(crt)
   250  		require.Error(err)
   251  		require.Contains(err.Error(), "hello")
   252  		m.AssertExpectations(t)
   253  	})
   254  }
   255  
   256  func TestProvider_Cleanup(t *testing.T) {
   257  	testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) {
   258  		require := require.New(t)
   259  
   260  		// Try cleanup with no error
   261  		m.On("Cleanup").Once().Return(nil)
   262  		require.NoError(p.Cleanup())
   263  		m.AssertExpectations(t)
   264  
   265  		// Try with an error
   266  		m.Mock = mock.Mock{}
   267  		m.On("Cleanup").Once().Return(errors.New("hello world"))
   268  		err := p.Cleanup()
   269  		require.Error(err)
   270  		require.Contains(err.Error(), "hello")
   271  		m.AssertExpectations(t)
   272  	})
   273  }
   274  
   275  // testPlugin runs the given test function callback for all supported
   276  // transports of the plugin RPC layer.
   277  func testPlugin(t *testing.T, f func(t *testing.T, m *ca.MockProvider, actual ca.Provider)) {
   278  	t.Run("net/rpc", func(t *testing.T) {
   279  		// Create a mock provider
   280  		mockP := new(ca.MockProvider)
   281  		client, _ := plugin.TestPluginRPCConn(t, map[string]plugin.Plugin{
   282  			Name: &ProviderPlugin{Impl: mockP},
   283  		}, nil)
   284  		defer client.Close()
   285  
   286  		// Request the provider
   287  		raw, err := client.Dispense(Name)
   288  		require.NoError(t, err)
   289  		provider := raw.(ca.Provider)
   290  
   291  		// Call the test function
   292  		f(t, mockP, provider)
   293  	})
   294  
   295  	t.Run("gRPC", func(t *testing.T) {
   296  		// Create a mock provider
   297  		mockP := new(ca.MockProvider)
   298  		client, _ := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   299  			Name: &ProviderPlugin{Impl: mockP},
   300  		})
   301  		defer client.Close()
   302  
   303  		// Request the provider
   304  		raw, err := client.Dispense(Name)
   305  		require.NoError(t, err)
   306  		provider := raw.(ca.Provider)
   307  
   308  		// Call the test function
   309  		f(t, mockP, provider)
   310  	})
   311  }