github.com/decred/dcrlnd@v0.7.6/kvdb/postgres/fixture.go (about)

     1  //go:build kvdb_postgres
     2  // +build kvdb_postgres
     3  
     4  package postgres
     5  
     6  import (
     7  	"context"
     8  	"crypto/rand"
     9  	"database/sql"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/btcsuite/btcwallet/walletdb"
    16  	embeddedpostgres "github.com/fergusstrange/embedded-postgres"
    17  )
    18  
    19  const (
    20  	testDsnTemplate = "postgres://postgres:postgres@localhost:9876/%v?sslmode=disable"
    21  	prefix          = "test"
    22  )
    23  
    24  func getTestDsn(dbName string) string {
    25  	return fmt.Sprintf(testDsnTemplate, dbName)
    26  }
    27  
    28  var testPostgres *embeddedpostgres.EmbeddedPostgres
    29  
    30  const testMaxConnections = 50
    31  
    32  // StartEmbeddedPostgres starts an embedded postgres instance. This only needs
    33  // to be done once, because NewFixture will create random new databases on every
    34  // call. It returns a stop closure that stops the database if called.
    35  func StartEmbeddedPostgres() (func() error, error) {
    36  	Init(testMaxConnections)
    37  
    38  	postgres := embeddedpostgres.NewDatabase(
    39  		embeddedpostgres.DefaultConfig().
    40  			Port(9876))
    41  
    42  	err := postgres.Start()
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	testPostgres = postgres
    48  
    49  	return testPostgres.Stop, nil
    50  }
    51  
    52  // NewFixture returns a new postgres test database. The database name is
    53  // randomly generated.
    54  func NewFixture(dbName string) (*fixture, error) {
    55  	if dbName == "" {
    56  		// Create random database name.
    57  		randBytes := make([]byte, 8)
    58  		_, err := rand.Read(randBytes)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  
    63  		dbName = "test_" + hex.EncodeToString(randBytes)
    64  	}
    65  
    66  	// Create database if it doesn't exist yet.
    67  	dbConn, err := sql.Open("pgx", getTestDsn("postgres"))
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  	defer dbConn.Close()
    72  
    73  	_, err = dbConn.ExecContext(
    74  		context.Background(), "CREATE DATABASE "+dbName,
    75  	)
    76  	if err != nil && !strings.Contains(err.Error(), "already exists") {
    77  		return nil, err
    78  	}
    79  
    80  	// Open database
    81  	dsn := getTestDsn(dbName)
    82  	db, err := newPostgresBackend(
    83  		context.Background(),
    84  		&Config{
    85  			Dsn:     dsn,
    86  			Timeout: time.Minute,
    87  		},
    88  		prefix,
    89  	)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	return &fixture{
    95  		Dsn: dsn,
    96  		Db:  db,
    97  	}, nil
    98  }
    99  
   100  type fixture struct {
   101  	Dsn string
   102  	Db  walletdb.DB
   103  }
   104  
   105  func (b *fixture) DB() walletdb.DB {
   106  	return b.Db
   107  }
   108  
   109  // Dump returns the raw contents of the database.
   110  func (b *fixture) Dump() (map[string]interface{}, error) {
   111  	dbConn, err := sql.Open("pgx", b.Dsn)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	rows, err := dbConn.Query(
   117  		"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'",
   118  	)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	var tables []string
   124  	for rows.Next() {
   125  		var table string
   126  		err := rows.Scan(&table)
   127  		if err != nil {
   128  			return nil, err
   129  		}
   130  
   131  		tables = append(tables, table)
   132  	}
   133  
   134  	result := make(map[string]interface{})
   135  
   136  	for _, table := range tables {
   137  		rows, err := dbConn.Query("SELECT * FROM " + table)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  
   142  		cols, err := rows.Columns()
   143  		if err != nil {
   144  			return nil, err
   145  		}
   146  		colCount := len(cols)
   147  
   148  		var tableRows []map[string]interface{}
   149  		for rows.Next() {
   150  			values := make([]interface{}, colCount)
   151  			valuePtrs := make([]interface{}, colCount)
   152  			for i := range values {
   153  				valuePtrs[i] = &values[i]
   154  			}
   155  
   156  			err := rows.Scan(valuePtrs...)
   157  			if err != nil {
   158  				return nil, err
   159  			}
   160  
   161  			tableData := make(map[string]interface{})
   162  			for i, v := range values {
   163  				// Cast byte slices to string to keep the
   164  				// expected database contents in test code more
   165  				// readable.
   166  				if ar, ok := v.([]uint8); ok {
   167  					v = string(ar)
   168  				}
   169  				tableData[cols[i]] = v
   170  			}
   171  
   172  			tableRows = append(tableRows, tableData)
   173  		}
   174  
   175  		result[table] = tableRows
   176  	}
   177  
   178  	return result, nil
   179  }