github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/postgres/postgres.go (about) 1 package postgres 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "io" 8 "net/url" 9 "runtime" 10 "strings" 11 12 "github.com/amacneil/dbmate/pkg/dbmate" 13 "github.com/amacneil/dbmate/pkg/dbutil" 14 15 "github.com/lib/pq" 16 ) 17 18 func init() { 19 dbmate.RegisterDriver(NewDriver, "postgres") 20 dbmate.RegisterDriver(NewDriver, "postgresql") 21 } 22 23 // Driver provides top level database functions 24 type Driver struct { 25 migrationsTableName string 26 databaseURL *url.URL 27 log io.Writer 28 } 29 30 // NewDriver initializes the driver 31 func NewDriver(config dbmate.DriverConfig) dbmate.Driver { 32 return &Driver{ 33 migrationsTableName: config.MigrationsTableName, 34 databaseURL: config.DatabaseURL, 35 log: config.Log, 36 } 37 } 38 39 func connectionString(u *url.URL) string { 40 hostname := u.Hostname() 41 port := u.Port() 42 query := u.Query() 43 44 // support socket parameter for consistency with mysql 45 if query.Get("socket") != "" { 46 query.Set("host", query.Get("socket")) 47 query.Del("socket") 48 } 49 50 // default hostname 51 if hostname == "" && query.Get("host") == "" { 52 switch runtime.GOOS { 53 case "linux": 54 query.Set("host", "/var/run/postgresql") 55 case "darwin", "freebsd", "dragonfly", "openbsd", "netbsd": 56 query.Set("host", "/tmp") 57 default: 58 hostname = "localhost" 59 } 60 } 61 62 // host param overrides url hostname 63 if query.Get("host") != "" { 64 hostname = "" 65 } 66 67 // always specify a port 68 if query.Get("port") != "" { 69 port = query.Get("port") 70 query.Del("port") 71 } 72 if port == "" { 73 port = "5432" 74 } 75 76 // generate output URL 77 out, _ := url.Parse(u.String()) 78 out.Host = fmt.Sprintf("%s:%s", hostname, port) 79 out.RawQuery = query.Encode() 80 81 return out.String() 82 } 83 84 func connectionArgsForDump(u *url.URL) []string { 85 u = dbutil.MustParseURL(connectionString(u)) 86 87 // find schemas from search_path 88 query := u.Query() 89 schemas := strings.Split(query.Get("search_path"), ",") 90 query.Del("search_path") 91 u.RawQuery = query.Encode() 92 93 out := []string{} 94 for _, schema := range schemas { 95 schema = strings.TrimSpace(schema) 96 if schema != "" { 97 out = append(out, "--schema", schema) 98 } 99 } 100 out = append(out, u.String()) 101 102 return out 103 } 104 105 // Open creates a new database connection 106 func (drv *Driver) Open() (*sql.DB, error) { 107 return sql.Open("postgres", connectionString(drv.databaseURL)) 108 } 109 110 func (drv *Driver) openPostgresDB() (*sql.DB, error) { 111 // clone databaseURL 112 postgresURL, err := url.Parse(connectionString(drv.databaseURL)) 113 if err != nil { 114 return nil, err 115 } 116 117 // connect to postgres database 118 postgresURL.Path = "postgres" 119 120 return sql.Open("postgres", postgresURL.String()) 121 } 122 123 // CreateDatabase creates the specified database 124 func (drv *Driver) CreateDatabase() error { 125 name := dbutil.DatabaseName(drv.databaseURL) 126 fmt.Fprintf(drv.log, "Creating: %s\n", name) 127 128 db, err := drv.openPostgresDB() 129 if err != nil { 130 return err 131 } 132 defer dbutil.MustClose(db) 133 134 _, err = db.Exec(fmt.Sprintf("create database %s", 135 pq.QuoteIdentifier(name))) 136 137 return err 138 } 139 140 // DropDatabase drops the specified database (if it exists) 141 func (drv *Driver) DropDatabase() error { 142 name := dbutil.DatabaseName(drv.databaseURL) 143 fmt.Fprintf(drv.log, "Dropping: %s\n", name) 144 145 db, err := drv.openPostgresDB() 146 if err != nil { 147 return err 148 } 149 defer dbutil.MustClose(db) 150 151 _, err = db.Exec(fmt.Sprintf("drop database if exists %s", 152 pq.QuoteIdentifier(name))) 153 154 return err 155 } 156 157 func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) { 158 migrationsTable, err := drv.quotedMigrationsTableName(db) 159 if err != nil { 160 return nil, err 161 } 162 163 // load applied migrations 164 migrations, err := dbutil.QueryColumn(db, 165 "select quote_literal(version) from "+migrationsTable+" order by version asc") 166 if err != nil { 167 return nil, err 168 } 169 170 // build migrations table data 171 var buf bytes.Buffer 172 buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n") 173 174 if len(migrations) > 0 { 175 buf.WriteString("INSERT INTO " + migrationsTable + " (version) VALUES\n (" + 176 strings.Join(migrations, "),\n (") + 177 ");\n") 178 } 179 180 return buf.Bytes(), nil 181 } 182 183 // DumpSchema returns the current database schema 184 func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) { 185 // load schema 186 args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only", 187 "--no-privileges", "--no-owner"}, connectionArgsForDump(drv.databaseURL)...) 188 schema, err := dbutil.RunCommand("pg_dump", args...) 189 if err != nil { 190 return nil, err 191 } 192 193 migrations, err := drv.schemaMigrationsDump(db) 194 if err != nil { 195 return nil, err 196 } 197 198 schema = append(schema, migrations...) 199 return dbutil.TrimLeadingSQLComments(schema) 200 } 201 202 // DatabaseExists determines whether the database exists 203 func (drv *Driver) DatabaseExists() (bool, error) { 204 name := dbutil.DatabaseName(drv.databaseURL) 205 206 db, err := drv.openPostgresDB() 207 if err != nil { 208 return false, err 209 } 210 defer dbutil.MustClose(db) 211 212 exists := false 213 err = db.QueryRow("select true from pg_database where datname = $1", name). 214 Scan(&exists) 215 if err == sql.ErrNoRows { 216 return false, nil 217 } 218 219 return exists, err 220 } 221 222 // MigrationsTableExists checks if the schema_migrations table exists 223 func (drv *Driver) MigrationsTableExists(db *sql.DB) (bool, error) { 224 schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db) 225 if err != nil { 226 return false, err 227 } 228 229 exists := false 230 err = db.QueryRow("SELECT 1 FROM information_schema.tables "+ 231 "WHERE table_schema = $1 "+ 232 "AND table_name = $2", 233 schema, migrationsTable). 234 Scan(&exists) 235 if err == sql.ErrNoRows { 236 return false, nil 237 } 238 239 return exists, err 240 } 241 242 // CreateMigrationsTable creates the schema_migrations table 243 func (drv *Driver) CreateMigrationsTable(db *sql.DB) error { 244 schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db) 245 if err != nil { 246 return err 247 } 248 249 // first attempt at creating migrations table 250 createTableStmt := fmt.Sprintf( 251 "create table if not exists %s.%s (version varchar(128) primary key)", 252 schema, migrationsTable) 253 _, err = db.Exec(createTableStmt) 254 if err == nil { 255 // table exists or created successfully 256 return nil 257 } 258 259 // catch 'schema does not exist' error 260 pqErr, ok := err.(*pq.Error) 261 if !ok || pqErr.Code != "3F000" { 262 // unknown error 263 return err 264 } 265 266 // in theory we could attempt to create the schema every time, but we avoid that 267 // in case the user doesn't have permissions to create schemas 268 fmt.Fprintf(drv.log, "Creating schema: %s\n", schema) 269 _, err = db.Exec(fmt.Sprintf("create schema if not exists %s", schema)) 270 if err != nil { 271 return err 272 } 273 274 // second and final attempt at creating migrations table 275 _, err = db.Exec(createTableStmt) 276 return err 277 } 278 279 // SelectMigrations returns a list of applied migrations 280 // with an optional limit (in descending order) 281 func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { 282 migrationsTable, err := drv.quotedMigrationsTableName(db) 283 if err != nil { 284 return nil, err 285 } 286 287 query := "select version from " + migrationsTable + " order by version desc" 288 if limit >= 0 { 289 query = fmt.Sprintf("%s limit %d", query, limit) 290 } 291 rows, err := db.Query(query) 292 if err != nil { 293 return nil, err 294 } 295 296 defer dbutil.MustClose(rows) 297 298 migrations := map[string]bool{} 299 for rows.Next() { 300 var version string 301 if err := rows.Scan(&version); err != nil { 302 return nil, err 303 } 304 305 migrations[version] = true 306 } 307 308 if err = rows.Err(); err != nil { 309 return nil, err 310 } 311 312 return migrations, nil 313 } 314 315 // InsertMigration adds a new migration record 316 func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error { 317 migrationsTable, err := drv.quotedMigrationsTableName(db) 318 if err != nil { 319 return err 320 } 321 322 _, err = db.Exec("insert into "+migrationsTable+" (version) values ($1)", version) 323 324 return err 325 } 326 327 // DeleteMigration removes a migration record 328 func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error { 329 migrationsTable, err := drv.quotedMigrationsTableName(db) 330 if err != nil { 331 return err 332 } 333 334 _, err = db.Exec("delete from "+migrationsTable+" where version = $1", version) 335 336 return err 337 } 338 339 // Ping verifies a connection to the database server. It does not verify whether the 340 // specified database exists. 341 func (drv *Driver) Ping() error { 342 // attempt connection to primary database, not "postgres" database 343 // to support servers with no "postgres" database 344 // (see https://github.com/amacneil/dbmate/issues/78) 345 db, err := drv.Open() 346 if err != nil { 347 return err 348 } 349 defer dbutil.MustClose(db) 350 351 err = db.Ping() 352 if err == nil { 353 return nil 354 } 355 356 // ignore 'database does not exist' error 357 pqErr, ok := err.(*pq.Error) 358 if ok && pqErr.Code == "3D000" { 359 return nil 360 } 361 362 return err 363 } 364 365 func (drv *Driver) quotedMigrationsTableName(db dbutil.Transaction) (string, error) { 366 schema, name, err := drv.quotedMigrationsTableNameParts(db) 367 if err != nil { 368 return "", err 369 } 370 371 return schema + "." + name, nil 372 } 373 374 func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string, string, error) { 375 schema := "" 376 tableNameParts := strings.Split(drv.migrationsTableName, ".") 377 if len(tableNameParts) > 1 { 378 // schema specified as part of table name 379 schema, tableNameParts = tableNameParts[0], tableNameParts[1:] 380 } 381 382 if schema == "" { 383 // no schema specified with table name, try URL search path if available 384 searchPath := strings.Split(drv.databaseURL.Query().Get("search_path"), ",") 385 schema = strings.TrimSpace(searchPath[0]) 386 } 387 388 var err error 389 if schema == "" { 390 // if no URL available, use current schema 391 // this is a hack because we don't always have the URL context available 392 schema, err = dbutil.QueryValue(db, "select current_schema()") 393 if err != nil { 394 return "", "", err 395 } 396 } 397 398 // fall back to public schema as last resort 399 if schema == "" { 400 schema = "public" 401 } 402 403 // quote all parts 404 // use server rather than client to do this to avoid unnecessary quotes 405 // (which would change schema.sql diff) 406 tableNameParts = append([]string{schema}, tableNameParts...) 407 quotedNameParts, err := dbutil.QueryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts)) 408 if err != nil { 409 return "", "", err 410 } 411 412 // if more than one part, we already have a schema 413 return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil 414 }