github.com/outbrain/consul@v1.4.5/agent/connect/ca/plugin/plugin_test.go (about) 1 package plugin 2 3 import ( 4 "crypto/x509" 5 "encoding/pem" 6 "errors" 7 "testing" 8 9 "github.com/hashicorp/consul/agent/connect" 10 "github.com/hashicorp/consul/agent/connect/ca" 11 "github.com/hashicorp/go-plugin" 12 "github.com/stretchr/testify/mock" 13 "github.com/stretchr/testify/require" 14 ) 15 16 func TestProvider_Configure(t *testing.T) { 17 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 18 require := require.New(t) 19 20 // Basic configure 21 m.On("Configure", "foo", false, map[string]interface{}{ 22 "string": "bar", 23 "number": float64(42), // because json 24 }).Once().Return(nil) 25 require.NoError(p.Configure("foo", false, map[string]interface{}{ 26 "string": "bar", 27 "number": float64(42), 28 })) 29 m.AssertExpectations(t) 30 31 // Try with an error 32 m.Mock = mock.Mock{} 33 m.On("Configure", "foo", false, map[string]interface{}{}).Once().Return(errors.New("hello world")) 34 err := p.Configure("foo", false, map[string]interface{}{}) 35 require.Error(err) 36 require.Contains(err.Error(), "hello") 37 m.AssertExpectations(t) 38 }) 39 } 40 41 func TestProvider_GenerateRoot(t *testing.T) { 42 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 43 require := require.New(t) 44 45 // Try cleanup with no error 46 m.On("GenerateRoot").Once().Return(nil) 47 require.NoError(p.GenerateRoot()) 48 m.AssertExpectations(t) 49 50 // Try with an error 51 m.Mock = mock.Mock{} 52 m.On("GenerateRoot").Once().Return(errors.New("hello world")) 53 err := p.GenerateRoot() 54 require.Error(err) 55 require.Contains(err.Error(), "hello") 56 m.AssertExpectations(t) 57 }) 58 } 59 60 func TestProvider_ActiveRoot(t *testing.T) { 61 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 62 require := require.New(t) 63 64 // Try cleanup with no error 65 m.On("ActiveRoot").Once().Return("foo", nil) 66 actual, err := p.ActiveRoot() 67 require.NoError(err) 68 require.Equal(actual, "foo") 69 m.AssertExpectations(t) 70 71 // Try with an error 72 m.Mock = mock.Mock{} 73 m.On("ActiveRoot").Once().Return("", errors.New("hello world")) 74 actual, err = p.ActiveRoot() 75 require.Error(err) 76 require.Contains(err.Error(), "hello") 77 m.AssertExpectations(t) 78 }) 79 } 80 81 func TestProvider_GenerateIntermediateCSR(t *testing.T) { 82 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 83 require := require.New(t) 84 85 // Try cleanup with no error 86 m.On("GenerateIntermediateCSR").Once().Return("foo", nil) 87 actual, err := p.GenerateIntermediateCSR() 88 require.NoError(err) 89 require.Equal(actual, "foo") 90 m.AssertExpectations(t) 91 92 // Try with an error 93 m.Mock = mock.Mock{} 94 m.On("GenerateIntermediateCSR").Once().Return("", errors.New("hello world")) 95 actual, err = p.GenerateIntermediateCSR() 96 require.Error(err) 97 require.Contains(err.Error(), "hello") 98 m.AssertExpectations(t) 99 }) 100 } 101 102 func TestProvider_SetIntermediate(t *testing.T) { 103 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 104 require := require.New(t) 105 106 // Try cleanup with no error 107 m.On("SetIntermediate", "foo", "bar").Once().Return(nil) 108 err := p.SetIntermediate("foo", "bar") 109 require.NoError(err) 110 m.AssertExpectations(t) 111 112 // Try with an error 113 m.Mock = mock.Mock{} 114 m.On("SetIntermediate", "foo", "bar").Once().Return(errors.New("hello world")) 115 err = p.SetIntermediate("foo", "bar") 116 require.Error(err) 117 require.Contains(err.Error(), "hello") 118 m.AssertExpectations(t) 119 }) 120 } 121 122 func TestProvider_ActiveIntermediate(t *testing.T) { 123 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 124 require := require.New(t) 125 126 // Try cleanup with no error 127 m.On("ActiveIntermediate").Once().Return("foo", nil) 128 actual, err := p.ActiveIntermediate() 129 require.NoError(err) 130 require.Equal(actual, "foo") 131 m.AssertExpectations(t) 132 133 // Try with an error 134 m.Mock = mock.Mock{} 135 m.On("ActiveIntermediate").Once().Return("", errors.New("hello world")) 136 actual, err = p.ActiveIntermediate() 137 require.Error(err) 138 require.Contains(err.Error(), "hello") 139 m.AssertExpectations(t) 140 }) 141 } 142 143 func TestProvider_GenerateIntermediate(t *testing.T) { 144 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 145 require := require.New(t) 146 147 // Try cleanup with no error 148 m.On("GenerateIntermediate").Once().Return("foo", nil) 149 actual, err := p.GenerateIntermediate() 150 require.NoError(err) 151 require.Equal(actual, "foo") 152 m.AssertExpectations(t) 153 154 // Try with an error 155 m.Mock = mock.Mock{} 156 m.On("GenerateIntermediate").Once().Return("", errors.New("hello world")) 157 actual, err = p.GenerateIntermediate() 158 require.Error(err) 159 require.Contains(err.Error(), "hello") 160 m.AssertExpectations(t) 161 }) 162 } 163 164 func TestProvider_Sign(t *testing.T) { 165 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 166 require := require.New(t) 167 168 // Create a CSR 169 csrPEM, _ := connect.TestCSR(t, connect.TestSpiffeIDService(t, "web")) 170 block, _ := pem.Decode([]byte(csrPEM)) 171 csr, err := x509.ParseCertificateRequest(block.Bytes) 172 require.NoError(err) 173 require.NoError(csr.CheckSignature()) 174 175 // No error 176 m.On("Sign", mock.Anything).Once().Return("foo", nil).Run(func(args mock.Arguments) { 177 csr := args.Get(0).(*x509.CertificateRequest) 178 require.NoError(csr.CheckSignature()) 179 }) 180 actual, err := p.Sign(csr) 181 require.NoError(err) 182 require.Equal(actual, "foo") 183 m.AssertExpectations(t) 184 185 // Try with an error 186 m.Mock = mock.Mock{} 187 m.On("Sign", mock.Anything).Once().Return("", errors.New("hello world")) 188 actual, err = p.Sign(csr) 189 require.Error(err) 190 require.Contains(err.Error(), "hello") 191 m.AssertExpectations(t) 192 }) 193 } 194 195 func TestProvider_SignIntermediate(t *testing.T) { 196 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 197 require := require.New(t) 198 199 // Create a CSR 200 csrPEM, _ := connect.TestCSR(t, connect.TestSpiffeIDService(t, "web")) 201 block, _ := pem.Decode([]byte(csrPEM)) 202 csr, err := x509.ParseCertificateRequest(block.Bytes) 203 require.NoError(err) 204 require.NoError(csr.CheckSignature()) 205 206 // No error 207 m.On("SignIntermediate", mock.Anything).Once().Return("foo", nil).Run(func(args mock.Arguments) { 208 csr := args.Get(0).(*x509.CertificateRequest) 209 require.NoError(csr.CheckSignature()) 210 }) 211 actual, err := p.SignIntermediate(csr) 212 require.NoError(err) 213 require.Equal(actual, "foo") 214 m.AssertExpectations(t) 215 216 // Try with an error 217 m.Mock = mock.Mock{} 218 m.On("SignIntermediate", mock.Anything).Once().Return("", errors.New("hello world")) 219 actual, err = p.SignIntermediate(csr) 220 require.Error(err) 221 require.Contains(err.Error(), "hello") 222 m.AssertExpectations(t) 223 }) 224 } 225 226 func TestProvider_CrossSignCA(t *testing.T) { 227 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 228 require := require.New(t) 229 230 // Create a CSR 231 root := connect.TestCA(t, nil) 232 block, _ := pem.Decode([]byte(root.RootCert)) 233 crt, err := x509.ParseCertificate(block.Bytes) 234 require.NoError(err) 235 236 // No error 237 m.On("CrossSignCA", mock.Anything).Once().Return("foo", nil).Run(func(args mock.Arguments) { 238 actual := args.Get(0).(*x509.Certificate) 239 require.True(crt.Equal(actual)) 240 }) 241 actual, err := p.CrossSignCA(crt) 242 require.NoError(err) 243 require.Equal(actual, "foo") 244 m.AssertExpectations(t) 245 246 // Try with an error 247 m.Mock = mock.Mock{} 248 m.On("CrossSignCA", mock.Anything).Once().Return("", errors.New("hello world")) 249 actual, err = p.CrossSignCA(crt) 250 require.Error(err) 251 require.Contains(err.Error(), "hello") 252 m.AssertExpectations(t) 253 }) 254 } 255 256 func TestProvider_Cleanup(t *testing.T) { 257 testPlugin(t, func(t *testing.T, m *ca.MockProvider, p ca.Provider) { 258 require := require.New(t) 259 260 // Try cleanup with no error 261 m.On("Cleanup").Once().Return(nil) 262 require.NoError(p.Cleanup()) 263 m.AssertExpectations(t) 264 265 // Try with an error 266 m.Mock = mock.Mock{} 267 m.On("Cleanup").Once().Return(errors.New("hello world")) 268 err := p.Cleanup() 269 require.Error(err) 270 require.Contains(err.Error(), "hello") 271 m.AssertExpectations(t) 272 }) 273 } 274 275 // testPlugin runs the given test function callback for all supported 276 // transports of the plugin RPC layer. 277 func testPlugin(t *testing.T, f func(t *testing.T, m *ca.MockProvider, actual ca.Provider)) { 278 t.Run("net/rpc", func(t *testing.T) { 279 // Create a mock provider 280 mockP := new(ca.MockProvider) 281 client, _ := plugin.TestPluginRPCConn(t, map[string]plugin.Plugin{ 282 Name: &ProviderPlugin{Impl: mockP}, 283 }, nil) 284 defer client.Close() 285 286 // Request the provider 287 raw, err := client.Dispense(Name) 288 require.NoError(t, err) 289 provider := raw.(ca.Provider) 290 291 // Call the test function 292 f(t, mockP, provider) 293 }) 294 295 t.Run("gRPC", func(t *testing.T) { 296 // Create a mock provider 297 mockP := new(ca.MockProvider) 298 client, _ := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ 299 Name: &ProviderPlugin{Impl: mockP}, 300 }) 301 defer client.Close() 302 303 // Request the provider 304 raw, err := client.Dispense(Name) 305 require.NoError(t, err) 306 provider := raw.(ca.Provider) 307 308 // Call the test function 309 f(t, mockP, provider) 310 }) 311 }