github.com/CloudCom/goose@v0.0.0-20151110184009-e03c3249c21b/lib/goose/dbconf.go (about)

     1  package goose
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"net/url"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  
    12  	"github.com/kylelemons/go-gypsy/yaml"
    13  )
    14  
    15  // DBDriver encapsulates the info needed to work with
    16  // a specific database driver
    17  type DBDriver struct {
    18  	Name    string
    19  	OpenStr string
    20  	Import  string
    21  	Dialect SqlDialect
    22  }
    23  
    24  type DBConf struct {
    25  	MigrationsDir string
    26  	Driver        DBDriver
    27  }
    28  
    29  var defaultDBConfYaml = `
    30  migrationsDir: $DB_MIGRATIONS_DIR
    31  driver: $DB_DRIVER
    32  import: $DB_DRIVER_IMPORT
    33  dialect: $DB_DIALECT
    34  open: $DB_DSN
    35  `
    36  
    37  // findDBConf looks for a dbconf.yaml file starting at the given directory and
    38  // walking up in the directory hierarchy.
    39  // Returns empty string if not found.
    40  func findDBConf(dbDir string) string {
    41  	dbDir, err := filepath.Abs(dbDir)
    42  	if err != nil {
    43  		return ""
    44  	}
    45  
    46  	for {
    47  		paths := []string{
    48  			"dbconf.yaml",
    49  			"dbconf.yml",
    50  			filepath.Join("db", "dbconf.yaml"),
    51  			filepath.Join("db", "dbconf.yml"),
    52  		}
    53  
    54  		for _, path := range paths {
    55  			path = filepath.Join(dbDir, path)
    56  			if _, err := os.Stat(path); err == nil {
    57  				return path
    58  			}
    59  		}
    60  
    61  		nextDir := filepath.Dir(dbDir)
    62  		if nextDir == dbDir {
    63  			// at the root
    64  			break
    65  		}
    66  		dbDir = nextDir
    67  	}
    68  
    69  	return ""
    70  }
    71  
    72  func confGet(f *yaml.File, env string, name string) (string, error) {
    73  	if env != "" {
    74  		if v, err := f.Get(fmt.Sprintf("%s.%s", env, name)); err == nil {
    75  			return os.ExpandEnv(v), nil
    76  		}
    77  	}
    78  	v, err := f.Get(name)
    79  	if err != nil {
    80  		return "", err
    81  	}
    82  	return os.ExpandEnv(v), nil
    83  }
    84  
    85  // extract configuration details from the given file
    86  func NewDBConf(dbDir, env string) (*DBConf, error) {
    87  	cfgFile := findDBConf(dbDir)
    88  	var f *yaml.File
    89  	if cfgFile == "" {
    90  		root, _ := yaml.Parse(strings.NewReader(defaultDBConfYaml))
    91  		f = &yaml.File{
    92  			Root: root,
    93  		}
    94  	} else {
    95  		dbDir = filepath.Dir(cfgFile)
    96  
    97  		var err error
    98  		f, err = yaml.ReadFile(cfgFile)
    99  		if err != nil {
   100  			return nil, fmt.Errorf("error loading config file: %s", err)
   101  		}
   102  	}
   103  
   104  	migrationsDir := filepath.Join(dbDir, "migrations")
   105  	if md, err := confGet(f, env, "migrationsDir"); err == nil {
   106  		if filepath.IsAbs(md) {
   107  			migrationsDir = md
   108  		} else {
   109  			migrationsDir = filepath.Join(dbDir, md)
   110  		}
   111  	}
   112  
   113  	drv, err := confGet(f, env, "driver")
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	var imprt string
   118  	// see if "driver" param is a full import path
   119  	if i := strings.LastIndex(drv, "/"); i != -1 {
   120  		imprt = drv
   121  		drv = imprt[i+1:]
   122  	}
   123  
   124  	open, _ := confGet(f, env, "open")
   125  
   126  	d := newDBDriver(drv, open)
   127  
   128  	if imprt != "" {
   129  		d.Import = imprt
   130  	}
   131  	// allow the configuration to override the Import for this driver
   132  	if imprt, err := confGet(f, env, "import"); err == nil && imprt != "" {
   133  		d.Import = imprt
   134  	}
   135  
   136  	// allow the configuration to override the Dialect for this driver
   137  	if dialect, err := confGet(f, env, "dialect"); err == nil && dialect != "" {
   138  		d.Dialect = dialectByName(dialect)
   139  	}
   140  
   141  	if !d.IsValid() {
   142  		return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
   143  	}
   144  
   145  	return &DBConf{
   146  		MigrationsDir: migrationsDir,
   147  		Driver:        d,
   148  	}, nil
   149  }
   150  
   151  // Create a new DBDriver and populate driver specific
   152  // fields for drivers that we know about.
   153  // Further customization may be done in NewDBConf
   154  func newDBDriver(name, open string) DBDriver {
   155  	d := DBDriver{
   156  		Name:    name,
   157  		OpenStr: open,
   158  	}
   159  
   160  	switch strings.ToLower(name) {
   161  	case "postgres":
   162  		d.Name = "postgres"
   163  		d.Import = "github.com/lib/pq"
   164  		d.Dialect = &PostgresDialect{}
   165  
   166  	case "redshift":
   167  		d.Name = "postgres"
   168  		d.Import = "github.com/lib/pq"
   169  		d.Dialect = &RedshiftDialect{}
   170  
   171  	case "mymysql":
   172  		d.Import = "github.com/ziutek/mymysql/godrv"
   173  		d.Dialect = &MySqlDialect{}
   174  
   175  	case "mysql":
   176  		d.Import = "github.com/go-sql-driver/mysql"
   177  		d.Dialect = &MySqlDialect{}
   178  
   179  	case "sqlite3":
   180  		d.Name = "sqlite3"
   181  		d.Import = "github.com/mattn/go-sqlite3"
   182  		d.Dialect = &Sqlite3Dialect{}
   183  	}
   184  
   185  	return d
   186  }
   187  
   188  // ensure we have enough info about this driver
   189  func (drv *DBDriver) IsValid() bool {
   190  	return len(drv.Import) > 0 && drv.Dialect != nil
   191  }
   192  
   193  // OpenDBFromDBConf wraps database/sql.DB.Open() and configures
   194  // the newly opened DB based on the given DBConf.
   195  //
   196  // Callers must Close() the returned DB.
   197  func OpenDBFromDBConf(conf *DBConf) (*sql.DB, error) {
   198  	// we depend on time parsing, so make sure it's enabled with the mysql driver
   199  	if conf.Driver.Name == "mysql" {
   200  		i := strings.Index(conf.Driver.OpenStr, "?")
   201  		if i == -1 {
   202  			i = len(conf.Driver.OpenStr)
   203  			conf.Driver.OpenStr = conf.Driver.OpenStr + "?"
   204  		}
   205  		i++
   206  
   207  		q, err := url.ParseQuery(conf.Driver.OpenStr[i:])
   208  		if err != nil {
   209  			return nil, err
   210  		}
   211  		q.Set("parseTime", "true")
   212  
   213  		conf.Driver.OpenStr = conf.Driver.OpenStr[:i] + q.Encode()
   214  	}
   215  
   216  	return sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
   217  }