github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/dialect_postgresql.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"net/url"
     8  	"os/exec"
     9  	"sync"
    10  
    11  	"github.com/Accefy/pop/internal/defaults"
    12  	"github.com/gobuffalo/fizz"
    13  	"github.com/gobuffalo/fizz/translators"
    14  	"github.com/gobuffalo/pop/v6/columns"
    15  	"github.com/gobuffalo/pop/v6/logging"
    16  	"github.com/jackc/pgconn"
    17  	_ "github.com/jackc/pgx/v4/stdlib" // Load pgx driver
    18  	"github.com/jmoiron/sqlx"
    19  )
    20  
    21  const namePostgreSQL = "postgres"
    22  const portPostgreSQL = "5432"
    23  
    24  func init() {
    25  	AvailableDialects = append(AvailableDialects, namePostgreSQL)
    26  	dialectSynonyms["postgresql"] = namePostgreSQL
    27  	dialectSynonyms["pg"] = namePostgreSQL
    28  	dialectSynonyms["pgx"] = namePostgreSQL
    29  	urlParser[namePostgreSQL] = urlParserPostgreSQL
    30  	finalizer[namePostgreSQL] = finalizerPostgreSQL
    31  	newConnection[namePostgreSQL] = newPostgreSQL
    32  }
    33  
    34  var _ dialect = &postgresql{}
    35  
    36  type postgresql struct {
    37  	commonDialect
    38  	translateCache map[string]string
    39  	mu             sync.Mutex
    40  }
    41  
    42  func (p *postgresql) Name() string {
    43  	return namePostgreSQL
    44  }
    45  
    46  func (p *postgresql) DefaultDriver() string {
    47  	return "pgx"
    48  }
    49  
    50  func (p *postgresql) Details() *ConnectionDetails {
    51  	return p.ConnectionDetails
    52  }
    53  
    54  func (p *postgresql) Create(c *Connection, model *Model, cols columns.Columns) error {
    55  	keyType, err := model.PrimaryKeyType()
    56  	if err != nil {
    57  		return err
    58  	}
    59  	switch keyType {
    60  	case "int", "int64":
    61  		cols.Remove(model.IDField())
    62  		w := cols.Writeable()
    63  		var query string
    64  		if len(w.Cols) > 0 {
    65  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
    66  		} else {
    67  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", p.Quote(model.TableName()), model.IDField())
    68  		}
    69  		txlog(logging.SQL, c, query, model.Value)
    70  		rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value)
    71  		if err != nil {
    72  			return fmt.Errorf("named insert: %w", err)
    73  		}
    74  		defer rows.Close()
    75  		if !rows.Next() {
    76  			if err := rows.Err(); err != nil {
    77  				return fmt.Errorf("named insert: next: %w", err)
    78  			}
    79  			return fmt.Errorf("named insert: %w", sql.ErrNoRows)
    80  		}
    81  		var id interface{}
    82  		if err := rows.Scan(&id); err != nil {
    83  			return fmt.Errorf("named insert: scan: %w", err)
    84  		}
    85  		if err := rows.Close(); err != nil {
    86  			return fmt.Errorf("named insert: close: %w", err)
    87  		}
    88  		model.setID(id)
    89  		return nil
    90  	}
    91  	return genericCreate(c, model, cols, p)
    92  }
    93  
    94  func (p *postgresql) Update(c *Connection, model *Model, cols columns.Columns) error {
    95  	return genericUpdate(c, model, cols, p)
    96  }
    97  
    98  func (p *postgresql) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) {
    99  	return genericUpdateQuery(c, model, cols, p, query, sqlx.DOLLAR)
   100  }
   101  
   102  func (p *postgresql) Destroy(c *Connection, model *Model) error {
   103  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
   104  	_, err := genericExec(c, stmt, model.ID())
   105  	if err != nil {
   106  		return err
   107  	}
   108  	return nil
   109  }
   110  
   111  func (p *postgresql) Delete(c *Connection, model *Model, query Query) error {
   112  	return genericDelete(c, model, query)
   113  }
   114  
   115  func (p *postgresql) SelectOne(c *Connection, model *Model, query Query) error {
   116  	return genericSelectOne(c, model, query)
   117  }
   118  
   119  func (p *postgresql) SelectMany(c *Connection, models *Model, query Query) error {
   120  	return genericSelectMany(c, models, query)
   121  }
   122  
   123  func (p *postgresql) CreateDB() error {
   124  	// createdb -h db -p 5432 -U postgres enterprise_development
   125  	deets := p.ConnectionDetails
   126  
   127  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   128  	if err != nil {
   129  		return fmt.Errorf("error creating PostgreSQL database %s: %w", deets.Database, err)
   130  	}
   131  	defer db.Close()
   132  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   133  	log(logging.SQL, query)
   134  
   135  	_, err = db.Exec(query)
   136  	if err != nil {
   137  		return fmt.Errorf("error creating PostgreSQL database %s: %w", deets.Database, err)
   138  	}
   139  
   140  	log(logging.Info, "created database %s", deets.Database)
   141  	return nil
   142  }
   143  
   144  func (p *postgresql) DropDB() error {
   145  	deets := p.ConnectionDetails
   146  
   147  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   148  	if err != nil {
   149  		return fmt.Errorf("error dropping PostgreSQL database %s: %w", deets.Database, err)
   150  	}
   151  	defer db.Close()
   152  	query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database))
   153  	log(logging.SQL, query)
   154  
   155  	_, err = db.Exec(query)
   156  	if err != nil {
   157  		return fmt.Errorf("error dropping PostgreSQL database %s: %w", deets.Database, err)
   158  	}
   159  
   160  	log(logging.Info, "dropped database %s", deets.Database)
   161  	return nil
   162  }
   163  
   164  func (p *postgresql) URL() string {
   165  	c := p.ConnectionDetails
   166  	if c.URL != "" {
   167  		return c.URL
   168  	}
   169  	s := "postgres://%s:%s@%s:%s/%s?%s"
   170  	return fmt.Sprintf(s, c.User, url.QueryEscape(c.Password), c.Host, c.Port, c.Database, c.OptionsString(""))
   171  }
   172  
   173  func (p *postgresql) urlWithoutDb() string {
   174  	c := p.ConnectionDetails
   175  	// https://github.com/gobuffalo/buffalo/issues/836
   176  	// If the db is not precised, postgresql takes the username as the database to connect on.
   177  	// To avoid a connection problem if the user db is not here, we use the default "postgres"
   178  	// db, just like the other client tools do.
   179  	s := "postgres://%s:%s@%s:%s/postgres?%s"
   180  	return fmt.Sprintf(s, c.User, url.QueryEscape(c.Password), c.Host, c.Port, c.OptionsString(""))
   181  }
   182  
   183  func (p *postgresql) MigrationURL() string {
   184  	return p.URL()
   185  }
   186  
   187  func (p *postgresql) TranslateSQL(sql string) string {
   188  	defer p.mu.Unlock()
   189  	p.mu.Lock()
   190  
   191  	if csql, ok := p.translateCache[sql]; ok {
   192  		return csql
   193  	}
   194  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   195  
   196  	p.translateCache[sql] = csql
   197  	return csql
   198  }
   199  
   200  func (p *postgresql) FizzTranslator() fizz.Translator {
   201  	return translators.NewPostgres()
   202  }
   203  
   204  func (p *postgresql) DumpSchema(w io.Writer) error {
   205  	cmd := exec.Command("pg_dump", "-s", fmt.Sprintf("--dbname=%s", p.URL()))
   206  	return genericDumpSchema(p.Details(), cmd, w)
   207  }
   208  
   209  // LoadSchema executes a schema sql file against the configured database.
   210  func (p *postgresql) LoadSchema(r io.Reader) error {
   211  	return genericLoadSchema(p, r)
   212  }
   213  
   214  // TruncateAll truncates all tables for the given connection.
   215  func (p *postgresql) TruncateAll(tx *Connection) error {
   216  	return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec()
   217  }
   218  
   219  func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
   220  	cd := &postgresql{
   221  		commonDialect:  commonDialect{ConnectionDetails: deets},
   222  		translateCache: map[string]string{},
   223  		mu:             sync.Mutex{},
   224  	}
   225  	return cd, nil
   226  }
   227  
   228  // urlParserPostgreSQL parses the options the same way jackc/pgconn does:
   229  // https://pkg.go.dev/github.com/jackc/pgconn?tab=doc#ParseConfig
   230  // After parsed, they are set to ConnectionDetails instance
   231  func urlParserPostgreSQL(cd *ConnectionDetails) error {
   232  	conf, err := pgconn.ParseConfig(cd.URL)
   233  	if err != nil {
   234  		return err
   235  	}
   236  
   237  	cd.Database = conf.Database
   238  	cd.Host = conf.Host
   239  	cd.User = conf.User
   240  	cd.Password = conf.Password
   241  	cd.Port = fmt.Sprintf("%d", conf.Port)
   242  
   243  	options := []string{"fallback_application_name"}
   244  	for i := range options {
   245  		if opt, ok := conf.RuntimeParams[options[i]]; ok {
   246  			cd.setOption(options[i], opt)
   247  		}
   248  	}
   249  
   250  	if conf.TLSConfig == nil {
   251  		cd.setOption("sslmode", "disable")
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  func finalizerPostgreSQL(cd *ConnectionDetails) {
   258  	cd.Port = defaults.String(cd.Port, portPostgreSQL)
   259  }
   260  
   261  const pgTruncate = `DO
   262  $func$
   263  DECLARE
   264     _tbl text;
   265     _sch text;
   266  BEGIN
   267     FOR _sch, _tbl IN
   268        SELECT schemaname, tablename
   269        FROM   pg_tables
   270        WHERE  tablename <> '%s' AND schemaname NOT IN ('pg_catalog', 'information_schema') AND tableowner = current_user
   271     LOOP
   272        --RAISE ERROR '%%',
   273        EXECUTE  -- dangerous, test before you execute!
   274           format('TRUNCATE TABLE %%I.%%I CASCADE', _sch, _tbl);
   275     END LOOP;
   276  END
   277  $func$;`