github.com/Odesyuk/pop@v4.13.1+incompatible/dialect_postgresql.go (about) 1 package pop 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "os/exec" 8 "strings" 9 "sync" 10 "unicode" 11 12 "github.com/gobuffalo/fizz" 13 "github.com/gobuffalo/fizz/translators" 14 "github.com/gobuffalo/pop/columns" 15 "github.com/gobuffalo/pop/internal/defaults" 16 "github.com/gobuffalo/pop/logging" 17 "github.com/jmoiron/sqlx" 18 pg "github.com/lib/pq" 19 "github.com/pkg/errors" 20 ) 21 22 const namePostgreSQL = "postgres" 23 const portPostgreSQL = "5432" 24 25 func init() { 26 AvailableDialects = append(AvailableDialects, namePostgreSQL) 27 dialectSynonyms["postgresql"] = namePostgreSQL 28 dialectSynonyms["pg"] = namePostgreSQL 29 urlParser[namePostgreSQL] = urlParserPostgreSQL 30 finalizer[namePostgreSQL] = finalizerPostgreSQL 31 newConnection[namePostgreSQL] = newPostgreSQL 32 } 33 34 var _ dialect = &postgresql{} 35 36 type postgresql struct { 37 commonDialect 38 translateCache map[string]string 39 mu sync.Mutex 40 } 41 42 func (p *postgresql) Name() string { 43 return namePostgreSQL 44 } 45 46 func (p *postgresql) Details() *ConnectionDetails { 47 return p.ConnectionDetails 48 } 49 50 func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error { 51 keyType := model.PrimaryKeyType() 52 switch keyType { 53 case "int", "int64": 54 cols.Remove("id") 55 id := struct { 56 ID int `db:"id"` 57 }{} 58 w := cols.Writeable() 59 var query string 60 if len(w.Cols) > 0 { 61 query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString()) 62 } else { 63 query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName())) 64 } 65 log(logging.SQL, query) 66 stmt, err := s.PrepareNamed(query) 67 if err != nil { 68 return err 69 } 70 err = stmt.Get(&id, model.Value) 71 if err != nil { 72 if err := stmt.Close(); err != nil { 73 return errors.WithMessage(err, "failed to close statement") 74 } 75 return err 76 } 77 model.setID(id.ID) 78 return errors.WithMessage(stmt.Close(), "failed to close statement") 79 } 80 return genericCreate(s, model, cols, p) 81 } 82 83 func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error { 84 return genericUpdate(s, model, cols, p) 85 } 86 87 func (p *postgresql) Destroy(s store, model *Model) error { 88 stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", p.Quote(model.TableName()), model.whereID())) 89 _, err := genericExec(s, stmt, model.ID()) 90 if err != nil { 91 return err 92 } 93 return nil 94 } 95 96 func (p *postgresql) SelectOne(s store, model *Model, query Query) error { 97 return genericSelectOne(s, model, query) 98 } 99 100 func (p *postgresql) SelectMany(s store, models *Model, query Query) error { 101 return genericSelectMany(s, models, query) 102 } 103 104 func (p *postgresql) CreateDB() error { 105 // createdb -h db -p 5432 -U postgres enterprise_development 106 deets := p.ConnectionDetails 107 db, err := sql.Open(deets.Dialect, p.urlWithoutDb()) 108 if err != nil { 109 return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database) 110 } 111 defer db.Close() 112 query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database)) 113 log(logging.SQL, query) 114 115 _, err = db.Exec(query) 116 if err != nil { 117 return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database) 118 } 119 120 log(logging.Info, "created database %s", deets.Database) 121 return nil 122 } 123 124 func (p *postgresql) DropDB() error { 125 deets := p.ConnectionDetails 126 db, err := sql.Open(deets.Dialect, p.urlWithoutDb()) 127 if err != nil { 128 return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database) 129 } 130 defer db.Close() 131 query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database)) 132 log(logging.SQL, query) 133 134 _, err = db.Exec(query) 135 if err != nil { 136 return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database) 137 } 138 139 log(logging.Info, "dropped database %s", deets.Database) 140 return nil 141 } 142 143 func (p *postgresql) URL() string { 144 c := p.ConnectionDetails 145 if c.URL != "" { 146 return c.URL 147 } 148 s := "postgres://%s:%s@%s:%s/%s?%s" 149 return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.Database, c.OptionsString("")) 150 } 151 152 func (p *postgresql) urlWithoutDb() string { 153 c := p.ConnectionDetails 154 // https://github.com/gobuffalo/buffalo/issues/836 155 // If the db is not precised, postgresql takes the username as the database to connect on. 156 // To avoid a connection problem if the user db is not here, we use the default "postgres" 157 // db, just like the other client tools do. 158 s := "postgres://%s:%s@%s:%s/postgres?%s" 159 return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.OptionsString("")) 160 } 161 162 func (p *postgresql) MigrationURL() string { 163 return p.URL() 164 } 165 166 func (p *postgresql) TranslateSQL(sql string) string { 167 defer p.mu.Unlock() 168 p.mu.Lock() 169 170 if csql, ok := p.translateCache[sql]; ok { 171 return csql 172 } 173 csql := sqlx.Rebind(sqlx.DOLLAR, sql) 174 175 p.translateCache[sql] = csql 176 return csql 177 } 178 179 func (p *postgresql) FizzTranslator() fizz.Translator { 180 return translators.NewPostgres() 181 } 182 183 func (p *postgresql) DumpSchema(w io.Writer) error { 184 cmd := exec.Command("pg_dump", "-s", fmt.Sprintf("--dbname=%s", p.URL())) 185 return genericDumpSchema(p.Details(), cmd, w) 186 } 187 188 // LoadSchema executes a schema sql file against the configured database. 189 func (p *postgresql) LoadSchema(r io.Reader) error { 190 return genericLoadSchema(p.ConnectionDetails, p.MigrationURL(), r) 191 } 192 193 // TruncateAll truncates all tables for the given connection. 194 func (p *postgresql) TruncateAll(tx *Connection) error { 195 return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec() 196 } 197 198 func newPostgreSQL(deets *ConnectionDetails) (dialect, error) { 199 cd := &postgresql{ 200 commonDialect: commonDialect{ConnectionDetails: deets}, 201 translateCache: map[string]string{}, 202 mu: sync.Mutex{}, 203 } 204 return cd, nil 205 } 206 207 // urlParserPostgreSQL parses the options the same way official lib/pg does: 208 // https://godoc.org/github.com/lib/pq#hdr-Connection_String_Parameters 209 // After parsed, they are set to ConnectionDetails instance 210 func urlParserPostgreSQL(cd *ConnectionDetails) error { 211 var err error 212 name := cd.URL 213 if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { 214 name, err = pg.ParseURL(name) 215 if err != nil { 216 return err 217 } 218 } 219 220 o := make(values) 221 if err := parseOpts(name, o); err != nil { 222 return err 223 } 224 225 if dbname, ok := o["dbname"]; ok { 226 cd.Database = dbname 227 } 228 if host, ok := o["host"]; ok { 229 cd.Host = host 230 } 231 if password, ok := o["password"]; ok { 232 cd.Password = password 233 } 234 if user, ok := o["user"]; ok { 235 cd.User = user 236 } 237 if port, ok := o["port"]; ok { 238 cd.Port = port 239 } 240 241 options := []string{"sslmode", "fallback_application_name", "connect_timeout", "sslcert", "sslkey", "sslrootcert"} 242 243 for i := range options { 244 if opt, ok := o[options[i]]; ok { 245 cd.Options[options[i]] = opt 246 } 247 } 248 249 return nil 250 } 251 252 func finalizerPostgreSQL(cd *ConnectionDetails) { 253 cd.Options["sslmode"] = defaults.String(cd.Options["sslmode"], "disable") 254 cd.Port = defaults.String(cd.Port, portPostgreSQL) 255 } 256 257 const pgTruncate = `DO 258 $func$ 259 DECLARE 260 _tbl text; 261 _sch text; 262 BEGIN 263 FOR _sch, _tbl IN 264 SELECT schemaname, tablename 265 FROM pg_tables 266 WHERE tablename <> '%s' AND schemaname NOT IN ('pg_catalog', 'information_schema') AND tableowner = current_user 267 LOOP 268 --RAISE ERROR '%%', 269 EXECUTE -- dangerous, test before you execute! 270 format('TRUNCATE TABLE %%I.%%I CASCADE', _sch, _tbl); 271 END LOOP; 272 END 273 $func$;` 274 275 // Code below is ported from: https://github.com/lib/pq/blob/master/conn.go 276 type values map[string]string 277 278 // scanner implements a tokenizer for libpq-style option strings. 279 type scanner struct { 280 s []rune 281 i int 282 } 283 284 // newScanner returns a new scanner initialized with the option string s. 285 func newScanner(s string) *scanner { 286 return &scanner{[]rune(s), 0} 287 } 288 289 // Next returns the next rune. 290 // It returns 0, false if the end of the text has been reached. 291 func (s *scanner) Next() (rune, bool) { 292 if s.i >= len(s.s) { 293 return 0, false 294 } 295 r := s.s[s.i] 296 s.i++ 297 return r, true 298 } 299 300 // SkipSpaces returns the next non-whitespace rune. 301 // It returns 0, false if the end of the text has been reached. 302 func (s *scanner) SkipSpaces() (rune, bool) { 303 r, ok := s.Next() 304 for unicode.IsSpace(r) && ok { 305 r, ok = s.Next() 306 } 307 return r, ok 308 } 309 310 // parseOpts parses the options from name and adds them to the values. 311 // 312 // The parsing code is based on conninfo_parse from libpq's fe-connect.c 313 func parseOpts(name string, o values) error { 314 s := newScanner(name) 315 316 for { 317 var ( 318 keyRunes, valRunes []rune 319 r rune 320 ok bool 321 ) 322 323 if r, ok = s.SkipSpaces(); !ok { 324 break 325 } 326 327 // Scan the key 328 for !unicode.IsSpace(r) && r != '=' { 329 keyRunes = append(keyRunes, r) 330 if r, ok = s.Next(); !ok { 331 break 332 } 333 } 334 335 // Skip any whitespace if we're not at the = yet 336 if r != '=' { 337 r, ok = s.SkipSpaces() 338 } 339 340 // The current character should be = 341 if r != '=' || !ok { 342 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) 343 } 344 345 // Skip any whitespace after the = 346 if r, ok = s.SkipSpaces(); !ok { 347 // If we reach the end here, the last value is just an empty string as per libpq. 348 o[string(keyRunes)] = "" 349 break 350 } 351 352 if r != '\'' { 353 for !unicode.IsSpace(r) { 354 if r == '\\' { 355 if r, ok = s.Next(); !ok { 356 return fmt.Errorf(`missing character after backslash`) 357 } 358 } 359 valRunes = append(valRunes, r) 360 361 if r, ok = s.Next(); !ok { 362 break 363 } 364 } 365 } else { 366 quote: 367 for { 368 if r, ok = s.Next(); !ok { 369 return fmt.Errorf(`unterminated quoted string literal in connection string`) 370 } 371 switch r { 372 case '\'': 373 break quote 374 case '\\': 375 r, _ = s.Next() 376 fallthrough 377 default: 378 valRunes = append(valRunes, r) 379 } 380 } 381 } 382 383 o[string(keyRunes)] = string(valRunes) 384 } 385 386 return nil 387 }