github.com/nmanchovski/burrow@v0.25.0/vent/test/db.go (about)

     1  package test
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"math/rand"
     7  	"syscall"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"os"
    14  
    15  	"github.com/hyperledger/burrow/vent/config"
    16  	"github.com/hyperledger/burrow/vent/logger"
    17  	"github.com/hyperledger/burrow/vent/sqldb"
    18  	"github.com/hyperledger/burrow/vent/types"
    19  )
    20  
    21  var letters = []rune("abcdefghijklmnopqrstuvwxyz")
    22  
    23  func init() {
    24  	rand.Seed(time.Now().UnixNano())
    25  }
    26  
    27  // NewTestDB creates a database connection for testing
    28  func NewTestDB(t *testing.T, cfg *config.VentConfig) (*sqldb.SQLDB, func()) {
    29  	t.Helper()
    30  
    31  	if cfg.DBAdapter != types.SQLiteDB {
    32  		if dbURL, ok := syscall.Getenv("DB_URL"); ok {
    33  			t.Logf("Using DB_URL '%s'", dbURL)
    34  			cfg.DBURL = dbURL
    35  		}
    36  	}
    37  
    38  	connection := types.SQLConnection{
    39  		DBAdapter: cfg.DBAdapter,
    40  		DBURL:     cfg.DBURL,
    41  		DBSchema:  cfg.DBSchema,
    42  
    43  		Log:           logger.NewLogger(""),
    44  		ChainID:       "ID 0123",
    45  		BurrowVersion: "Version 0.0",
    46  	}
    47  
    48  	db, err := sqldb.NewSQLDB(connection)
    49  	if err != nil {
    50  		require.NoError(t, err)
    51  	}
    52  
    53  	return db, func() {
    54  		if cfg.DBAdapter == types.SQLiteDB {
    55  			db.Close()
    56  			os.Remove(connection.DBURL)
    57  			os.Remove(connection.DBURL + "-shm")
    58  			os.Remove(connection.DBURL + "-wal")
    59  		} else {
    60  			destroySchema(db, connection.DBSchema)
    61  			db.Close()
    62  		}
    63  	}
    64  }
    65  
    66  func SqliteVentConfig(grpcAddress string) *config.VentConfig {
    67  	cfg := config.DefaultVentConfig()
    68  	file, err := ioutil.TempFile("", "vent.sqlite")
    69  	if err != nil {
    70  		panic(err)
    71  	}
    72  	err = file.Close()
    73  	if err != nil {
    74  		panic(err)
    75  	}
    76  
    77  	cfg.DBURL = file.Name()
    78  	cfg.DBAdapter = types.SQLiteDB
    79  	cfg.GRPCAddr = grpcAddress
    80  	return cfg
    81  }
    82  
    83  func PostgresVentConfig(grpcAddress string) *config.VentConfig {
    84  	cfg := config.DefaultVentConfig()
    85  	cfg.DBSchema = fmt.Sprintf("test_%s", randString(10))
    86  	cfg.DBAdapter = types.PostgresDB
    87  	cfg.DBURL = config.DefaultPostgresDBURL
    88  	cfg.GRPCAddr = grpcAddress
    89  	return cfg
    90  }
    91  
    92  func destroySchema(db *sqldb.SQLDB, dbSchema string) error {
    93  	db.Log.Info("msg", "Dropping schema")
    94  	query := fmt.Sprintf("DROP SCHEMA %s CASCADE;", dbSchema)
    95  
    96  	db.Log.Info("msg", "Drop schema", "query", query)
    97  
    98  	if _, err := db.DB.Exec(query); err != nil {
    99  		db.Log.Info("msg", "Error dropping schema", "err", err)
   100  		return err
   101  	}
   102  
   103  	return nil
   104  }
   105  
   106  func randString(n int) string {
   107  	b := make([]rune, n)
   108  
   109  	for i := range b {
   110  		b[i] = letters[rand.Intn(len(letters))]
   111  	}
   112  
   113  	return string(b)
   114  }