github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/security/puppetsec/puppet_security_test.go (about)

     1  // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package puppetsec
     6  
     7  import (
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"os"
    13  	"path/filepath"
    14  	"runtime"
    15  	"testing"
    16  
    17  	"github.com/choria-io/go-choria/build"
    18  	"github.com/choria-io/go-choria/config"
    19  	"github.com/choria-io/go-choria/inter"
    20  	"github.com/choria-io/go-choria/srvcache"
    21  	"github.com/sirupsen/logrus"
    22  
    23  	"github.com/golang/mock/gomock"
    24  	. "github.com/onsi/ginkgo/v2"
    25  	. "github.com/onsi/gomega"
    26  )
    27  
    28  func TestPuppetSecurity(t *testing.T) {
    29  	RegisterFailHandler(Fail)
    30  	RunSpecs(t, "Providers/Security/Puppet")
    31  }
    32  
    33  var _ = Describe("PuppetSSL", func() {
    34  	var mockctl *gomock.Controller
    35  	var resolver *MockResolver
    36  	var cfg *Config
    37  	var err error
    38  	var prov *PuppetSecurity
    39  	var l *logrus.Logger
    40  
    41  	BeforeEach(func() {
    42  		mockctl = gomock.NewController(GinkgoT())
    43  		resolver = NewMockResolver(mockctl)
    44  		os.Setenv("MCOLLECTIVE_CERTNAME", "rip.mcollective")
    45  
    46  		cfg = &Config{
    47  			SSLDir:       filepath.Join("..", "testdata", "good"),
    48  			Identity:     "rip.mcollective",
    49  			PuppetCAHost: "puppet",
    50  			PuppetCAPort: 8140,
    51  			DisableSRV:   true,
    52  			useFakeUID:   true,
    53  			fakeUID:      500,
    54  		}
    55  
    56  		l = logrus.New()
    57  		l.SetOutput(GinkgoWriter)
    58  
    59  		prov, err = New(WithConfig(cfg), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
    60  		Expect(err).ToNot(HaveOccurred())
    61  	})
    62  
    63  	AfterEach(func() {
    64  		mockctl.Finish()
    65  	})
    66  
    67  	It("Should implement the provider interface", func() {
    68  		f := func(p inter.SecurityProvider) {}
    69  		f(prov)
    70  		Expect(prov.Provider()).To(Equal("puppet"))
    71  	})
    72  
    73  	Describe("WithChoriaConfig", func() {
    74  		It("Should disable SRV when the CA is configured", func() {
    75  			c, err := config.NewConfig(filepath.Join("..", "testdata", "puppetca.cfg"))
    76  			Expect(err).ToNot(HaveOccurred())
    77  
    78  			prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
    79  			Expect(err).ToNot(HaveOccurred())
    80  
    81  			Expect(prov.conf.DisableSRV).To(BeTrue())
    82  		})
    83  
    84  		It("Should support OverrideCertname", func() {
    85  			c := config.NewConfigForTests()
    86  
    87  			c.OverrideCertname = "override.choria"
    88  			prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
    89  			Expect(err).ToNot(HaveOccurred())
    90  
    91  			Expect(prov.conf.Identity).To(Equal("override.choria"))
    92  		})
    93  
    94  		// TODO: windows
    95  		if runtime.GOOS != "windows" {
    96  			It("Should fail when it cannot determine user identity", func() {
    97  				c := config.NewConfigForTests()
    98  				c.OverrideCertname = ""
    99  				v := os.Getenv("USER")
   100  				defer os.Setenv("USER", v)
   101  				os.Unsetenv("USER")
   102  				os.Unsetenv("MCOLLECTIVE_CERTNAME")
   103  				_, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
   104  				Expect(err).To(MatchError("could not determine client identity, ensure USER environment variable is set"))
   105  			})
   106  
   107  			It("Should use the user SSL directory when not configured", func() {
   108  				c, err := config.NewDefaultConfig()
   109  				Expect(err).ToNot(HaveOccurred())
   110  
   111  				prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
   112  				Expect(err).ToNot(HaveOccurred())
   113  
   114  				d, err := userSSlDir()
   115  				Expect(err).ToNot(HaveOccurred())
   116  
   117  				Expect(prov.conf.SSLDir).To(Equal(d))
   118  			})
   119  		}
   120  
   121  		It("Should copy all the relevant settings", func() {
   122  			c, err := config.NewDefaultConfig()
   123  			Expect(err).ToNot(HaveOccurred())
   124  
   125  			c.DisableTLSVerify = true
   126  			c.Choria.SSLDir = "/stub"
   127  			c.Choria.PuppetCAHost = "stubhost"
   128  			c.Choria.PuppetCAPort = 8080
   129  
   130  			prov, err = New(WithChoriaConfig(&build.Info{}, c), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
   131  			Expect(err).ToNot(HaveOccurred())
   132  
   133  			Expect(prov.conf.AllowList).To(Equal([]string{"\\.mcollective$", "\\.choria$"}))
   134  			Expect(prov.conf.PrivilegedUsers).To(Equal([]string{"\\.privileged.mcollective$", "\\.privileged.choria$"}))
   135  			Expect(prov.conf.DisableTLSVerify).To(BeTrue())
   136  			Expect(prov.conf.SSLDir).To(Equal("/stub"))
   137  			Expect(prov.conf.PuppetCAHost).To(Equal("stubhost"))
   138  			Expect(prov.conf.PuppetCAPort).To(Equal(8080))
   139  		})
   140  	})
   141  
   142  	Describe("Validate", func() {
   143  		It("Should handle missing files", func() {
   144  			cfg.SSLDir = filepath.Join("testdata", "allmissing")
   145  			cfg.Identity = "test.mcollective"
   146  			prov, err = New(WithConfig(cfg), WithResolver(resolver), WithLog(l.WithFields(logrus.Fields{})))
   147  
   148  			Expect(err).ToNot(HaveOccurred())
   149  
   150  			errs, ok := prov.Validate()
   151  
   152  			Expect(ok).To(BeFalse())
   153  			Expect(errs).To(HaveLen(3))
   154  			Expect(errs[0]).To(Equal(fmt.Sprintf("public certificate %s does not exist", filepath.Join(cfg.SSLDir, "certs", "test.mcollective.pem"))))
   155  			Expect(errs[1]).To(Equal(fmt.Sprintf("private key %s does not exist", filepath.Join(cfg.SSLDir, "private_keys", "test.mcollective.pem"))))
   156  			Expect(errs[2]).To(Equal(fmt.Sprintf("CA %s does not exist", filepath.Join(cfg.SSLDir, "certs", "ca.pem"))))
   157  		})
   158  
   159  		It("Should accept valid directories", func() {
   160  			cfg.Identity = "rip.mcollective"
   161  			errs, ok := prov.Validate()
   162  			Expect(errs).To(BeEmpty())
   163  			Expect(ok).To(BeTrue())
   164  		})
   165  	})
   166  
   167  	Describe("Identity", func() {
   168  		It("Should support OverrideCertname", func() {
   169  			cfg.Identity = "bob.choria"
   170  			prov.reinit()
   171  
   172  			Expect(prov.Identity()).To(Equal("bob.choria"))
   173  		})
   174  	})
   175  
   176  	Describe("writeCSR", func() {
   177  		It("should not write over existing CSRs", func() {
   178  			cfg.Identity = "na.mcollective"
   179  			prov.reinit()
   180  
   181  			kpath := prov.privateKeyPath()
   182  			csrpath := prov.csrPath()
   183  
   184  			defer os.Remove(kpath)
   185  			defer os.Remove(csrpath)
   186  
   187  			key, err := prov.writePrivateKey()
   188  			Expect(err).ToNot(HaveOccurred())
   189  
   190  			prov.conf.Identity = "rip.mcollective"
   191  			prov.reinit()
   192  			_, err = prov.writeCSR(key, "rip.mcollective", "choria.io")
   193  
   194  			Expect(err).To(MatchError("a certificate request already exist for rip.mcollective"))
   195  		})
   196  
   197  		It("Should create a valid CSR", func() {
   198  			prov.conf.Identity = "na.mcollective"
   199  			prov.reinit()
   200  
   201  			kpath := prov.privateKeyPath()
   202  			csrpath := prov.csrPath()
   203  
   204  			defer os.Remove(kpath)
   205  			defer os.Remove(csrpath)
   206  
   207  			key, err := prov.writePrivateKey()
   208  			Expect(err).ToNot(HaveOccurred())
   209  
   210  			_, err = prov.writeCSR(key, "na.mcollective", "choria.io")
   211  			Expect(err).ToNot(HaveOccurred())
   212  
   213  			csrpem, err := os.ReadFile(csrpath)
   214  			Expect(err).ToNot(HaveOccurred())
   215  
   216  			pb, _ := pem.Decode(csrpem)
   217  
   218  			req, err := x509.ParseCertificateRequest(pb.Bytes)
   219  			Expect(err).ToNot(HaveOccurred())
   220  			Expect(req.Subject.CommonName).To(Equal("na.mcollective"))
   221  			Expect(req.Subject.OrganizationalUnit).To(Equal([]string{"choria.io"}))
   222  		})
   223  	})
   224  
   225  	Describe("writePrivateKey", func() {
   226  		It("Should not write over existing private keys", func() {
   227  			cfg.Identity = "rip.mcollective"
   228  			key, err := prov.writePrivateKey()
   229  			Expect(err).To(MatchError("a private key already exist for rip.mcollective"))
   230  			Expect(key).To(BeNil())
   231  		})
   232  
   233  		It("Should create new keys", func() {
   234  			cfg.Identity = "na.mcollective"
   235  			prov.reinit()
   236  
   237  			path := prov.privateKeyPath()
   238  			defer os.Remove(path)
   239  
   240  			key, err := prov.writePrivateKey()
   241  			Expect(err).ToNot(HaveOccurred())
   242  			Expect(key).ToNot(BeNil())
   243  			Expect(path).To(BeAnExistingFile())
   244  		})
   245  	})
   246  
   247  	Describe("csrExists", func() {
   248  		It("Should detect existing keys", func() {
   249  			cfg.Identity = "rip.mcollective"
   250  			prov.reinit()
   251  
   252  			Expect(prov.csrExists()).To(BeTrue())
   253  		})
   254  
   255  		It("Should detect absent keys", func() {
   256  			cfg.Identity = "na.mcollective"
   257  			prov.reinit()
   258  
   259  			Expect(prov.csrExists()).To(BeFalse())
   260  		})
   261  	})
   262  
   263  	Describe("puppetCA", func() {
   264  		It("Should use supplied config when SRV is disabled", func() {
   265  			cfg.DisableSRV = true
   266  			s := prov.puppetCA()
   267  			Expect(s.Host()).To(Equal("puppet"))
   268  			Expect(s.Port()).To(Equal(uint16(8140)))
   269  			Expect(s.Scheme()).To(Equal("https"))
   270  		})
   271  
   272  		It("Should use supplied config when no srv resolver is given", func() {
   273  			prov, err = New(WithConfig(cfg), WithLog(l.WithFields(logrus.Fields{})))
   274  			Expect(err).ToNot(HaveOccurred())
   275  
   276  			resolver.EXPECT().QuerySrvRecords(gomock.Any()).Times(0)
   277  
   278  			s := prov.puppetCA()
   279  			Expect(s.Host()).To(Equal("puppet"))
   280  			Expect(s.Port()).To(Equal(uint16(8140)))
   281  			Expect(s.Scheme()).To(Equal("https"))
   282  		})
   283  
   284  		It("Should return defaults when SRV fails", func() {
   285  			resolver.EXPECT().QuerySrvRecords([]string{"_x-puppet-ca._tcp", "_x-puppet._tcp"}).Return(srvcache.NewServers(), errors.New("simulated error"))
   286  
   287  			cfg.DisableSRV = false
   288  			s := prov.puppetCA()
   289  			Expect(s.Host()).To(Equal("puppet"))
   290  			Expect(s.Port()).To(Equal(uint16(8140)))
   291  			Expect(s.Scheme()).To(Equal("https"))
   292  		})
   293  
   294  		It("Should use SRV records", func() {
   295  			ans := srvcache.NewServers(
   296  				srvcache.NewServer("p1", 8080, "http"),
   297  				srvcache.NewServer("p2", 8081, "https"),
   298  			)
   299  
   300  			resolver.EXPECT().QuerySrvRecords([]string{"_x-puppet-ca._tcp", "_x-puppet._tcp"}).Return(ans, nil)
   301  			cfg.DisableSRV = false
   302  
   303  			s := prov.puppetCA()
   304  			Expect(s.Host()).To(Equal("p1"))
   305  			Expect(s.Port()).To(Equal(uint16(8080)))
   306  			Expect(s.Scheme()).To(Equal("http"))
   307  
   308  		})
   309  	})
   310  })