github.com/decred/dcrlnd@v0.7.6/channeldb/migtest/migtest.go (about) 1 package migtest 2 3 import ( 4 "fmt" 5 "io/ioutil" 6 "os" 7 "testing" 8 9 "github.com/decred/dcrlnd/kvdb" 10 ) 11 12 // MakeDB creates a new instance of the ChannelDB for testing purposes. A 13 // callback which cleans up the created temporary directories is also returned 14 // and intended to be executed after the test completes. 15 func MakeDB() (kvdb.Backend, func(), error) { 16 // Create temporary database for mission control. 17 file, err := ioutil.TempFile("", "*.db") 18 if err != nil { 19 return nil, nil, err 20 } 21 22 dbPath := file.Name() 23 db, err := kvdb.Open( 24 kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout, 25 ) 26 if err != nil { 27 return nil, nil, err 28 } 29 30 cleanUp := func() { 31 db.Close() 32 os.RemoveAll(dbPath) 33 } 34 35 return db, cleanUp, nil 36 } 37 38 // ApplyMigration is a helper test function that encapsulates the general steps 39 // which are needed to properly check the result of applying migration function. 40 func ApplyMigration(t *testing.T, 41 beforeMigration, afterMigration, migrationFunc func(tx kvdb.RwTx) error, 42 shouldFail bool) { 43 44 t.Helper() 45 46 cdb, cleanUp, err := MakeDB() 47 defer cleanUp() 48 if err != nil { 49 t.Fatal(err) 50 } 51 52 // beforeMigration usually used for populating the database 53 // with test data. 54 err = kvdb.Update(cdb, beforeMigration, func() {}) 55 if err != nil { 56 t.Fatal(err) 57 } 58 59 defer func() { 60 t.Helper() 61 62 if r := recover(); r != nil { 63 err = newError(r) 64 } 65 66 if err == nil && shouldFail { 67 t.Fatal("error wasn't received on migration stage") 68 } else if err != nil && !shouldFail { 69 t.Fatalf("error was received on migration stage: %v", err) 70 } 71 72 // afterMigration usually used for checking the database state and 73 // throwing the error if something went wrong. 74 err = kvdb.Update(cdb, afterMigration, func() {}) 75 if err != nil { 76 t.Fatal(err) 77 } 78 }() 79 80 // Apply migration. 81 err = kvdb.Update(cdb, migrationFunc, func() {}) 82 if err != nil { 83 t.Logf("migration error: %v", err) 84 } 85 } 86 87 func newError(e interface{}) error { 88 var err error 89 switch e := e.(type) { 90 case error: 91 err = e 92 default: 93 err = fmt.Errorf("%v", e) 94 } 95 96 return err 97 }