github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/testing/testing.go (about) 1 // Package testing has the database tests. 2 // All database drivers must pass the Test function. 3 // This lives in it's own package so it stays a test dependency. 4 package testing 5 6 import ( 7 "bytes" 8 "errors" 9 "fmt" 10 "io" 11 "testing" 12 "time" 13 14 "github.com/golang-migrate/migrate/v4/database" 15 ) 16 17 // Test runs tests against database implementations. 18 func Test(t *testing.T, d database.Driver, migration []byte) { 19 if migration == nil { 20 t.Fatal("test must provide migration reader") 21 } 22 23 TestNilVersion(t, d) // test first 24 TestLockAndUnlock(t, d) 25 TestRun(t, d, bytes.NewReader(migration)) 26 TestSetVersion(t, d) // also tests Version() 27 // Drop breaks the driver, so test it last. 28 TestDrop(t, d) 29 } 30 31 func TestNilVersion(t *testing.T, d database.Driver) { 32 v, _, err := d.Version() 33 if err != nil { 34 t.Fatal(err) 35 } 36 if v != database.NilVersion { 37 t.Fatalf("Version: expected version to be NilVersion (-1), got %v", v) 38 } 39 } 40 41 func TestLockAndUnlock(t *testing.T, d database.Driver) { 42 // add a timeout, in case there is a deadlock 43 done := make(chan struct{}) 44 errs := make(chan error) 45 46 go func() { 47 timeout := time.After(15 * time.Second) 48 for { 49 select { 50 case <-done: 51 return 52 case <-timeout: 53 errs <- fmt.Errorf("Timeout after 15 seconds. Looks like a deadlock in Lock/UnLock.\n%#v", d) 54 return 55 } 56 } 57 }() 58 59 // run the locking test ... 60 go func() { 61 if err := d.Lock(); err != nil { 62 errs <- err 63 return 64 } 65 66 // try to acquire lock again 67 if err := d.Lock(); err == nil { 68 errs <- errors.New("lock: expected err not to be nil") 69 return 70 } 71 72 // unlock 73 if err := d.Unlock(); err != nil { 74 errs <- err 75 return 76 } 77 78 // try to lock 79 if err := d.Lock(); err != nil { 80 errs <- err 81 return 82 } 83 if err := d.Unlock(); err != nil { 84 errs <- err 85 return 86 } 87 // notify everyone 88 close(done) 89 }() 90 91 // wait for done or any error 92 for { 93 select { 94 case <-done: 95 return 96 case err := <-errs: 97 t.Fatal(err) 98 } 99 } 100 } 101 102 func TestRun(t *testing.T, d database.Driver, migration io.Reader) { 103 if migration == nil { 104 t.Fatal("migration can't be nil") 105 } 106 107 if err := d.Run(migration); err != nil { 108 t.Fatal(err) 109 } 110 } 111 112 func TestDrop(t *testing.T, d database.Driver) { 113 if err := d.Drop(); err != nil { 114 t.Fatal(err) 115 } 116 } 117 118 func TestSetVersion(t *testing.T, d database.Driver) { 119 // nolint:maligned 120 testCases := []struct { 121 name string 122 version int 123 dirty bool 124 expectedErr error 125 expectedReadErr error 126 expectedVersion int 127 expectedDirty bool 128 }{ 129 {name: "set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true}, 130 {name: "re-set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true}, 131 {name: "set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false}, 132 {name: "re-set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false}, 133 {name: "last migration dirty", version: database.NilVersion, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: true}, 134 {name: "last migration clean", version: database.NilVersion, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: false}, 135 } 136 137 for _, tc := range testCases { 138 t.Run(tc.name, func(t *testing.T) { 139 err := d.SetVersion(tc.version, tc.dirty) 140 if err != tc.expectedErr { 141 t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr) 142 } 143 v, dirty, readErr := d.Version() 144 if readErr != tc.expectedReadErr { 145 t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr) 146 } 147 if v != tc.expectedVersion { 148 t.Error("Got unexpected version:", v, "!=", tc.expectedVersion) 149 } 150 if dirty != tc.expectedDirty { 151 t.Error("Got unexpected dirty value:", dirty, "!=", tc.dirty) 152 } 153 }) 154 } 155 }