github.com/royge/pop@v4.13.1+incompatible/dialect_postgresql.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"os/exec"
     8  	"strings"
     9  	"sync"
    10  	"unicode"
    11  
    12  	"github.com/gobuffalo/fizz"
    13  	"github.com/gobuffalo/fizz/translators"
    14  	"github.com/gobuffalo/pop/columns"
    15  	"github.com/gobuffalo/pop/internal/defaults"
    16  	"github.com/gobuffalo/pop/logging"
    17  	"github.com/jmoiron/sqlx"
    18  	pg "github.com/lib/pq"
    19  	"github.com/pkg/errors"
    20  )
    21  
    22  const namePostgreSQL = "postgres"
    23  const portPostgreSQL = "5432"
    24  
    25  func init() {
    26  	AvailableDialects = append(AvailableDialects, namePostgreSQL)
    27  	dialectSynonyms["postgresql"] = namePostgreSQL
    28  	dialectSynonyms["pg"] = namePostgreSQL
    29  	urlParser[namePostgreSQL] = urlParserPostgreSQL
    30  	finalizer[namePostgreSQL] = finalizerPostgreSQL
    31  	newConnection[namePostgreSQL] = newPostgreSQL
    32  }
    33  
    34  var _ dialect = &postgresql{}
    35  
    36  type postgresql struct {
    37  	commonDialect
    38  	translateCache map[string]string
    39  	mu             sync.Mutex
    40  }
    41  
    42  func (p *postgresql) Name() string {
    43  	return namePostgreSQL
    44  }
    45  
    46  func (p *postgresql) Details() *ConnectionDetails {
    47  	return p.ConnectionDetails
    48  }
    49  
    50  func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
    51  	keyType := model.PrimaryKeyType()
    52  	switch keyType {
    53  	case "int", "int64":
    54  		cols.Remove("id")
    55  		id := struct {
    56  			ID int `db:"id"`
    57  		}{}
    58  		w := cols.Writeable()
    59  		var query string
    60  		if len(w.Cols) > 0 {
    61  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
    62  		} else {
    63  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
    64  		}
    65  		log(logging.SQL, query)
    66  		stmt, err := s.PrepareNamed(query)
    67  		if err != nil {
    68  			return err
    69  		}
    70  		err = stmt.Get(&id, model.Value)
    71  		if err != nil {
    72  			if err := stmt.Close(); err != nil {
    73  				return errors.WithMessage(err, "failed to close statement")
    74  			}
    75  			return err
    76  		}
    77  		model.setID(id.ID)
    78  		return errors.WithMessage(stmt.Close(), "failed to close statement")
    79  	}
    80  	return genericCreate(s, model, cols, p)
    81  }
    82  
    83  func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
    84  	return genericUpdate(s, model, cols, p)
    85  }
    86  
    87  func (p *postgresql) Destroy(s store, model *Model) error {
    88  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", p.Quote(model.TableName()), model.whereID()))
    89  	_, err := genericExec(s, stmt, model.ID())
    90  	if err != nil {
    91  		return err
    92  	}
    93  	return nil
    94  }
    95  
    96  func (p *postgresql) SelectOne(s store, model *Model, query Query) error {
    97  	return genericSelectOne(s, model, query)
    98  }
    99  
   100  func (p *postgresql) SelectMany(s store, models *Model, query Query) error {
   101  	return genericSelectMany(s, models, query)
   102  }
   103  
   104  func (p *postgresql) CreateDB() error {
   105  	// createdb -h db -p 5432 -U postgres enterprise_development
   106  	deets := p.ConnectionDetails
   107  	db, err := sql.Open(deets.Dialect, p.urlWithoutDb())
   108  	if err != nil {
   109  		return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
   110  	}
   111  	defer db.Close()
   112  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   113  	log(logging.SQL, query)
   114  
   115  	_, err = db.Exec(query)
   116  	if err != nil {
   117  		return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
   118  	}
   119  
   120  	log(logging.Info, "created database %s", deets.Database)
   121  	return nil
   122  }
   123  
   124  func (p *postgresql) DropDB() error {
   125  	deets := p.ConnectionDetails
   126  	db, err := sql.Open(deets.Dialect, p.urlWithoutDb())
   127  	if err != nil {
   128  		return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
   129  	}
   130  	defer db.Close()
   131  	query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database))
   132  	log(logging.SQL, query)
   133  
   134  	_, err = db.Exec(query)
   135  	if err != nil {
   136  		return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
   137  	}
   138  
   139  	log(logging.Info, "dropped database %s", deets.Database)
   140  	return nil
   141  }
   142  
   143  func (p *postgresql) URL() string {
   144  	c := p.ConnectionDetails
   145  	if c.URL != "" {
   146  		return c.URL
   147  	}
   148  	s := "postgres://%s:%s@%s:%s/%s?%s"
   149  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.Database, c.OptionsString(""))
   150  }
   151  
   152  func (p *postgresql) urlWithoutDb() string {
   153  	c := p.ConnectionDetails
   154  	// https://github.com/gobuffalo/buffalo/issues/836
   155  	// If the db is not precised, postgresql takes the username as the database to connect on.
   156  	// To avoid a connection problem if the user db is not here, we use the default "postgres"
   157  	// db, just like the other client tools do.
   158  	s := "postgres://%s:%s@%s:%s/postgres?%s"
   159  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.OptionsString(""))
   160  }
   161  
   162  func (p *postgresql) MigrationURL() string {
   163  	return p.URL()
   164  }
   165  
   166  func (p *postgresql) TranslateSQL(sql string) string {
   167  	defer p.mu.Unlock()
   168  	p.mu.Lock()
   169  
   170  	if csql, ok := p.translateCache[sql]; ok {
   171  		return csql
   172  	}
   173  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   174  
   175  	p.translateCache[sql] = csql
   176  	return csql
   177  }
   178  
   179  func (p *postgresql) FizzTranslator() fizz.Translator {
   180  	return translators.NewPostgres()
   181  }
   182  
   183  func (p *postgresql) DumpSchema(w io.Writer) error {
   184  	cmd := exec.Command("pg_dump", "-s", fmt.Sprintf("--dbname=%s", p.URL()))
   185  	return genericDumpSchema(p.Details(), cmd, w)
   186  }
   187  
   188  // LoadSchema executes a schema sql file against the configured database.
   189  func (p *postgresql) LoadSchema(r io.Reader) error {
   190  	return genericLoadSchema(p.ConnectionDetails, p.MigrationURL(), r)
   191  }
   192  
   193  // TruncateAll truncates all tables for the given connection.
   194  func (p *postgresql) TruncateAll(tx *Connection) error {
   195  	return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec()
   196  }
   197  
   198  func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
   199  	cd := &postgresql{
   200  		commonDialect:  commonDialect{ConnectionDetails: deets},
   201  		translateCache: map[string]string{},
   202  		mu:             sync.Mutex{},
   203  	}
   204  	return cd, nil
   205  }
   206  
   207  // urlParserPostgreSQL parses the options the same way official lib/pg does:
   208  // https://godoc.org/github.com/lib/pq#hdr-Connection_String_Parameters
   209  // After parsed, they are set to ConnectionDetails instance
   210  func urlParserPostgreSQL(cd *ConnectionDetails) error {
   211  	var err error
   212  	name := cd.URL
   213  	if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
   214  		name, err = pg.ParseURL(name)
   215  		if err != nil {
   216  			return err
   217  		}
   218  	}
   219  
   220  	o := make(values)
   221  	if err := parseOpts(name, o); err != nil {
   222  		return err
   223  	}
   224  
   225  	if dbname, ok := o["dbname"]; ok {
   226  		cd.Database = dbname
   227  	}
   228  	if host, ok := o["host"]; ok {
   229  		cd.Host = host
   230  	}
   231  	if password, ok := o["password"]; ok {
   232  		cd.Password = password
   233  	}
   234  	if user, ok := o["user"]; ok {
   235  		cd.User = user
   236  	}
   237  	if port, ok := o["port"]; ok {
   238  		cd.Port = port
   239  	}
   240  
   241  	options := []string{"sslmode", "fallback_application_name", "connect_timeout", "sslcert", "sslkey", "sslrootcert"}
   242  
   243  	for i := range options {
   244  		if opt, ok := o[options[i]]; ok {
   245  			cd.Options[options[i]] = opt
   246  		}
   247  	}
   248  
   249  	return nil
   250  }
   251  
   252  func finalizerPostgreSQL(cd *ConnectionDetails) {
   253  	cd.Options["sslmode"] = defaults.String(cd.Options["sslmode"], "disable")
   254  	cd.Port = defaults.String(cd.Port, portPostgreSQL)
   255  }
   256  
   257  const pgTruncate = `DO
   258  $func$
   259  DECLARE
   260     _tbl text;
   261     _sch text;
   262  BEGIN
   263     FOR _sch, _tbl IN
   264        SELECT schemaname, tablename
   265        FROM   pg_tables
   266        WHERE  tablename <> '%s' AND schemaname NOT IN ('pg_catalog', 'information_schema') AND tableowner = current_user
   267     LOOP
   268        --RAISE ERROR '%%',
   269        EXECUTE  -- dangerous, test before you execute!
   270           format('TRUNCATE TABLE %%I.%%I CASCADE', _sch, _tbl);
   271     END LOOP;
   272  END
   273  $func$;`
   274  
   275  // Code below is ported from: https://github.com/lib/pq/blob/master/conn.go
   276  type values map[string]string
   277  
   278  // scanner implements a tokenizer for libpq-style option strings.
   279  type scanner struct {
   280  	s []rune
   281  	i int
   282  }
   283  
   284  // newScanner returns a new scanner initialized with the option string s.
   285  func newScanner(s string) *scanner {
   286  	return &scanner{[]rune(s), 0}
   287  }
   288  
   289  // Next returns the next rune.
   290  // It returns 0, false if the end of the text has been reached.
   291  func (s *scanner) Next() (rune, bool) {
   292  	if s.i >= len(s.s) {
   293  		return 0, false
   294  	}
   295  	r := s.s[s.i]
   296  	s.i++
   297  	return r, true
   298  }
   299  
   300  // SkipSpaces returns the next non-whitespace rune.
   301  // It returns 0, false if the end of the text has been reached.
   302  func (s *scanner) SkipSpaces() (rune, bool) {
   303  	r, ok := s.Next()
   304  	for unicode.IsSpace(r) && ok {
   305  		r, ok = s.Next()
   306  	}
   307  	return r, ok
   308  }
   309  
   310  // parseOpts parses the options from name and adds them to the values.
   311  //
   312  // The parsing code is based on conninfo_parse from libpq's fe-connect.c
   313  func parseOpts(name string, o values) error {
   314  	s := newScanner(name)
   315  
   316  	for {
   317  		var (
   318  			keyRunes, valRunes []rune
   319  			r                  rune
   320  			ok                 bool
   321  		)
   322  
   323  		if r, ok = s.SkipSpaces(); !ok {
   324  			break
   325  		}
   326  
   327  		// Scan the key
   328  		for !unicode.IsSpace(r) && r != '=' {
   329  			keyRunes = append(keyRunes, r)
   330  			if r, ok = s.Next(); !ok {
   331  				break
   332  			}
   333  		}
   334  
   335  		// Skip any whitespace if we're not at the = yet
   336  		if r != '=' {
   337  			r, ok = s.SkipSpaces()
   338  		}
   339  
   340  		// The current character should be =
   341  		if r != '=' || !ok {
   342  			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
   343  		}
   344  
   345  		// Skip any whitespace after the =
   346  		if r, ok = s.SkipSpaces(); !ok {
   347  			// If we reach the end here, the last value is just an empty string as per libpq.
   348  			o[string(keyRunes)] = ""
   349  			break
   350  		}
   351  
   352  		if r != '\'' {
   353  			for !unicode.IsSpace(r) {
   354  				if r == '\\' {
   355  					if r, ok = s.Next(); !ok {
   356  						return fmt.Errorf(`missing character after backslash`)
   357  					}
   358  				}
   359  				valRunes = append(valRunes, r)
   360  
   361  				if r, ok = s.Next(); !ok {
   362  					break
   363  				}
   364  			}
   365  		} else {
   366  		quote:
   367  			for {
   368  				if r, ok = s.Next(); !ok {
   369  					return fmt.Errorf(`unterminated quoted string literal in connection string`)
   370  				}
   371  				switch r {
   372  				case '\'':
   373  					break quote
   374  				case '\\':
   375  					r, _ = s.Next()
   376  					fallthrough
   377  				default:
   378  					valRunes = append(valRunes, r)
   379  				}
   380  			}
   381  		}
   382  
   383  		o[string(keyRunes)] = string(valRunes)
   384  	}
   385  
   386  	return nil
   387  }