github.com/dkishere/pop/v6@v6.103.1/dialect_postgresql.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os/exec"
     7  	"sync"
     8  
     9  	"github.com/gobuffalo/fizz"
    10  	"github.com/gobuffalo/fizz/translators"
    11  	"github.com/dkishere/pop/v6/columns"
    12  	"github.com/dkishere/pop/v6/internal/defaults"
    13  	"github.com/dkishere/pop/v6/logging"
    14  	"github.com/jackc/pgconn"
    15  	_ "github.com/jackc/pgx/v4/stdlib" // Load pgx driver
    16  	"github.com/jmoiron/sqlx"
    17  )
    18  
    19  const namePostgreSQL = "postgres"
    20  const portPostgreSQL = "5432"
    21  
    22  func init() {
    23  	AvailableDialects = append(AvailableDialects, namePostgreSQL)
    24  	dialectSynonyms["postgresql"] = namePostgreSQL
    25  	dialectSynonyms["pg"] = namePostgreSQL
    26  	dialectSynonyms["pgx"] = namePostgreSQL
    27  	urlParser[namePostgreSQL] = urlParserPostgreSQL
    28  	finalizer[namePostgreSQL] = finalizerPostgreSQL
    29  	newConnection[namePostgreSQL] = newPostgreSQL
    30  }
    31  
    32  var _ dialect = &postgresql{}
    33  
    34  type postgresql struct {
    35  	commonDialect
    36  	translateCache map[string]string
    37  	mu             sync.Mutex
    38  }
    39  
    40  func (p *postgresql) Name() string {
    41  	return namePostgreSQL
    42  }
    43  
    44  func (p *postgresql) DefaultDriver() string {
    45  	return "pgx"
    46  }
    47  
    48  func (p *postgresql) Details() *ConnectionDetails {
    49  	return p.ConnectionDetails
    50  }
    51  
    52  func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
    53  	keyType, err := model.PrimaryKeyType()
    54  	if err != nil {
    55  		return err
    56  	}
    57  	switch keyType {
    58  	case "int", "int64":
    59  		cols.Remove(model.IDField())
    60  		w := cols.Writeable()
    61  		var query string
    62  		if len(w.Cols) > 0 {
    63  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
    64  		} else {
    65  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
    66  		}
    67  		log(logging.SQL, query, model.Value)
    68  		stmt, err := s.PrepareNamed(query)
    69  		if err != nil {
    70  			return err
    71  		}
    72  		id := map[string]interface{}{}
    73  		err = stmt.QueryRow(model.Value).MapScan(id)
    74  		if err != nil {
    75  			if closeErr := stmt.Close(); closeErr != nil {
    76  				return fmt.Errorf("failed to close prepared statement: %s: %w", closeErr, err)
    77  			}
    78  			return err
    79  		}
    80  		model.setID(id[model.IDField()])
    81  		if closeErr := stmt.Close(); closeErr != nil {
    82  			return fmt.Errorf("failed to close statement: %w", closeErr)
    83  		}
    84  		return nil
    85  	}
    86  	return genericCreate(s, model, cols, p)
    87  }
    88  
    89  func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
    90  	return genericUpdate(s, model, cols, p)
    91  }
    92  
    93  func (p *postgresql) Destroy(s store, model *Model) error {
    94  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
    95  	_, err := genericExec(s, stmt, model.ID())
    96  	if err != nil {
    97  		return err
    98  	}
    99  	return nil
   100  }
   101  
   102  func (p *postgresql) Delete(s store, model *Model, query Query) error {
   103  	return genericDelete(s, model, query)
   104  }
   105  
   106  func (p *postgresql) SelectOne(s store, model *Model, query Query) error {
   107  	return genericSelectOne(s, model, query)
   108  }
   109  
   110  func (p *postgresql) SelectMany(s store, models *Model, query Query) error {
   111  	return genericSelectMany(s, models, query)
   112  }
   113  
   114  func (p *postgresql) CreateDB() error {
   115  	// createdb -h db -p 5432 -U postgres enterprise_development
   116  	deets := p.ConnectionDetails
   117  
   118  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   119  	if err != nil {
   120  		return fmt.Errorf("error creating PostgreSQL database %s: %w", deets.Database, err)
   121  	}
   122  	defer db.Close()
   123  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   124  	log(logging.SQL, query)
   125  
   126  	_, err = db.Exec(query)
   127  	if err != nil {
   128  		return fmt.Errorf("error creating PostgreSQL database %s: %w", deets.Database, err)
   129  	}
   130  
   131  	log(logging.Info, "created database %s", deets.Database)
   132  	return nil
   133  }
   134  
   135  func (p *postgresql) DropDB() error {
   136  	deets := p.ConnectionDetails
   137  
   138  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   139  	if err != nil {
   140  		return fmt.Errorf("error dropping PostgreSQL database %s: %w", deets.Database, err)
   141  	}
   142  	defer db.Close()
   143  	query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database))
   144  	log(logging.SQL, query)
   145  
   146  	_, err = db.Exec(query)
   147  	if err != nil {
   148  		return fmt.Errorf("error dropping PostgreSQL database %s: %w", deets.Database, err)
   149  	}
   150  
   151  	log(logging.Info, "dropped database %s", deets.Database)
   152  	return nil
   153  }
   154  
   155  func (p *postgresql) URL() string {
   156  	c := p.ConnectionDetails
   157  	if c.URL != "" {
   158  		return c.URL
   159  	}
   160  	s := "postgres://%s:%s@%s:%s/%s?%s"
   161  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.Database, c.OptionsString(""))
   162  }
   163  
   164  func (p *postgresql) urlWithoutDb() string {
   165  	c := p.ConnectionDetails
   166  	// https://github.com/gobuffalo/buffalo/issues/836
   167  	// If the db is not precised, postgresql takes the username as the database to connect on.
   168  	// To avoid a connection problem if the user db is not here, we use the default "postgres"
   169  	// db, just like the other client tools do.
   170  	s := "postgres://%s:%s@%s:%s/postgres?%s"
   171  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.OptionsString(""))
   172  }
   173  
   174  func (p *postgresql) MigrationURL() string {
   175  	return p.URL()
   176  }
   177  
   178  func (p *postgresql) TranslateSQL(sql string) string {
   179  	defer p.mu.Unlock()
   180  	p.mu.Lock()
   181  
   182  	if csql, ok := p.translateCache[sql]; ok {
   183  		return csql
   184  	}
   185  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   186  
   187  	p.translateCache[sql] = csql
   188  	return csql
   189  }
   190  
   191  func (p *postgresql) FizzTranslator() fizz.Translator {
   192  	return translators.NewPostgres()
   193  }
   194  
   195  func (p *postgresql) DumpSchema(w io.Writer) error {
   196  	cmd := exec.Command("pg_dump", "-s", fmt.Sprintf("--dbname=%s", p.URL()))
   197  	return genericDumpSchema(p.Details(), cmd, w)
   198  }
   199  
   200  // LoadSchema executes a schema sql file against the configured database.
   201  func (p *postgresql) LoadSchema(r io.Reader) error {
   202  	return genericLoadSchema(p, r)
   203  }
   204  
   205  // TruncateAll truncates all tables for the given connection.
   206  func (p *postgresql) TruncateAll(tx *Connection) error {
   207  	return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec()
   208  }
   209  
   210  func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
   211  	cd := &postgresql{
   212  		commonDialect:  commonDialect{ConnectionDetails: deets},
   213  		translateCache: map[string]string{},
   214  		mu:             sync.Mutex{},
   215  	}
   216  	return cd, nil
   217  }
   218  
   219  // urlParserPostgreSQL parses the options the same way jackc/pgconn does:
   220  // https://pkg.go.dev/github.com/jackc/pgconn?tab=doc#ParseConfig
   221  // After parsed, they are set to ConnectionDetails instance
   222  func urlParserPostgreSQL(cd *ConnectionDetails) error {
   223  	conf, err := pgconn.ParseConfig(cd.URL)
   224  	if err != nil {
   225  		return err
   226  	}
   227  
   228  	cd.Database = conf.Database
   229  	cd.Host = conf.Host
   230  	cd.User = conf.User
   231  	cd.Password = conf.Password
   232  	cd.Port = fmt.Sprintf("%d", conf.Port)
   233  
   234  	options := []string{"fallback_application_name"}
   235  	for i := range options {
   236  		if opt, ok := conf.RuntimeParams[options[i]]; ok {
   237  			cd.Options[options[i]] = opt
   238  		}
   239  	}
   240  
   241  	if conf.TLSConfig == nil {
   242  		cd.Options["sslmode"] = "disable"
   243  	}
   244  
   245  	return nil
   246  }
   247  
   248  func finalizerPostgreSQL(cd *ConnectionDetails) {
   249  	cd.Port = defaults.String(cd.Port, portPostgreSQL)
   250  }
   251  
   252  const pgTruncate = `DO
   253  $func$
   254  DECLARE
   255     _tbl text;
   256     _sch text;
   257  BEGIN
   258     FOR _sch, _tbl IN
   259        SELECT schemaname, tablename
   260        FROM   pg_tables
   261        WHERE  tablename <> '%s' AND schemaname NOT IN ('pg_catalog', 'information_schema') AND tableowner = current_user
   262     LOOP
   263        --RAISE ERROR '%%',
   264        EXECUTE  -- dangerous, test before you execute!
   265           format('TRUNCATE TABLE %%I.%%I CASCADE', _sch, _tbl);
   266     END LOOP;
   267  END
   268  $func$;`