github.com/paweljw/pop/v5@v5.4.6/dialect_cockroach.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os"
     7  	"os/exec"
     8  	"path/filepath"
     9  	"strings"
    10  	"sync"
    11  
    12  	// Import PostgreSQL driver
    13  	_ "github.com/jackc/pgx/v4/stdlib"
    14  
    15  	"github.com/gobuffalo/fizz"
    16  	"github.com/gobuffalo/fizz/translators"
    17  	"github.com/gobuffalo/pop/v5/columns"
    18  	"github.com/gobuffalo/pop/v5/internal/defaults"
    19  	"github.com/gobuffalo/pop/v5/logging"
    20  	"github.com/jmoiron/sqlx"
    21  	"github.com/pkg/errors"
    22  )
    23  
    24  const nameCockroach = "cockroach"
    25  const portCockroach = "26257"
    26  
    27  const selectTablesQueryCockroach = "select table_name from information_schema.tables where table_schema = 'public' and table_type = 'BASE TABLE' and table_name <> ? and table_catalog = ?"
    28  const selectTablesQueryCockroachV1 = "select table_name from information_schema.tables where table_name <> ? and table_schema = ?"
    29  
    30  func init() {
    31  	AvailableDialects = append(AvailableDialects, nameCockroach)
    32  	dialectSynonyms["cockroachdb"] = nameCockroach
    33  	dialectSynonyms["crdb"] = nameCockroach
    34  	finalizer[nameCockroach] = finalizerCockroach
    35  	newConnection[nameCockroach] = newCockroach
    36  }
    37  
    38  var _ dialect = &cockroach{}
    39  
    40  // ServerInfo holds informational data about connected database server.
    41  type cockroachInfo struct {
    42  	VersionString string `db:"version"`
    43  	product       string `db:"-"`
    44  	license       string `db:"-"`
    45  	version       string `db:"-"`
    46  	buildInfo     string `db:"-"`
    47  	client        string `db:"-"`
    48  }
    49  
    50  type cockroach struct {
    51  	commonDialect
    52  	translateCache map[string]string
    53  	mu             sync.Mutex
    54  	info           cockroachInfo
    55  }
    56  
    57  func (p *cockroach) Name() string {
    58  	return nameCockroach
    59  }
    60  
    61  func (p *cockroach) DefaultDriver() string {
    62  	return "pgx"
    63  }
    64  
    65  func (p *cockroach) Details() *ConnectionDetails {
    66  	return p.ConnectionDetails
    67  }
    68  
    69  func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
    70  	keyType, err := model.PrimaryKeyType()
    71  	if err != nil {
    72  		return err
    73  	}
    74  	switch keyType {
    75  	case "int", "int64":
    76  		cols.Remove(model.IDField())
    77  		w := cols.Writeable()
    78  		var query string
    79  		if len(w.Cols) > 0 {
    80  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
    81  		} else {
    82  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
    83  		}
    84  		log(logging.SQL, query)
    85  		stmt, err := s.PrepareNamed(query)
    86  		if err != nil {
    87  			return err
    88  		}
    89  		id := map[string]interface{}{}
    90  		err = stmt.QueryRow(model.Value).MapScan(id)
    91  		if err != nil {
    92  			if closeErr := stmt.Close(); closeErr != nil {
    93  				return errors.Wrapf(err, "failed to close prepared statement: %s", closeErr)
    94  			}
    95  			return err
    96  		}
    97  		model.setID(id[model.IDField()])
    98  		return errors.WithMessage(stmt.Close(), "failed to close statement")
    99  	}
   100  	return genericCreate(s, model, cols, p)
   101  }
   102  
   103  func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error {
   104  	return genericUpdate(s, model, cols, p)
   105  }
   106  
   107  func (p *cockroach) Destroy(s store, model *Model) error {
   108  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.alias(), model.whereID()))
   109  	_, err := genericExec(s, stmt, model.ID())
   110  	return err
   111  }
   112  
   113  func (p *cockroach) SelectOne(s store, model *Model, query Query) error {
   114  	return genericSelectOne(s, model, query)
   115  }
   116  
   117  func (p *cockroach) SelectMany(s store, models *Model, query Query) error {
   118  	return genericSelectMany(s, models, query)
   119  }
   120  
   121  func (p *cockroach) CreateDB() error {
   122  	// createdb -h db -p 5432 -U cockroach enterprise_development
   123  	deets := p.ConnectionDetails
   124  
   125  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   126  	if err != nil {
   127  		return errors.Wrapf(err, "error creating Cockroach database %s", deets.Database)
   128  	}
   129  	defer db.Close()
   130  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   131  	log(logging.SQL, query)
   132  
   133  	_, err = db.Exec(query)
   134  	if err != nil {
   135  		return errors.Wrapf(err, "error creating Cockroach database %s", deets.Database)
   136  	}
   137  
   138  	log(logging.Info, "created database %s", deets.Database)
   139  	return nil
   140  }
   141  
   142  func (p *cockroach) DropDB() error {
   143  	deets := p.ConnectionDetails
   144  
   145  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   146  	if err != nil {
   147  		return errors.Wrapf(err, "error dropping Cockroach database %s", deets.Database)
   148  	}
   149  	defer db.Close()
   150  	query := fmt.Sprintf("DROP DATABASE %s CASCADE;", p.Quote(deets.Database))
   151  	log(logging.SQL, query)
   152  
   153  	_, err = db.Exec(query)
   154  	if err != nil {
   155  		return errors.Wrapf(err, "error dropping Cockroach database %s", deets.Database)
   156  	}
   157  
   158  	log(logging.Info, "dropped database %s", deets.Database)
   159  	return nil
   160  }
   161  
   162  func (p *cockroach) URL() string {
   163  	c := p.ConnectionDetails
   164  	if c.URL != "" {
   165  		return c.URL
   166  	}
   167  	s := "postgres://%s:%s@%s:%s/%s?%s"
   168  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.Database, c.OptionsString(""))
   169  }
   170  
   171  func (p *cockroach) urlWithoutDb() string {
   172  	c := p.ConnectionDetails
   173  	s := "postgres://%s:%s@%s:%s/?%s"
   174  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.OptionsString(""))
   175  }
   176  
   177  func (p *cockroach) MigrationURL() string {
   178  	return p.URL()
   179  }
   180  
   181  func (p *cockroach) TranslateSQL(sql string) string {
   182  	defer p.mu.Unlock()
   183  	p.mu.Lock()
   184  
   185  	if csql, ok := p.translateCache[sql]; ok {
   186  		return csql
   187  	}
   188  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   189  
   190  	p.translateCache[sql] = csql
   191  	return csql
   192  }
   193  
   194  func (p *cockroach) FizzTranslator() fizz.Translator {
   195  	return translators.NewCockroach(p.URL(), p.Details().Database)
   196  }
   197  
   198  func (p *cockroach) DumpSchema(w io.Writer) error {
   199  	cmd := exec.Command("cockroach", "dump", p.Details().Database, "--dump-mode=schema")
   200  
   201  	c := p.ConnectionDetails
   202  	if defaults.String(c.Options["sslmode"], "disable") == "disable" || strings.Contains(c.RawOptions, "sslmode=disable") {
   203  		cmd.Args = append(cmd.Args, "--insecure")
   204  	}
   205  	return genericDumpSchema(p.Details(), cmd, w)
   206  }
   207  
   208  func (p *cockroach) LoadSchema(r io.Reader) error {
   209  	return genericLoadSchema(p, r)
   210  }
   211  
   212  func (p *cockroach) TruncateAll(tx *Connection) error {
   213  	type table struct {
   214  		TableName string `db:"table_name"`
   215  	}
   216  
   217  	tableQuery := p.tablesQuery()
   218  
   219  	var tables []table
   220  	if err := tx.RawQuery(tableQuery, tx.MigrationTableName(), tx.Dialect.Details().Database).All(&tables); err != nil {
   221  		return err
   222  	}
   223  
   224  	if len(tables) == 0 {
   225  		return nil
   226  	}
   227  
   228  	tableNames := make([]string, len(tables))
   229  	for i, t := range tables {
   230  		tableNames[i] = t.TableName
   231  		//! work around for current limitation of DDL and DML at the same transaction.
   232  		//  it should be fixed when cockroach support it or with other approach.
   233  		//  https://www.cockroachlabs.com/docs/stable/known-limitations.html#schema-changes-within-transactions
   234  		if err := tx.RawQuery(fmt.Sprintf("delete from %s", p.Quote(t.TableName))).Exec(); err != nil {
   235  			return err
   236  		}
   237  	}
   238  	return nil
   239  	// TODO!
   240  	// return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec()
   241  }
   242  
   243  func (p *cockroach) AfterOpen(c *Connection) error {
   244  	if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil {
   245  		return err
   246  	}
   247  	if s := strings.Split(p.info.VersionString, " "); len(s) > 3 {
   248  		p.info.product = s[0]
   249  		p.info.license = s[1]
   250  		p.info.version = s[2]
   251  		p.info.buildInfo = s[3]
   252  	}
   253  	log(logging.Debug, "server: %v %v %v", p.info.product, p.info.license, p.info.version)
   254  
   255  	return nil
   256  }
   257  
   258  func newCockroach(deets *ConnectionDetails) (dialect, error) {
   259  	deets.Dialect = "postgres"
   260  	d := &cockroach{
   261  		commonDialect:  commonDialect{ConnectionDetails: deets},
   262  		translateCache: map[string]string{},
   263  		mu:             sync.Mutex{},
   264  	}
   265  	d.info.client = deets.Options["application_name"]
   266  	return d, nil
   267  }
   268  
   269  func finalizerCockroach(cd *ConnectionDetails) {
   270  	appName := filepath.Base(os.Args[0])
   271  	cd.Options["application_name"] = defaults.String(cd.Options["application_name"], appName)
   272  	cd.Port = defaults.String(cd.Port, portCockroach)
   273  	if cd.URL != "" {
   274  		cd.URL = "postgres://" + trimCockroachPrefix(cd.URL)
   275  	}
   276  }
   277  
   278  func trimCockroachPrefix(u string) string {
   279  	parts := strings.Split(u, "://")
   280  	if len(parts) != 2 {
   281  		return u
   282  	}
   283  	return parts[1]
   284  }
   285  
   286  func (p *cockroach) tablesQuery() string {
   287  	// See https://www.cockroachlabs.com/docs/stable/information-schema.html for more info about information schema changes
   288  	tableQuery := selectTablesQueryCockroach
   289  	if strings.HasPrefix(p.info.version, "v1.") {
   290  		tableQuery = selectTablesQueryCockroachV1
   291  	}
   292  	return tableQuery
   293  }