github.com/CloudCom/goose@v0.0.0-20151110184009-e03c3249c21b/lib/goose/dbconf_test.go (about) 1 package goose 2 3 import ( 4 "io/ioutil" 5 "os" 6 "path/filepath" 7 "reflect" 8 "strings" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 ) 14 15 func setupDBConf(t *testing.T, confPath string, extraPath string) (string, string, func()) { 16 td, err := ioutil.TempDir("", "goose-test") 17 require.NoError(t, err) 18 defer func() { 19 if t.Failed() { 20 os.RemoveAll(td) 21 } 22 }() 23 24 confPath = filepath.Join(strings.Split(confPath, "/")...) 25 confPath = filepath.Join(td, confPath) 26 confDir := filepath.Dir(confPath) 27 err = os.MkdirAll(confDir, 0700) 28 require.NoError(t, err) 29 30 err = ioutil.WriteFile(confPath, 31 []byte("\n"), 32 0600) 33 require.NoError(t, err) 34 35 extraDir := filepath.Join(strings.Split(extraPath, "/")...) 36 extraDir = filepath.Join(td, extraDir) 37 err = os.MkdirAll(extraDir, 0700) 38 require.NoError(t, err) 39 40 return confPath, extraDir, func() { os.RemoveAll(td) } 41 } 42 43 func TestFindDBConf_confDir(t *testing.T) { 44 confNames := []string{"db/dbconf.yaml", "db/dbconf.yml", "dbconf.yaml", "dbconf.yml"} 45 for _, confName := range confNames { 46 confPath, baseDir, clean := setupDBConf(t, confName, "") 47 defer clean() 48 49 path := findDBConf(baseDir) 50 assert.Equal(t, confPath, path) 51 } 52 } 53 54 func TestFindDBConf_deepDir(t *testing.T) { 55 confPath, deepDir, clean := setupDBConf(t, "db/dbconf.yaml", "a/b/c") 56 defer clean() 57 58 path := findDBConf(deepDir) 59 assert.Equal(t, confPath, path) 60 } 61 62 func TestFindDBConf_pwd(t *testing.T) { 63 confPath, deepDir, clean := setupDBConf(t, "dbconf.yaml", "a/b/c") 64 defer clean() 65 66 pwd, err := os.Getwd() 67 require.NoError(t, err) 68 defer os.Chdir(pwd) 69 err = os.Chdir(deepDir) 70 require.NoError(t, err) 71 72 path := findDBConf("") 73 assert.Equal(t, confPath, path) 74 } 75 76 func TestNewDBConf(t *testing.T) { 77 confPath, migrationsDir, clean := setupDBConf(t, "dbconf.yaml", "dbstuff") 78 defer clean() 79 80 err := ioutil.WriteFile(confPath, 81 []byte(` 82 myenv: 83 driver: mysql 84 open: foo 85 migrationsDir: dbstuff 86 `), 87 0700) 88 require.NoError(t, err) 89 90 dbconf, err := NewDBConf(filepath.Dir(confPath), "myenv") 91 require.NoError(t, err) 92 93 assert.Equal(t, migrationsDir, dbconf.MigrationsDir) 94 assert.Equal(t, "mysql", dbconf.Driver.Name) 95 assert.Equal(t, "foo", dbconf.Driver.OpenStr) 96 } 97 98 func TestNewDBConf_default(t *testing.T) { 99 // Since the default uses env vars, and also no environment, this tests 100 // these 2 additional configurations as well. 101 defer os.Setenv("DB_MIGRATIONS_DIR", os.Getenv("DB_MIGRATIONS_DIR")) 102 os.Setenv("DB_MIGRATIONS_DIR", "/migdir") 103 defer os.Setenv("DB_DRIVER", os.Getenv("DB_DRIVER")) 104 os.Setenv("DB_DRIVER", "mysqlite3") 105 defer os.Setenv("DB_DRIVER_IMPORT", os.Getenv("DB_DRIVER_IMPORT")) 106 os.Setenv("DB_DRIVER_IMPORT", "github.com/myfork/sqlite3") 107 defer os.Setenv("DB_DIALECT", os.Getenv("DB_DIALECT")) 108 os.Setenv("DB_DIALECT", "sqlite3") 109 defer os.Setenv("DB_DSN", os.Getenv("DB_DSN")) 110 os.Setenv("DB_DSN", "foo") 111 112 dbconf, err := NewDBConf("", "") 113 require.NoError(t, err) 114 115 assert.Equal(t, "/migdir", dbconf.MigrationsDir) 116 assert.Equal(t, "mysqlite3", dbconf.Driver.Name) 117 assert.Equal(t, "github.com/myfork/sqlite3", dbconf.Driver.Import) 118 assert.Equal(t, &Sqlite3Dialect{}, dbconf.Driver.Dialect) 119 assert.Equal(t, "foo", dbconf.Driver.OpenStr) 120 } 121 122 func TestNewDBConf_driverImport(t *testing.T) { 123 confPath, migrationsDir, clean := setupDBConf(t, "dbconf.yaml", "migrations") 124 defer clean() 125 126 err := ioutil.WriteFile(confPath, 127 []byte(` 128 myenv: 129 driver: github.com/myfork/mysql 130 open: foo 131 `), 132 0700) 133 require.NoError(t, err) 134 135 dbconf, err := NewDBConf(filepath.Dir(confPath), "myenv") 136 require.NoError(t, err) 137 138 assert.Equal(t, migrationsDir, dbconf.MigrationsDir) 139 assert.Equal(t, "mysql", dbconf.Driver.Name) 140 assert.Equal(t, "github.com/myfork/mysql", dbconf.Driver.Import) 141 assert.Equal(t, &MySqlDialect{}, dbconf.Driver.Dialect) 142 assert.Equal(t, "foo", dbconf.Driver.OpenStr) 143 } 144 145 func TestNewDBConf_driverDefaults(t *testing.T) { 146 tests := []struct { 147 names []string 148 driver DBDriver 149 }{ 150 { 151 []string{"postgres", "PoStGrEs"}, 152 DBDriver{ 153 Name: "postgres", 154 Import: "github.com/lib/pq", 155 Dialect: &PostgresDialect{}, 156 }, 157 }, 158 { 159 []string{"redshift"}, 160 DBDriver{ 161 Name: "postgres", 162 Import: "github.com/lib/pq", 163 Dialect: &RedshiftDialect{}, 164 }, 165 }, 166 { 167 []string{"mymysql"}, 168 DBDriver{ 169 Name: "mymysql", 170 Import: "github.com/ziutek/mymysql/godrv", 171 Dialect: &MySqlDialect{}, 172 }, 173 }, 174 { 175 []string{"mysql"}, 176 DBDriver{ 177 Name: "mysql", 178 Import: "github.com/go-sql-driver/mysql", 179 Dialect: &MySqlDialect{}, 180 }, 181 }, 182 { 183 []string{"sqlite3"}, 184 DBDriver{ 185 Name: "sqlite3", 186 Import: "github.com/mattn/go-sqlite3", 187 Dialect: &Sqlite3Dialect{}, 188 }, 189 }, 190 } 191 for _, test := range tests { 192 for _, driverName := range test.names { 193 confPath, migrationsDir, clean := setupDBConf(t, "dbconf.yaml", "migrations") 194 defer clean() 195 196 err := ioutil.WriteFile(confPath, 197 []byte(` 198 myenv: 199 driver: `+driverName+` 200 `), 201 0700) 202 require.NoError(t, err) 203 204 dbconf, err := NewDBConf(filepath.Dir(confPath), "myenv") 205 require.NoError(t, err) 206 207 assert.Equal(t, migrationsDir, dbconf.MigrationsDir) 208 assert.Equal(t, test.driver, dbconf.Driver) 209 } 210 } 211 } 212 213 func TestImportOverride(t *testing.T) { 214 dbconf, err := NewDBConf("../../_example", "customimport") 215 if err != nil { 216 t.Fatal(err) 217 } 218 219 got := dbconf.Driver.Import 220 want := "github.com/custom/driver" 221 if got != want { 222 t.Errorf("bad custom import. got %v want %v", got, want) 223 } 224 } 225 226 func TestDriverSetFromEnvironmentVariable(t *testing.T) { 227 databaseUrlEnvVariableKey := "DB_DRIVER" 228 databaseUrlEnvVariableVal := "sqlite3" 229 databaseOpenStringKey := "DATABASE_URL" 230 databaseOpenStringVal := "db.db" 231 232 os.Setenv(databaseUrlEnvVariableKey, databaseUrlEnvVariableVal) 233 os.Setenv(databaseOpenStringKey, databaseOpenStringVal) 234 235 dbconf, err := NewDBConf("../../_example", "environment_variable_config") 236 if err != nil { 237 t.Fatal(err) 238 } 239 240 got := reflect.TypeOf(dbconf.Driver.Dialect) 241 want := reflect.TypeOf(&Sqlite3Dialect{}) 242 243 if got != want { 244 t.Errorf("Not able to read the driver type from environment variable."+ 245 "got %v want %v", got, want) 246 } 247 248 gotOpenString := dbconf.Driver.OpenStr 249 wantOpenString := databaseOpenStringVal 250 251 if gotOpenString != wantOpenString { 252 t.Errorf("Not able to read the open string from the environment."+ 253 "got %v want %v", gotOpenString, wantOpenString) 254 } 255 }