github.com/gogriddy/goose@v0.0.0-20180817174216-2c751e0981c8/lib/goose/dbconf_test.go (about)

     1  package goose
     2  
     3  import (
     4  	"io/ioutil"
     5  	"os"
     6  	"path/filepath"
     7  	"reflect"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func setupDBConf(t *testing.T, confPath string, extraPath string) (string, string, func()) {
    16  	td, err := ioutil.TempDir("", "goose-test")
    17  	require.NoError(t, err)
    18  	defer func() {
    19  		if t.Failed() {
    20  			os.RemoveAll(td)
    21  		}
    22  	}()
    23  
    24  	confPath = filepath.Join(strings.Split(confPath, "/")...)
    25  	confPath = filepath.Join(td, confPath)
    26  	confDir := filepath.Dir(confPath)
    27  	err = os.MkdirAll(confDir, 0700)
    28  	require.NoError(t, err)
    29  
    30  	err = ioutil.WriteFile(confPath,
    31  		[]byte("\n"),
    32  		0600)
    33  	require.NoError(t, err)
    34  
    35  	extraDir := filepath.Join(strings.Split(extraPath, "/")...)
    36  	extraDir = filepath.Join(td, extraDir)
    37  	err = os.MkdirAll(extraDir, 0700)
    38  	require.NoError(t, err)
    39  
    40  	return confPath, extraDir, func() { os.RemoveAll(td) }
    41  }
    42  
    43  func TestFindDBConf_confDir(t *testing.T) {
    44  	confNames := []string{"db/dbconf.yaml", "db/dbconf.yml", "dbconf.yaml", "dbconf.yml"}
    45  	for _, confName := range confNames {
    46  		confPath, baseDir, clean := setupDBConf(t, confName, "")
    47  		defer clean()
    48  
    49  		path := findDBConf(baseDir)
    50  		assert.Equal(t, confPath, path)
    51  	}
    52  }
    53  
    54  func TestFindDBConf_deepDir(t *testing.T) {
    55  	confPath, deepDir, clean := setupDBConf(t, "db/dbconf.yaml", "a/b/c")
    56  	defer clean()
    57  
    58  	path := findDBConf(deepDir)
    59  	assert.Equal(t, confPath, path)
    60  }
    61  
    62  func TestFindDBConf_pwd(t *testing.T) {
    63  	confPath, deepDir, clean := setupDBConf(t, "dbconf.yaml", "a/b/c")
    64  	defer clean()
    65  
    66  	pwd, err := os.Getwd()
    67  	require.NoError(t, err)
    68  	defer os.Chdir(pwd)
    69  	err = os.Chdir(deepDir)
    70  	require.NoError(t, err)
    71  
    72  	path := findDBConf("")
    73  	assert.Equal(t, confPath, path)
    74  }
    75  
    76  func TestNewDBConf(t *testing.T) {
    77  	confPath, migrationsDir, clean := setupDBConf(t, "dbconf.yaml", "dbstuff")
    78  	defer clean()
    79  
    80  	err := ioutil.WriteFile(confPath,
    81  		[]byte(`
    82  myenv:
    83  	driver: mysql
    84  	open: foo
    85  	migrationsDir: dbstuff
    86  `),
    87  		0700)
    88  	require.NoError(t, err)
    89  
    90  	dbconf, err := NewDBConf(filepath.Dir(confPath), "myenv")
    91  	require.NoError(t, err)
    92  
    93  	assert.Equal(t, migrationsDir, dbconf.MigrationsDir)
    94  	assert.Equal(t, "mysql", dbconf.Driver.Name)
    95  	assert.Equal(t, "foo", dbconf.Driver.OpenStr)
    96  }
    97  
    98  func TestNewDBConf_default(t *testing.T) {
    99  	// Since the default uses env vars, and also no environment, this tests
   100  	// these 2 additional configurations as well.
   101  	defer os.Setenv("DB_MIGRATIONS_DIR", os.Getenv("DB_MIGRATIONS_DIR"))
   102  	os.Setenv("DB_MIGRATIONS_DIR", "/migdir")
   103  	defer os.Setenv("DB_DRIVER", os.Getenv("DB_DRIVER"))
   104  	os.Setenv("DB_DRIVER", "mysqlite3")
   105  	defer os.Setenv("DB_DRIVER_IMPORT", os.Getenv("DB_DRIVER_IMPORT"))
   106  	os.Setenv("DB_DRIVER_IMPORT", "github.com/myfork/sqlite3")
   107  	defer os.Setenv("DB_DIALECT", os.Getenv("DB_DIALECT"))
   108  	os.Setenv("DB_DIALECT", "sqlite3")
   109  	defer os.Setenv("DB_DSN", os.Getenv("DB_DSN"))
   110  	os.Setenv("DB_DSN", "foo")
   111  
   112  	dbconf, err := NewDBConf("", "")
   113  	require.NoError(t, err)
   114  
   115  	assert.Equal(t, "/migdir", dbconf.MigrationsDir)
   116  	assert.Equal(t, "mysqlite3", dbconf.Driver.Name)
   117  	assert.Equal(t, "github.com/myfork/sqlite3", dbconf.Driver.Import)
   118  	assert.Equal(t, &Sqlite3Dialect{}, dbconf.Driver.Dialect)
   119  	assert.Equal(t, "foo", dbconf.Driver.OpenStr)
   120  }
   121  
   122  func TestNewDBConf_driverImport(t *testing.T) {
   123  	confPath, migrationsDir, clean := setupDBConf(t, "dbconf.yaml", "migrations")
   124  	defer clean()
   125  
   126  	err := ioutil.WriteFile(confPath,
   127  		[]byte(`
   128  myenv:
   129  	driver: github.com/myfork/mysql
   130  	open: foo
   131  `),
   132  		0700)
   133  	require.NoError(t, err)
   134  
   135  	dbconf, err := NewDBConf(filepath.Dir(confPath), "myenv")
   136  	require.NoError(t, err)
   137  
   138  	assert.Equal(t, migrationsDir, dbconf.MigrationsDir)
   139  	assert.Equal(t, "mysql", dbconf.Driver.Name)
   140  	assert.Equal(t, "github.com/myfork/mysql", dbconf.Driver.Import)
   141  	assert.Equal(t, &MySqlDialect{}, dbconf.Driver.Dialect)
   142  	assert.Equal(t, "foo", dbconf.Driver.OpenStr)
   143  }
   144  
   145  func TestNewDBConf_driverDefaults(t *testing.T) {
   146  	tests := []struct {
   147  		names  []string
   148  		driver DBDriver
   149  	}{
   150  		{
   151  			[]string{"postgres", "PoStGrEs"},
   152  			DBDriver{
   153  				Name:    "postgres",
   154  				Import:  "github.com/lib/pq",
   155  				Dialect: &PostgresDialect{},
   156  			},
   157  		},
   158  		{
   159  			[]string{"redshift"},
   160  			DBDriver{
   161  				Name:    "postgres",
   162  				Import:  "github.com/lib/pq",
   163  				Dialect: &RedshiftDialect{},
   164  			},
   165  		},
   166  		{
   167  			[]string{"mymysql"},
   168  			DBDriver{
   169  				Name:    "mymysql",
   170  				Import:  "github.com/ziutek/mymysql/godrv",
   171  				Dialect: &MySqlDialect{},
   172  			},
   173  		},
   174  		{
   175  			[]string{"mysql"},
   176  			DBDriver{
   177  				Name:    "mysql",
   178  				Import:  "github.com/go-sql-driver/mysql",
   179  				Dialect: &MySqlDialect{},
   180  			},
   181  		},
   182  		{
   183  			[]string{"sqlite3"},
   184  			DBDriver{
   185  				Name:    "sqlite3",
   186  				Import:  "github.com/mattn/go-sqlite3",
   187  				Dialect: &Sqlite3Dialect{},
   188  			},
   189  		},
   190  	}
   191  	for _, test := range tests {
   192  		for _, driverName := range test.names {
   193  			confPath, migrationsDir, clean := setupDBConf(t, "dbconf.yaml", "migrations")
   194  			defer clean()
   195  
   196  			err := ioutil.WriteFile(confPath,
   197  				[]byte(`
   198  myenv:
   199  	driver: `+driverName+`
   200  `),
   201  				0700)
   202  			require.NoError(t, err)
   203  
   204  			dbconf, err := NewDBConf(filepath.Dir(confPath), "myenv")
   205  			require.NoError(t, err)
   206  
   207  			assert.Equal(t, migrationsDir, dbconf.MigrationsDir)
   208  			assert.Equal(t, test.driver, dbconf.Driver)
   209  		}
   210  	}
   211  }
   212  
   213  func TestImportOverride(t *testing.T) {
   214  	dbconf, err := NewDBConf("../../_example", "customimport")
   215  	if err != nil {
   216  		t.Fatal(err)
   217  	}
   218  
   219  	got := dbconf.Driver.Import
   220  	want := "github.com/custom/driver"
   221  	if got != want {
   222  		t.Errorf("bad custom import. got %v want %v", got, want)
   223  	}
   224  }
   225  
   226  func TestDriverSetFromEnvironmentVariable(t *testing.T) {
   227  	databaseUrlEnvVariableKey := "DB_DRIVER"
   228  	databaseUrlEnvVariableVal := "sqlite3"
   229  	databaseOpenStringKey := "DATABASE_URL"
   230  	databaseOpenStringVal := "db.db"
   231  
   232  	os.Setenv(databaseUrlEnvVariableKey, databaseUrlEnvVariableVal)
   233  	os.Setenv(databaseOpenStringKey, databaseOpenStringVal)
   234  
   235  	dbconf, err := NewDBConf("../../_example", "environment_variable_config")
   236  	if err != nil {
   237  		t.Fatal(err)
   238  	}
   239  
   240  	got := reflect.TypeOf(dbconf.Driver.Dialect)
   241  	want := reflect.TypeOf(&Sqlite3Dialect{})
   242  
   243  	if got != want {
   244  		t.Errorf("Not able to read the driver type from environment variable."+
   245  			"got %v want %v", got, want)
   246  	}
   247  
   248  	gotOpenString := dbconf.Driver.OpenStr
   249  	wantOpenString := databaseOpenStringVal
   250  
   251  	if gotOpenString != wantOpenString {
   252  		t.Errorf("Not able to read the open string from the environment."+
   253  			"got %v want %v", gotOpenString, wantOpenString)
   254  	}
   255  }