github.com/dkishere/pop/v6@v6.103.1/dialect_cockroach.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os"
     7  	"os/exec"
     8  	"path/filepath"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/gobuffalo/fizz"
    13  	"github.com/gobuffalo/fizz/translators"
    14  	"github.com/dkishere/pop/v6/columns"
    15  	"github.com/dkishere/pop/v6/internal/defaults"
    16  	"github.com/dkishere/pop/v6/logging"
    17  	_ "github.com/jackc/pgx/v4/stdlib" // Import PostgreSQL driver
    18  	"github.com/jmoiron/sqlx"
    19  )
    20  
    21  const nameCockroach = "cockroach"
    22  const portCockroach = "26257"
    23  
    24  const selectTablesQueryCockroach = "select table_name from information_schema.tables where table_schema = 'public' and table_type = 'BASE TABLE' and table_name <> ? and table_catalog = ?"
    25  const selectTablesQueryCockroachV1 = "select table_name from information_schema.tables where table_name <> ? and table_schema = ?"
    26  
    27  func init() {
    28  	AvailableDialects = append(AvailableDialects, nameCockroach)
    29  	dialectSynonyms["cockroachdb"] = nameCockroach
    30  	dialectSynonyms["crdb"] = nameCockroach
    31  	finalizer[nameCockroach] = finalizerCockroach
    32  	newConnection[nameCockroach] = newCockroach
    33  }
    34  
    35  var _ dialect = &cockroach{}
    36  
    37  // ServerInfo holds informational data about connected database server.
    38  type cockroachInfo struct {
    39  	VersionString string `db:"version"`
    40  	product       string `db:"-"`
    41  	license       string `db:"-"`
    42  	version       string `db:"-"`
    43  	buildInfo     string `db:"-"`
    44  	client        string `db:"-"`
    45  }
    46  
    47  type cockroach struct {
    48  	commonDialect
    49  	translateCache map[string]string
    50  	mu             sync.Mutex
    51  	info           cockroachInfo
    52  }
    53  
    54  func (p *cockroach) Name() string {
    55  	return nameCockroach
    56  }
    57  
    58  func (p *cockroach) DefaultDriver() string {
    59  	return "pgx"
    60  }
    61  
    62  func (p *cockroach) Details() *ConnectionDetails {
    63  	return p.ConnectionDetails
    64  }
    65  
    66  func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
    67  	keyType, err := model.PrimaryKeyType()
    68  	if err != nil {
    69  		return err
    70  	}
    71  	switch keyType {
    72  	case "int", "int64":
    73  		cols.Remove(model.IDField())
    74  		w := cols.Writeable()
    75  		var query string
    76  		if len(w.Cols) > 0 {
    77  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
    78  		} else {
    79  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
    80  		}
    81  		log(logging.SQL, query, model.Value)
    82  		stmt, err := s.PrepareNamed(query)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		id := map[string]interface{}{}
    87  		err = stmt.QueryRow(model.Value).MapScan(id)
    88  		if err != nil {
    89  			if closeErr := stmt.Close(); closeErr != nil {
    90  				return fmt.Errorf("failed to close prepared statement: %s: %w", closeErr, err)
    91  			}
    92  			return err
    93  		}
    94  		model.setID(id[model.IDField()])
    95  		if err := stmt.Close(); err != nil {
    96  			return fmt.Errorf("failed to close statement: %w", err)
    97  		}
    98  		return nil
    99  	}
   100  	return genericCreate(s, model, cols, p)
   101  }
   102  
   103  func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error {
   104  	return genericUpdate(s, model, cols, p)
   105  }
   106  
   107  func (p *cockroach) Destroy(s store, model *Model) error {
   108  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
   109  	_, err := genericExec(s, stmt, model.ID())
   110  	return err
   111  }
   112  
   113  func (p *cockroach) Delete(s store, model *Model, query Query) error {
   114  	return genericDelete(s, model, query)
   115  }
   116  
   117  func (p *cockroach) SelectOne(s store, model *Model, query Query) error {
   118  	return genericSelectOne(s, model, query)
   119  }
   120  
   121  func (p *cockroach) SelectMany(s store, models *Model, query Query) error {
   122  	return genericSelectMany(s, models, query)
   123  }
   124  
   125  func (p *cockroach) CreateDB() error {
   126  	// createdb -h db -p 5432 -U cockroach enterprise_development
   127  	deets := p.ConnectionDetails
   128  
   129  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   130  	if err != nil {
   131  		return fmt.Errorf("error creating Cockroach database %s: %w", deets.Database, err)
   132  	}
   133  	defer db.Close()
   134  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   135  	log(logging.SQL, query)
   136  
   137  	_, err = db.Exec(query)
   138  	if err != nil {
   139  		return fmt.Errorf("error creating Cockroach database %s: %w", deets.Database, err)
   140  	}
   141  
   142  	log(logging.Info, "created database %s", deets.Database)
   143  	return nil
   144  }
   145  
   146  func (p *cockroach) DropDB() error {
   147  	deets := p.ConnectionDetails
   148  
   149  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   150  	if err != nil {
   151  		return fmt.Errorf("error dropping Cockroach database %s: %w", deets.Database, err)
   152  	}
   153  	defer db.Close()
   154  	query := fmt.Sprintf("DROP DATABASE %s CASCADE;", p.Quote(deets.Database))
   155  	log(logging.SQL, query)
   156  
   157  	_, err = db.Exec(query)
   158  	if err != nil {
   159  		return fmt.Errorf("error dropping Cockroach database %s: %w", deets.Database, err)
   160  	}
   161  
   162  	log(logging.Info, "dropped database %s", deets.Database)
   163  	return nil
   164  }
   165  
   166  func (p *cockroach) URL() string {
   167  	c := p.ConnectionDetails
   168  	if c.URL != "" {
   169  		return c.URL
   170  	}
   171  	s := "postgres://%s:%s@%s:%s/%s?%s"
   172  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.Database, c.OptionsString(""))
   173  }
   174  
   175  func (p *cockroach) urlWithoutDb() string {
   176  	c := p.ConnectionDetails
   177  	s := "postgres://%s:%s@%s:%s/?%s"
   178  	return fmt.Sprintf(s, c.User, c.Password, c.Host, c.Port, c.OptionsString(""))
   179  }
   180  
   181  func (p *cockroach) MigrationURL() string {
   182  	return p.URL()
   183  }
   184  
   185  func (p *cockroach) TranslateSQL(sql string) string {
   186  	defer p.mu.Unlock()
   187  	p.mu.Lock()
   188  
   189  	if csql, ok := p.translateCache[sql]; ok {
   190  		return csql
   191  	}
   192  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   193  
   194  	p.translateCache[sql] = csql
   195  	return csql
   196  }
   197  
   198  func (p *cockroach) FizzTranslator() fizz.Translator {
   199  	return translators.NewCockroach(p.URL(), p.Details().Database)
   200  }
   201  
   202  func (p *cockroach) DumpSchema(w io.Writer) error {
   203  	cmd := exec.Command("cockroach", "dump", p.Details().Database, "--dump-mode=schema")
   204  
   205  	c := p.ConnectionDetails
   206  	if defaults.String(c.Options["sslmode"], "disable") == "disable" || strings.Contains(c.RawOptions, "sslmode=disable") {
   207  		cmd.Args = append(cmd.Args, "--insecure")
   208  	}
   209  	return genericDumpSchema(p.Details(), cmd, w)
   210  }
   211  
   212  func (p *cockroach) LoadSchema(r io.Reader) error {
   213  	return genericLoadSchema(p, r)
   214  }
   215  
   216  func (p *cockroach) TruncateAll(tx *Connection) error {
   217  	type table struct {
   218  		TableName string `db:"table_name"`
   219  	}
   220  
   221  	tableQuery := p.tablesQuery()
   222  
   223  	var tables []table
   224  	if err := tx.RawQuery(tableQuery, tx.MigrationTableName(), tx.Dialect.Details().Database).All(&tables); err != nil {
   225  		return err
   226  	}
   227  
   228  	if len(tables) == 0 {
   229  		return nil
   230  	}
   231  
   232  	tableNames := make([]string, len(tables))
   233  	for i, t := range tables {
   234  		tableNames[i] = t.TableName
   235  		//! work around for current limitation of DDL and DML at the same transaction.
   236  		//  it should be fixed when cockroach support it or with other approach.
   237  		//  https://www.cockroachlabs.com/docs/stable/known-limitations.html#schema-changes-within-transactions
   238  		if err := tx.RawQuery(fmt.Sprintf("delete from %s", p.Quote(t.TableName))).Exec(); err != nil {
   239  			return err
   240  		}
   241  	}
   242  	return nil
   243  	// TODO!
   244  	// return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec()
   245  }
   246  
   247  func (p *cockroach) AfterOpen(c *Connection) error {
   248  	if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil {
   249  		return err
   250  	}
   251  	if s := strings.Split(p.info.VersionString, " "); len(s) > 3 {
   252  		p.info.product = s[0]
   253  		p.info.license = s[1]
   254  		p.info.version = s[2]
   255  		p.info.buildInfo = s[3]
   256  	}
   257  	log(logging.Debug, "server: %v %v %v", p.info.product, p.info.license, p.info.version)
   258  
   259  	return nil
   260  }
   261  
   262  func newCockroach(deets *ConnectionDetails) (dialect, error) {
   263  	deets.Dialect = "postgres"
   264  	d := &cockroach{
   265  		commonDialect:  commonDialect{ConnectionDetails: deets},
   266  		translateCache: map[string]string{},
   267  		mu:             sync.Mutex{},
   268  	}
   269  	d.info.client = deets.Options["application_name"]
   270  	return d, nil
   271  }
   272  
   273  func finalizerCockroach(cd *ConnectionDetails) {
   274  	appName := filepath.Base(os.Args[0])
   275  	cd.Options["application_name"] = defaults.String(cd.Options["application_name"], appName)
   276  	cd.Port = defaults.String(cd.Port, portCockroach)
   277  	if cd.URL != "" {
   278  		cd.URL = "postgres://" + trimCockroachPrefix(cd.URL)
   279  	}
   280  }
   281  
   282  func trimCockroachPrefix(u string) string {
   283  	parts := strings.Split(u, "://")
   284  	if len(parts) != 2 {
   285  		return u
   286  	}
   287  	return parts[1]
   288  }
   289  
   290  func (p *cockroach) tablesQuery() string {
   291  	// See https://www.cockroachlabs.com/docs/stable/information-schema.html for more info about information schema changes
   292  	tableQuery := selectTablesQueryCockroach
   293  	if strings.HasPrefix(p.info.version, "v1.") {
   294  		tableQuery = selectTablesQueryCockroachV1
   295  	}
   296  	return tableQuery
   297  }