github.com/decred/dcrlnd@v0.7.6/kvdb/postgres/fixture.go (about) 1 //go:build kvdb_postgres 2 // +build kvdb_postgres 3 4 package postgres 5 6 import ( 7 "context" 8 "crypto/rand" 9 "database/sql" 10 "encoding/hex" 11 "fmt" 12 "strings" 13 "time" 14 15 "github.com/btcsuite/btcwallet/walletdb" 16 embeddedpostgres "github.com/fergusstrange/embedded-postgres" 17 ) 18 19 const ( 20 testDsnTemplate = "postgres://postgres:postgres@localhost:9876/%v?sslmode=disable" 21 prefix = "test" 22 ) 23 24 func getTestDsn(dbName string) string { 25 return fmt.Sprintf(testDsnTemplate, dbName) 26 } 27 28 var testPostgres *embeddedpostgres.EmbeddedPostgres 29 30 const testMaxConnections = 50 31 32 // StartEmbeddedPostgres starts an embedded postgres instance. This only needs 33 // to be done once, because NewFixture will create random new databases on every 34 // call. It returns a stop closure that stops the database if called. 35 func StartEmbeddedPostgres() (func() error, error) { 36 Init(testMaxConnections) 37 38 postgres := embeddedpostgres.NewDatabase( 39 embeddedpostgres.DefaultConfig(). 40 Port(9876)) 41 42 err := postgres.Start() 43 if err != nil { 44 return nil, err 45 } 46 47 testPostgres = postgres 48 49 return testPostgres.Stop, nil 50 } 51 52 // NewFixture returns a new postgres test database. The database name is 53 // randomly generated. 54 func NewFixture(dbName string) (*fixture, error) { 55 if dbName == "" { 56 // Create random database name. 57 randBytes := make([]byte, 8) 58 _, err := rand.Read(randBytes) 59 if err != nil { 60 return nil, err 61 } 62 63 dbName = "test_" + hex.EncodeToString(randBytes) 64 } 65 66 // Create database if it doesn't exist yet. 67 dbConn, err := sql.Open("pgx", getTestDsn("postgres")) 68 if err != nil { 69 return nil, err 70 } 71 defer dbConn.Close() 72 73 _, err = dbConn.ExecContext( 74 context.Background(), "CREATE DATABASE "+dbName, 75 ) 76 if err != nil && !strings.Contains(err.Error(), "already exists") { 77 return nil, err 78 } 79 80 // Open database 81 dsn := getTestDsn(dbName) 82 db, err := newPostgresBackend( 83 context.Background(), 84 &Config{ 85 Dsn: dsn, 86 Timeout: time.Minute, 87 }, 88 prefix, 89 ) 90 if err != nil { 91 return nil, err 92 } 93 94 return &fixture{ 95 Dsn: dsn, 96 Db: db, 97 }, nil 98 } 99 100 type fixture struct { 101 Dsn string 102 Db walletdb.DB 103 } 104 105 func (b *fixture) DB() walletdb.DB { 106 return b.Db 107 } 108 109 // Dump returns the raw contents of the database. 110 func (b *fixture) Dump() (map[string]interface{}, error) { 111 dbConn, err := sql.Open("pgx", b.Dsn) 112 if err != nil { 113 return nil, err 114 } 115 116 rows, err := dbConn.Query( 117 "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'", 118 ) 119 if err != nil { 120 return nil, err 121 } 122 123 var tables []string 124 for rows.Next() { 125 var table string 126 err := rows.Scan(&table) 127 if err != nil { 128 return nil, err 129 } 130 131 tables = append(tables, table) 132 } 133 134 result := make(map[string]interface{}) 135 136 for _, table := range tables { 137 rows, err := dbConn.Query("SELECT * FROM " + table) 138 if err != nil { 139 return nil, err 140 } 141 142 cols, err := rows.Columns() 143 if err != nil { 144 return nil, err 145 } 146 colCount := len(cols) 147 148 var tableRows []map[string]interface{} 149 for rows.Next() { 150 values := make([]interface{}, colCount) 151 valuePtrs := make([]interface{}, colCount) 152 for i := range values { 153 valuePtrs[i] = &values[i] 154 } 155 156 err := rows.Scan(valuePtrs...) 157 if err != nil { 158 return nil, err 159 } 160 161 tableData := make(map[string]interface{}) 162 for i, v := range values { 163 // Cast byte slices to string to keep the 164 // expected database contents in test code more 165 // readable. 166 if ar, ok := v.([]uint8); ok { 167 v = string(ar) 168 } 169 tableData[cols[i]] = v 170 } 171 172 tableRows = append(tableRows, tableData) 173 } 174 175 result[table] = tableRows 176 } 177 178 return result, nil 179 }