github.com/paweljw/pop/v5@v5.4.6/dialect_postgresql.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os/exec"
     7  	"sync"
     8  
     9  	// Load pgx driver
    10  	_ "github.com/jackc/pgx/v4/stdlib"
    11  
    12  	"github.com/gobuffalo/fizz"
    13  	"github.com/gobuffalo/fizz/translators"
    14  	"github.com/jackc/pgconn"
    15  	"github.com/jmoiron/sqlx"
    16  	"github.com/pkg/errors"
    17  
    18  	"github.com/gobuffalo/pop/v5/columns"
    19  	"github.com/gobuffalo/pop/v5/internal/defaults"
    20  	"github.com/gobuffalo/pop/v5/logging"
    21  )
    22  
    23  const namePostgreSQL = "postgres"
    24  const portPostgreSQL = "5432"
    25  
    26  func init() {
    27  	AvailableDialects = append(AvailableDialects, namePostgreSQL)
    28  	dialectSynonyms["postgresql"] = namePostgreSQL
    29  	dialectSynonyms["pg"] = namePostgreSQL
    30  	dialectSynonyms["pgx"] = namePostgreSQL
    31  	urlParser[namePostgreSQL] = urlParserPostgreSQL
    32  	finalizer[namePostgreSQL] = finalizerPostgreSQL
    33  	newConnection[namePostgreSQL] = newPostgreSQL
    34  }
    35  
    36  var _ dialect = &postgresql{}
    37  
    38  type postgresql struct {
    39  	commonDialect
    40  	translateCache map[string]string
    41  	mu             sync.Mutex
    42  }
    43  
    44  func (p *postgresql) Name() string {
    45  	return namePostgreSQL
    46  }
    47  
    48  func (p *postgresql) DefaultDriver() string {
    49  	return "pgx"
    50  }
    51  
    52  func (p *postgresql) Details() *ConnectionDetails {
    53  	return p.ConnectionDetails
    54  }
    55  
    56  func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
    57  	keyType, err := model.PrimaryKeyType()
    58  	if err != nil {
    59  		return err
    60  	}
    61  	switch keyType {
    62  	case "int", "int64":
    63  		cols.Remove(model.IDField())
    64  		w := cols.Writeable()
    65  		var query string
    66  		if len(w.Cols) > 0 {
    67  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
    68  		} else {
    69  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
    70  		}
    71  		log(logging.SQL, query)
    72  		stmt, err := s.PrepareNamed(query)
    73  		if err != nil {
    74  			return err
    75  		}
    76  		id := map[string]interface{}{}
    77  		err = stmt.QueryRow(model.Value).MapScan(id)
    78  		if err != nil {
    79  			if closeErr := stmt.Close(); closeErr != nil {
    80  				return errors.Wrapf(err, "failed to close prepared statement: %s", closeErr)
    81  			}
    82  			return err
    83  		}
    84  		model.setID(id[model.IDField()])
    85  		return errors.WithMessage(stmt.Close(), "failed to close statement")
    86  	}
    87  	return genericCreate(s, model, cols, p)
    88  }
    89  
    90  func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
    91  	return genericUpdate(s, model, cols, p)
    92  }
    93  
    94  func (p *postgresql) Destroy(s store, model *Model) error {
    95  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.alias(), model.whereID()))
    96  	_, err := genericExec(s, stmt, model.ID())
    97  	if err != nil {
    98  		return err
    99  	}
   100  	return nil
   101  }
   102  
   103  func (p *postgresql) SelectOne(s store, model *Model, query Query) error {
   104  	return genericSelectOne(s, model, query)
   105  }
   106  
   107  func (p *postgresql) SelectMany(s store, models *Model, query Query) error {
   108  	return genericSelectMany(s, models, query)
   109  }
   110  
   111  func (p *postgresql) CreateDB() error {
   112  	// createdb -h db -p 5432 -U postgres enterprise_development
   113  	deets := p.ConnectionDetails
   114  
   115  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   116  	if err != nil {
   117  		return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
   118  	}
   119  	defer db.Close()
   120  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   121  	log(logging.SQL, query)
   122  
   123  	_, err = db.Exec(query)
   124  	if err != nil {
   125  		return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
   126  	}
   127  
   128  	log(logging.Info, "created database %s", deets.Database)
   129  	return nil
   130  }
   131  
   132  func (p *postgresql) DropDB() error {
   133  	deets := p.ConnectionDetails
   134  
   135  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   136  	if err != nil {
   137  		return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
   138  	}
   139  	defer db.Close()
   140  	query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database))
   141  	log(logging.SQL, query)
   142  
   143  	_, err = db.Exec(query)
   144  	if err != nil {
   145  		return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
   146  	}
   147  
   148  	log(logging.Info, "dropped database %s", deets.Database)
   149  	return nil
   150  }
   151  
   152  func (p *postgresql) URL() string {
   153  	c := p.ConnectionDetails
   154  	if c.URL != "" {
   155  		return c.URL
   156  	}
   157  	s := "postgres://%s:%s@%s:%s/%s?%s"
   158  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.Database, c.OptionsString(""))
   159  }
   160  
   161  func (p *postgresql) urlWithoutDb() string {
   162  	c := p.ConnectionDetails
   163  	// https://github.com/gobuffalo/buffalo/issues/836
   164  	// If the db is not precised, postgresql takes the username as the database to connect on.
   165  	// To avoid a connection problem if the user db is not here, we use the default "postgres"
   166  	// db, just like the other client tools do.
   167  	s := "postgres://%s:%s@%s:%s/postgres?%s"
   168  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.OptionsString(""))
   169  }
   170  
   171  func (p *postgresql) MigrationURL() string {
   172  	return p.URL()
   173  }
   174  
   175  func (p *postgresql) TranslateSQL(sql string) string {
   176  	defer p.mu.Unlock()
   177  	p.mu.Lock()
   178  
   179  	if csql, ok := p.translateCache[sql]; ok {
   180  		return csql
   181  	}
   182  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   183  
   184  	p.translateCache[sql] = csql
   185  	return csql
   186  }
   187  
   188  func (p *postgresql) FizzTranslator() fizz.Translator {
   189  	return translators.NewPostgres()
   190  }
   191  
   192  func (p *postgresql) DumpSchema(w io.Writer) error {
   193  	cmd := exec.Command("pg_dump", "-s", fmt.Sprintf("--dbname=%s", p.URL()))
   194  	return genericDumpSchema(p.Details(), cmd, w)
   195  }
   196  
   197  // LoadSchema executes a schema sql file against the configured database.
   198  func (p *postgresql) LoadSchema(r io.Reader) error {
   199  	return genericLoadSchema(p, r)
   200  }
   201  
   202  // TruncateAll truncates all tables for the given connection.
   203  func (p *postgresql) TruncateAll(tx *Connection) error {
   204  	return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec()
   205  }
   206  
   207  func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
   208  	cd := &postgresql{
   209  		commonDialect:  commonDialect{ConnectionDetails: deets},
   210  		translateCache: map[string]string{},
   211  		mu:             sync.Mutex{},
   212  	}
   213  	return cd, nil
   214  }
   215  
   216  // urlParserPostgreSQL parses the options the same way jackc/pgconn does:
   217  // https://pkg.go.dev/github.com/jackc/pgconn?tab=doc#ParseConfig
   218  // After parsed, they are set to ConnectionDetails instance
   219  func urlParserPostgreSQL(cd *ConnectionDetails) error {
   220  	conf, err := pgconn.ParseConfig(cd.URL)
   221  	if err != nil {
   222  		return err
   223  	}
   224  
   225  	cd.Database = conf.Database
   226  	cd.Host = conf.Host
   227  	cd.User = conf.User
   228  	cd.Password = conf.Password
   229  	cd.Port = fmt.Sprintf("%d", conf.Port)
   230  
   231  	options := []string{"fallback_application_name"}
   232  	for i := range options {
   233  		if opt, ok := conf.RuntimeParams[options[i]]; ok {
   234  			cd.Options[options[i]] = opt
   235  		}
   236  	}
   237  
   238  	if conf.TLSConfig == nil {
   239  		cd.Options["sslmode"] = "disable"
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func finalizerPostgreSQL(cd *ConnectionDetails) {
   246  	cd.Port = defaults.String(cd.Port, portPostgreSQL)
   247  }
   248  
   249  const pgTruncate = `DO
   250  $func$
   251  DECLARE
   252     _tbl text;
   253     _sch text;
   254  BEGIN
   255     FOR _sch, _tbl IN
   256        SELECT schemaname, tablename
   257        FROM   pg_tables
   258        WHERE  tablename <> '%s' AND schemaname NOT IN ('pg_catalog', 'information_schema') AND tableowner = current_user
   259     LOOP
   260        --RAISE ERROR '%%',
   261        EXECUTE  -- dangerous, test before you execute!
   262           format('TRUNCATE TABLE %%I.%%I CASCADE', _sch, _tbl);
   263     END LOOP;
   264  END
   265  $func$;`