github.com/gogriddy/goose@v0.0.0-20180817174216-2c751e0981c8/lib/goose/dbconf.go (about) 1 package goose 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "net/url" 8 "os" 9 "path/filepath" 10 "strings" 11 12 "github.com/kylelemons/go-gypsy/yaml" 13 ) 14 15 // DBDriver encapsulates the info needed to work with 16 // a specific database driver 17 type DBDriver struct { 18 Name string 19 OpenStr string 20 Import string 21 Dialect SqlDialect 22 } 23 24 type DBConf struct { 25 MigrationsDir string 26 Driver DBDriver 27 } 28 29 var defaultDBConfYaml = ` 30 migrationsDir: $DB_MIGRATIONS_DIR 31 driver: $DB_DRIVER 32 import: $DB_DRIVER_IMPORT 33 dialect: $DB_DIALECT 34 open: $DB_DSN 35 ` 36 37 // findDBConf looks for a dbconf.yaml file starting at the given directory and 38 // walking up in the directory hierarchy. 39 // Returns empty string if not found. 40 func findDBConf(dbDir string) string { 41 dbDir, err := filepath.Abs(dbDir) 42 if err != nil { 43 return "" 44 } 45 46 for { 47 paths := []string{ 48 "dbconf.yaml", 49 "dbconf.yml", 50 filepath.Join("db", "dbconf.yaml"), 51 filepath.Join("db", "dbconf.yml"), 52 } 53 54 for _, path := range paths { 55 path = filepath.Join(dbDir, path) 56 if _, err := os.Stat(path); err == nil { 57 return path 58 } 59 } 60 61 nextDir := filepath.Dir(dbDir) 62 if nextDir == dbDir { 63 // at the root 64 break 65 } 66 dbDir = nextDir 67 } 68 69 return "" 70 } 71 72 func confGet(f *yaml.File, env string, name string) (string, error) { 73 if env != "" { 74 if v, err := f.Get(fmt.Sprintf("%s.%s", env, name)); err == nil { 75 return os.ExpandEnv(v), nil 76 } 77 } 78 v, err := f.Get(name) 79 if err != nil { 80 return "", err 81 } 82 return os.ExpandEnv(v), nil 83 } 84 85 // extract configuration details from the given file 86 func NewDBConf(dbDir, env string) (*DBConf, error) { 87 cfgFile := findDBConf(dbDir) 88 var f *yaml.File 89 if cfgFile == "" { 90 root, _ := yaml.Parse(strings.NewReader(defaultDBConfYaml)) 91 f = &yaml.File{ 92 Root: root, 93 } 94 } else { 95 dbDir = filepath.Dir(cfgFile) 96 97 var err error 98 f, err = yaml.ReadFile(cfgFile) 99 if err != nil { 100 return nil, fmt.Errorf("error loading config file: %s", err) 101 } 102 } 103 104 migrationsDir := filepath.Join(dbDir, "migrations") 105 if md, err := confGet(f, env, "migrationsDir"); err == nil { 106 if filepath.IsAbs(md) { 107 migrationsDir = md 108 } else { 109 migrationsDir = filepath.Join(dbDir, md) 110 } 111 } 112 113 drv, err := confGet(f, env, "driver") 114 if err != nil { 115 return nil, err 116 } 117 var imprt string 118 // see if "driver" param is a full import path 119 if i := strings.LastIndex(drv, "/"); i != -1 { 120 imprt = drv 121 drv = imprt[i+1:] 122 } 123 124 open, _ := confGet(f, env, "open") 125 126 d := newDBDriver(drv, open) 127 128 if imprt != "" { 129 d.Import = imprt 130 } 131 // allow the configuration to override the Import for this driver 132 if imprt, err := confGet(f, env, "import"); err == nil && imprt != "" { 133 d.Import = imprt 134 } 135 136 // allow the configuration to override the Dialect for this driver 137 if dialect, err := confGet(f, env, "dialect"); err == nil && dialect != "" { 138 d.Dialect = dialectByName(dialect) 139 } 140 141 if !d.IsValid() { 142 return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d)) 143 } 144 145 return &DBConf{ 146 MigrationsDir: migrationsDir, 147 Driver: d, 148 }, nil 149 } 150 151 // Create a new DBDriver and populate driver specific 152 // fields for drivers that we know about. 153 // Further customization may be done in NewDBConf 154 func newDBDriver(name, open string) DBDriver { 155 d := DBDriver{ 156 Name: name, 157 OpenStr: open, 158 } 159 160 switch strings.ToLower(name) { 161 case "postgres": 162 d.Name = "postgres" 163 d.Import = "github.com/lib/pq" 164 d.Dialect = &PostgresDialect{} 165 166 case "redshift": 167 d.Name = "postgres" 168 d.Import = "github.com/lib/pq" 169 d.Dialect = &RedshiftDialect{} 170 171 case "mymysql": 172 d.Import = "github.com/ziutek/mymysql/godrv" 173 d.Dialect = &MySqlDialect{} 174 175 case "mysql": 176 d.Import = "github.com/go-sql-driver/mysql" 177 d.Dialect = &MySqlDialect{} 178 179 case "sqlite3": 180 d.Name = "sqlite3" 181 d.Import = "github.com/mattn/go-sqlite3" 182 d.Dialect = &Sqlite3Dialect{} 183 } 184 185 return d 186 } 187 188 // ensure we have enough info about this driver 189 func (drv *DBDriver) IsValid() bool { 190 return len(drv.Import) > 0 && drv.Dialect != nil 191 } 192 193 // OpenDBFromDBConf wraps database/sql.DB.Open() and configures 194 // the newly opened DB based on the given DBConf. 195 // 196 // Callers must Close() the returned DB. 197 func OpenDBFromDBConf(conf *DBConf) (*sql.DB, error) { 198 // we depend on time parsing, so make sure it's enabled with the mysql driver 199 if conf.Driver.Name == "mysql" { 200 i := strings.Index(conf.Driver.OpenStr, "?") 201 if i == -1 { 202 i = len(conf.Driver.OpenStr) 203 conf.Driver.OpenStr = conf.Driver.OpenStr + "?" 204 } 205 i++ 206 207 q, err := url.ParseQuery(conf.Driver.OpenStr[i:]) 208 if err != nil { 209 return nil, err 210 } 211 q.Set("parseTime", "true") 212 213 conf.Driver.OpenStr = conf.Driver.OpenStr[:i] + q.Encode() 214 } 215 216 return sql.Open(conf.Driver.Name, conf.Driver.OpenStr) 217 }