github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/database/bootstrap_test.go (about) 1 // Copyright 2022 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package database 5 6 import ( 7 "context" 8 "database/sql" 9 "errors" 10 "fmt" 11 "net" 12 13 "github.com/juju/testing" 14 jc "github.com/juju/testing/checkers" 15 gc "gopkg.in/check.v1" 16 17 "github.com/juju/juju/core/network" 18 "github.com/juju/juju/database/app" 19 "github.com/juju/juju/database/client" 20 ) 21 22 type bootstrapSuite struct { 23 testing.IsolationSuite 24 } 25 26 var _ = gc.Suite(&bootstrapSuite{}) 27 28 func (s *bootstrapSuite) TestBootstrapSuccess(c *gc.C) { 29 mgr := &testNodeManager{c: c} 30 31 // check tests the variadic operation functionality 32 // and ensures that bootstrap applied the DDL. 33 check := func(db *sql.DB) error { 34 rows, err := db.Query("SELECT COUNT(*) FROM lease_type") 35 if err != nil { 36 return err 37 } 38 39 defer func() { _ = rows.Close() }() 40 41 if !rows.Next() { 42 return errors.New("no rows in lease_type") 43 } 44 45 var count int 46 err = rows.Scan(&count) 47 if err != nil { 48 return err 49 } 50 51 if count != 2 { 52 return fmt.Errorf("expected 2 rows, got %d", count) 53 } 54 55 return nil 56 } 57 58 err := BootstrapDqlite(context.Background(), mgr, stubLogger{}, check) 59 c.Assert(err, jc.ErrorIsNil) 60 } 61 62 type testNodeManager struct { 63 c *gc.C 64 dataDir string 65 port int 66 } 67 68 func (f *testNodeManager) EnsureDataDir() (string, error) { 69 if f.dataDir == "" { 70 f.dataDir = f.c.MkDir() 71 } 72 return f.dataDir, nil 73 } 74 75 func (f *testNodeManager) IsLoopbackPreferred() bool { 76 return true 77 } 78 79 func (f *testNodeManager) WithPreferredCloudLocalAddressOption(network.ConfigSource) (app.Option, error) { 80 return f.WithLoopbackAddressOption(), nil 81 } 82 83 func (f *testNodeManager) WithLoopbackAddressOption() app.Option { 84 if f.port == 0 { 85 l, err := net.Listen("tcp", ":0") 86 f.c.Assert(err, jc.ErrorIsNil) 87 f.c.Assert(l.Close(), jc.ErrorIsNil) 88 f.port = l.Addr().(*net.TCPAddr).Port 89 } 90 return app.WithAddress(fmt.Sprintf("127.0.0.1:%d", f.port)) 91 } 92 93 func (f *testNodeManager) WithLogFuncOption() app.Option { 94 return app.WithLogFunc(func(_ client.LogLevel, msg string, args ...interface{}) { 95 f.c.Logf(msg, args...) 96 }) 97 } 98 99 func (f *testNodeManager) WithTracingOption() app.Option { 100 return app.WithTracing(client.LogNone) 101 } 102 103 func (f *testNodeManager) WithTLSOption() (app.Option, error) { 104 return nil, nil 105 }