github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/testing/testing.go (about)

     1  // Package testing has the database tests.
     2  // All database drivers must pass the Test function.
     3  // This lives in it's own package so it stays a test dependency.
     4  package testing
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/golang-migrate/migrate/v4/database"
    15  )
    16  
    17  // Test runs tests against database implementations.
    18  func Test(t *testing.T, d database.Driver, migration []byte) {
    19  	if migration == nil {
    20  		t.Fatal("test must provide migration reader")
    21  	}
    22  
    23  	TestNilVersion(t, d) // test first
    24  	TestLockAndUnlock(t, d)
    25  	TestRun(t, d, bytes.NewReader(migration))
    26  	TestSetVersion(t, d) // also tests Version()
    27  	// Drop breaks the driver, so test it last.
    28  	TestDrop(t, d)
    29  }
    30  
    31  func TestNilVersion(t *testing.T, d database.Driver) {
    32  	v, _, err := d.Version()
    33  	if err != nil {
    34  		t.Fatal(err)
    35  	}
    36  	if v != database.NilVersion {
    37  		t.Fatalf("Version: expected version to be NilVersion (-1), got %v", v)
    38  	}
    39  }
    40  
    41  func TestLockAndUnlock(t *testing.T, d database.Driver) {
    42  	// add a timeout, in case there is a deadlock
    43  	done := make(chan struct{})
    44  	errs := make(chan error)
    45  
    46  	go func() {
    47  		timeout := time.After(15 * time.Second)
    48  		for {
    49  			select {
    50  			case <-done:
    51  				return
    52  			case <-timeout:
    53  				errs <- fmt.Errorf("Timeout after 15 seconds. Looks like a deadlock in Lock/UnLock.\n%#v", d)
    54  				return
    55  			}
    56  		}
    57  	}()
    58  
    59  	// run the locking test ...
    60  	go func() {
    61  		if err := d.Lock(); err != nil {
    62  			errs <- err
    63  			return
    64  		}
    65  
    66  		// try to acquire lock again
    67  		if err := d.Lock(); err == nil {
    68  			errs <- errors.New("lock: expected err not to be nil")
    69  			return
    70  		}
    71  
    72  		// unlock
    73  		if err := d.Unlock(); err != nil {
    74  			errs <- err
    75  			return
    76  		}
    77  
    78  		// try to lock
    79  		if err := d.Lock(); err != nil {
    80  			errs <- err
    81  			return
    82  		}
    83  		if err := d.Unlock(); err != nil {
    84  			errs <- err
    85  			return
    86  		}
    87  		// notify everyone
    88  		close(done)
    89  	}()
    90  
    91  	// wait for done or any error
    92  	for {
    93  		select {
    94  		case <-done:
    95  			return
    96  		case err := <-errs:
    97  			t.Fatal(err)
    98  		}
    99  	}
   100  }
   101  
   102  func TestRun(t *testing.T, d database.Driver, migration io.Reader) {
   103  	if migration == nil {
   104  		t.Fatal("migration can't be nil")
   105  	}
   106  
   107  	if err := d.Run(migration); err != nil {
   108  		t.Fatal(err)
   109  	}
   110  }
   111  
   112  func TestDrop(t *testing.T, d database.Driver) {
   113  	if err := d.Drop(); err != nil {
   114  		t.Fatal(err)
   115  	}
   116  }
   117  
   118  func TestSetVersion(t *testing.T, d database.Driver) {
   119  	// nolint:maligned
   120  	testCases := []struct {
   121  		name            string
   122  		version         int
   123  		dirty           bool
   124  		expectedErr     error
   125  		expectedReadErr error
   126  		expectedVersion int
   127  		expectedDirty   bool
   128  	}{
   129  		{name: "set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
   130  		{name: "re-set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
   131  		{name: "set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
   132  		{name: "re-set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
   133  		{name: "last migration dirty", version: database.NilVersion, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: true},
   134  		{name: "last migration clean", version: database.NilVersion, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: false},
   135  	}
   136  
   137  	for _, tc := range testCases {
   138  		t.Run(tc.name, func(t *testing.T) {
   139  			err := d.SetVersion(tc.version, tc.dirty)
   140  			if err != tc.expectedErr {
   141  				t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr)
   142  			}
   143  			v, dirty, readErr := d.Version()
   144  			if readErr != tc.expectedReadErr {
   145  				t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr)
   146  			}
   147  			if v != tc.expectedVersion {
   148  				t.Error("Got unexpected version:", v, "!=", tc.expectedVersion)
   149  			}
   150  			if dirty != tc.expectedDirty {
   151  				t.Error("Got unexpected dirty value:", dirty, "!=", tc.dirty)
   152  			}
   153  		})
   154  	}
   155  }