github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/security/puppetsec/puppet_security_test.go (about) 1 // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package puppetsec 6 7 import ( 8 "crypto/x509" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "os" 13 "path/filepath" 14 "runtime" 15 "testing" 16 17 "github.com/choria-io/go-choria/build" 18 "github.com/choria-io/go-choria/config" 19 "github.com/choria-io/go-choria/inter" 20 "github.com/choria-io/go-choria/srvcache" 21 "github.com/sirupsen/logrus" 22 23 "github.com/golang/mock/gomock" 24 . "github.com/onsi/ginkgo/v2" 25 . "github.com/onsi/gomega" 26 ) 27 28 func TestPuppetSecurity(t *testing.T) { 29 RegisterFailHandler(Fail) 30 RunSpecs(t, "Providers/Security/Puppet") 31 } 32 33 var _ = Describe("PuppetSSL", func() { 34 var mockctl *gomock.Controller 35 var resolver *MockResolver 36 var cfg *Config 37 var err error 38 var prov *PuppetSecurity 39 var l *logrus.Logger 40 41 BeforeEach(func() { 42 mockctl = gomock.NewController(GinkgoT()) 43 resolver = NewMockResolver(mockctl) 44 os.Setenv("MCOLLECTIVE_CERTNAME", "rip.mcollective") 45 46 cfg = &Config{ 47 SSLDir: filepath.Join("..", "testdata", "good"), 48 Identity: "rip.mcollective", 49 PuppetCAHost: "puppet", 50 PuppetCAPort: 8140, 51 DisableSRV: true, 52 useFakeUID: true, 53 fakeUID: 500, 54 } 55 56 l = logrus.New() 57 l.SetOutput(GinkgoWriter) 58 59 prov, err = New(WithConfig(cfg), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 60 Expect(err).ToNot(HaveOccurred()) 61 }) 62 63 AfterEach(func() { 64 mockctl.Finish() 65 }) 66 67 It("Should implement the provider interface", func() { 68 f := func(p inter.SecurityProvider) {} 69 f(prov) 70 Expect(prov.Provider()).To(Equal("puppet")) 71 }) 72 73 Describe("WithChoriaConfig", func() { 74 It("Should disable SRV when the CA is configured", func() { 75 c, err := config.NewConfig(filepath.Join("..", "testdata", "puppetca.cfg")) 76 Expect(err).ToNot(HaveOccurred()) 77 78 prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 79 Expect(err).ToNot(HaveOccurred()) 80 81 Expect(prov.conf.DisableSRV).To(BeTrue()) 82 }) 83 84 It("Should support OverrideCertname", func() { 85 c := config.NewConfigForTests() 86 87 c.OverrideCertname = "override.choria" 88 prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 89 Expect(err).ToNot(HaveOccurred()) 90 91 Expect(prov.conf.Identity).To(Equal("override.choria")) 92 }) 93 94 // TODO: windows 95 if runtime.GOOS != "windows" { 96 It("Should fail when it cannot determine user identity", func() { 97 c := config.NewConfigForTests() 98 c.OverrideCertname = "" 99 v := os.Getenv("USER") 100 defer os.Setenv("USER", v) 101 os.Unsetenv("USER") 102 os.Unsetenv("MCOLLECTIVE_CERTNAME") 103 _, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 104 Expect(err).To(MatchError("could not determine client identity, ensure USER environment variable is set")) 105 }) 106 107 It("Should use the user SSL directory when not configured", func() { 108 c, err := config.NewDefaultConfig() 109 Expect(err).ToNot(HaveOccurred()) 110 111 prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 112 Expect(err).ToNot(HaveOccurred()) 113 114 d, err := userSSlDir() 115 Expect(err).ToNot(HaveOccurred()) 116 117 Expect(prov.conf.SSLDir).To(Equal(d)) 118 }) 119 } 120 121 It("Should copy all the relevant settings", func() { 122 c, err := config.NewDefaultConfig() 123 Expect(err).ToNot(HaveOccurred()) 124 125 c.DisableTLSVerify = true 126 c.Choria.SSLDir = "/stub" 127 c.Choria.PuppetCAHost = "stubhost" 128 c.Choria.PuppetCAPort = 8080 129 130 prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 131 Expect(err).ToNot(HaveOccurred()) 132 133 Expect(prov.conf.AllowList).To(Equal([]string{"\\.mcollective$", "\\.choria$"})) 134 Expect(prov.conf.PrivilegedUsers).To(Equal([]string{"\\.privileged.mcollective$", "\\.privileged.choria$"})) 135 Expect(prov.conf.DisableTLSVerify).To(BeTrue()) 136 Expect(prov.conf.SSLDir).To(Equal("/stub")) 137 Expect(prov.conf.PuppetCAHost).To(Equal("stubhost")) 138 Expect(prov.conf.PuppetCAPort).To(Equal(8080)) 139 }) 140 }) 141 142 Describe("Validate", func() { 143 It("Should handle missing files", func() { 144 cfg.SSLDir = filepath.Join("testdata", "allmissing") 145 cfg.Identity = "test.mcollective" 146 prov, err = New(WithConfig(cfg), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{}))) 147 148 Expect(err).ToNot(HaveOccurred()) 149 150 errs, ok := prov.Validate() 151 152 Expect(ok).To(BeFalse()) 153 Expect(errs).To(HaveLen(3)) 154 Expect(errs[0]).To(Equal(fmt.Sprintf("public certificate %s does not exist", filepath.Join(cfg.SSLDir, "certs", "test.mcollective.pem")))) 155 Expect(errs[1]).To(Equal(fmt.Sprintf("private key %s does not exist", filepath.Join(cfg.SSLDir, "private_keys", "test.mcollective.pem")))) 156 Expect(errs[2]).To(Equal(fmt.Sprintf("CA %s does not exist", filepath.Join(cfg.SSLDir, "certs", "ca.pem")))) 157 }) 158 159 It("Should accept valid directories", func() { 160 cfg.Identity = "rip.mcollective" 161 errs, ok := prov.Validate() 162 Expect(errs).To(BeEmpty()) 163 Expect(ok).To(BeTrue()) 164 }) 165 }) 166 167 Describe("Identity", func() { 168 It("Should support OverrideCertname", func() { 169 cfg.Identity = "bob.choria" 170 prov.reinit() 171 172 Expect(prov.Identity()).To(Equal("bob.choria")) 173 }) 174 }) 175 176 Describe("writeCSR", func() { 177 It("should not write over existing CSRs", func() { 178 cfg.Identity = "na.mcollective" 179 prov.reinit() 180 181 kpath := prov.privateKeyPath() 182 csrpath := prov.csrPath() 183 184 defer os.Remove(kpath) 185 defer os.Remove(csrpath) 186 187 key, err := prov.writePrivateKey() 188 Expect(err).ToNot(HaveOccurred()) 189 190 prov.conf.Identity = "rip.mcollective" 191 prov.reinit() 192 _, err = prov.writeCSR(key, "rip.mcollective", "choria.io") 193 194 Expect(err).To(MatchError("a certificate request already exist for rip.mcollective")) 195 }) 196 197 It("Should create a valid CSR", func() { 198 prov.conf.Identity = "na.mcollective" 199 prov.reinit() 200 201 kpath := prov.privateKeyPath() 202 csrpath := prov.csrPath() 203 204 defer os.Remove(kpath) 205 defer os.Remove(csrpath) 206 207 key, err := prov.writePrivateKey() 208 Expect(err).ToNot(HaveOccurred()) 209 210 _, err = prov.writeCSR(key, "na.mcollective", "choria.io") 211 Expect(err).ToNot(HaveOccurred()) 212 213 csrpem, err := os.ReadFile(csrpath) 214 Expect(err).ToNot(HaveOccurred()) 215 216 pb, _ := pem.Decode(csrpem) 217 218 req, err := x509.ParseCertificateRequest(pb.Bytes) 219 Expect(err).ToNot(HaveOccurred()) 220 Expect(req.Subject.CommonName).To(Equal("na.mcollective")) 221 Expect(req.Subject.OrganizationalUnit).To(Equal([]string{"choria.io"})) 222 }) 223 }) 224 225 Describe("writePrivateKey", func() { 226 It("Should not write over existing private keys", func() { 227 cfg.Identity = "rip.mcollective" 228 key, err := prov.writePrivateKey() 229 Expect(err).To(MatchError("a private key already exist for rip.mcollective")) 230 Expect(key).To(BeNil()) 231 }) 232 233 It("Should create new keys", func() { 234 cfg.Identity = "na.mcollective" 235 prov.reinit() 236 237 path := prov.privateKeyPath() 238 defer os.Remove(path) 239 240 key, err := prov.writePrivateKey() 241 Expect(err).ToNot(HaveOccurred()) 242 Expect(key).ToNot(BeNil()) 243 Expect(path).To(BeAnExistingFile()) 244 }) 245 }) 246 247 Describe("csrExists", func() { 248 It("Should detect existing keys", func() { 249 cfg.Identity = "rip.mcollective" 250 prov.reinit() 251 252 Expect(prov.csrExists()).To(BeTrue()) 253 }) 254 255 It("Should detect absent keys", func() { 256 cfg.Identity = "na.mcollective" 257 prov.reinit() 258 259 Expect(prov.csrExists()).To(BeFalse()) 260 }) 261 }) 262 263 Describe("puppetCA", func() { 264 It("Should use supplied config when SRV is disabled", func() { 265 cfg.DisableSRV = true 266 s := prov.puppetCA() 267 Expect(s.Host()).To(Equal("puppet")) 268 Expect(s.Port()).To(Equal(uint16(8140))) 269 Expect(s.Scheme()).To(Equal("https")) 270 }) 271 272 It("Should use supplied config when no srv resolver is given", func() { 273 prov, err = New(WithConfig(cfg), WithLog(l.WithFields(logrus.Fields{}))) 274 Expect(err).ToNot(HaveOccurred()) 275 276 resolver.EXPECT().QuerySrvRecords(gomock.Any()).Times(0) 277 278 s := prov.puppetCA() 279 Expect(s.Host()).To(Equal("puppet")) 280 Expect(s.Port()).To(Equal(uint16(8140))) 281 Expect(s.Scheme()).To(Equal("https")) 282 }) 283 284 It("Should return defaults when SRV fails", func() { 285 resolver.EXPECT().QuerySrvRecords([]string{"_x-puppet-ca._tcp", "_x-puppet._tcp"}).Return(srvcache.NewServers(), errors.New("simulated error")) 286 287 cfg.DisableSRV = false 288 s := prov.puppetCA() 289 Expect(s.Host()).To(Equal("puppet")) 290 Expect(s.Port()).To(Equal(uint16(8140))) 291 Expect(s.Scheme()).To(Equal("https")) 292 }) 293 294 It("Should use SRV records", func() { 295 ans := srvcache.NewServers( 296 srvcache.NewServer("p1", 8080, "http"), 297 srvcache.NewServer("p2", 8081, "https"), 298 ) 299 300 resolver.EXPECT().QuerySrvRecords([]string{"_x-puppet-ca._tcp", "_x-puppet._tcp"}).Return(ans, nil) 301 cfg.DisableSRV = false 302 303 s := prov.puppetCA() 304 Expect(s.Host()).To(Equal("p1")) 305 Expect(s.Port()).To(Equal(uint16(8080))) 306 Expect(s.Scheme()).To(Equal("http")) 307 308 }) 309 }) 310 })