github.com/status-im/status-go@v1.1.0/t/helpers/db.go (about)

     1  package helpers
     2  
     3  import (
     4  	"database/sql"
     5  	"io/ioutil"
     6  	"os"
     7  
     8  	"github.com/status-im/status-go/common/dbsetup"
     9  	"github.com/status-im/status-go/multiaccounts"
    10  )
    11  
    12  const kdfIterationsNumberForTests = 1
    13  
    14  // SetupTestSQLDB creates a temporary sqlite database file, initialises and then returns with a teardown func
    15  func SetupTestSQLDB(dbInit dbsetup.DatabaseInitializer, prefix string) (*sql.DB, func() error, error) {
    16  	tmpfile, err := ioutil.TempFile("", prefix)
    17  	if err != nil {
    18  		return nil, nil, err
    19  	}
    20  
    21  	db, err := dbInit.Initialize(tmpfile.Name(), "password", kdfIterationsNumberForTests)
    22  	if err != nil {
    23  		return nil, nil, err
    24  	}
    25  
    26  	return db, func() error {
    27  		err := db.Close()
    28  		if err != nil {
    29  			return err
    30  		}
    31  		return os.Remove(tmpfile.Name())
    32  	}, nil
    33  }
    34  
    35  func SetupTestMemorySQLDB(dbInit dbsetup.DatabaseInitializer) (*sql.DB, error) {
    36  	db, err := dbInit.Initialize(dbsetup.InMemoryPath, "password", kdfIterationsNumberForTests)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	return db, nil
    42  }
    43  
    44  func SetupTestMemorySQLAccountsDB(dbInit dbsetup.DatabaseInitializer) (*sql.DB, error) {
    45  	db, err := multiaccounts.InitializeDB(dbsetup.InMemoryPath)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	return db.DB(), nil
    51  }
    52  
    53  func ColumnExists(db *sql.DB, tableName string, columnName string) (bool, error) {
    54  	rows, err := db.Query("PRAGMA table_info(" + tableName + ")")
    55  	if err != nil {
    56  		return false, err
    57  	}
    58  	defer rows.Close()
    59  
    60  	var cid int
    61  	var name string
    62  	var dataType string
    63  	var notNull bool
    64  	var dFLTValue sql.NullString
    65  	var pk int
    66  
    67  	for rows.Next() {
    68  		err := rows.Scan(&cid, &name, &dataType, &notNull, &dFLTValue, &pk)
    69  		if err != nil {
    70  			return false, err
    71  		}
    72  		if name == columnName {
    73  			return true, nil
    74  		}
    75  	}
    76  
    77  	if rows.Err() != nil {
    78  		return false, rows.Err()
    79  	}
    80  
    81  	return false, nil
    82  }