github.com/hyperledger/burrow@v0.34.5-0.20220512172541-77f09336001d/vent/test/db.go (about)

     1  package test
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"math/rand"
     7  	"os"
     8  	"syscall"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/hyperledger/burrow/logging"
    13  	"github.com/hyperledger/burrow/vent/config"
    14  	"github.com/hyperledger/burrow/vent/sqldb"
    15  	"github.com/hyperledger/burrow/vent/types"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  const (
    20  	ChainID       = "CHAIN_123"
    21  	BurrowVersion = "1.0.0"
    22  )
    23  
    24  var letters = []rune("abcdefghijklmnopqrstuvwxyz")
    25  
    26  func init() {
    27  	rand.Seed(time.Now().UnixNano())
    28  }
    29  
    30  // NewTestDB creates a database connection for testing
    31  func NewTestDB(t *testing.T, cfg *config.VentConfig) (*sqldb.SQLDB, func()) {
    32  	t.Helper()
    33  
    34  	if cfg.DBAdapter != types.SQLiteDB {
    35  		if dbURL, ok := syscall.Getenv("DB_URL"); ok {
    36  			t.Logf("Using DB_URL '%s'", dbURL)
    37  			cfg.DBURL = dbURL
    38  		}
    39  	}
    40  
    41  	connection := types.SQLConnection{
    42  		DBAdapter: cfg.DBAdapter,
    43  		DBURL:     cfg.DBURL,
    44  		DBSchema:  cfg.DBSchema,
    45  
    46  		Log: logging.NewNoopLogger(),
    47  	}
    48  
    49  	db, err := sqldb.NewSQLDB(connection)
    50  	require.NoError(t, err)
    51  
    52  	err = db.Init(ChainID, BurrowVersion)
    53  	require.NoError(t, err)
    54  
    55  	return db, func() {
    56  		if cfg.DBAdapter == types.SQLiteDB {
    57  			db.Close()
    58  			os.Remove(connection.DBURL)
    59  			os.Remove(connection.DBURL + "-shm")
    60  			os.Remove(connection.DBURL + "-wal")
    61  		} else {
    62  			destroySchema(db, connection.DBSchema)
    63  			db.Close()
    64  		}
    65  	}
    66  }
    67  
    68  func SqliteVentConfig(grpcAddress string) *config.VentConfig {
    69  	cfg := config.DefaultVentConfig()
    70  	file, err := ioutil.TempFile("", "vent.sqlite")
    71  	if err != nil {
    72  		panic(err)
    73  	}
    74  	err = file.Close()
    75  	if err != nil {
    76  		panic(err)
    77  	}
    78  
    79  	cfg.DBURL = file.Name()
    80  	cfg.DBAdapter = types.SQLiteDB
    81  	cfg.ChainAddress = grpcAddress
    82  	return cfg
    83  }
    84  
    85  func PostgresVentConfig(chainAddress string) *config.VentConfig {
    86  	cfg := config.DefaultVentConfig()
    87  	cfg.DBSchema = fmt.Sprintf("test_%d_%s", time.Now().Unix(), randString(10))
    88  	cfg.DBAdapter = types.PostgresDB
    89  	cfg.DBURL = config.DefaultPostgresDBURL
    90  	cfg.ChainAddress = chainAddress
    91  	cfg.AnnounceEvery = time.Millisecond * 100
    92  	return cfg
    93  }
    94  
    95  func destroySchema(db *sqldb.SQLDB, dbSchema string) error {
    96  	db.Log.InfoMsg("Dropping schema")
    97  	query := fmt.Sprintf("DROP SCHEMA %s CASCADE;", dbSchema)
    98  
    99  	db.Log.InfoMsg("Drop schema", "query", query)
   100  
   101  	if _, err := db.DB.Exec(query); err != nil {
   102  		db.Log.InfoMsg("Error dropping schema", "err", err)
   103  		return err
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func randString(n int) string {
   110  	b := make([]rune, n)
   111  
   112  	for i := range b {
   113  		b[i] = letters[rand.Intn(len(letters))]
   114  	}
   115  
   116  	return string(b)
   117  }