github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/dialect_cockroach.go (about)

     1  package pop
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"net/url"
     9  	"os"
    10  	"os/exec"
    11  	"path/filepath"
    12  	"regexp"
    13  	"strings"
    14  	"sync"
    15  
    16  	"github.com/Accefy/pop/internal/defaults"
    17  	"github.com/gobuffalo/fizz"
    18  	"github.com/gobuffalo/fizz/translators"
    19  	"github.com/gobuffalo/pop/v6/columns"
    20  	"github.com/gobuffalo/pop/v6/logging"
    21  	"github.com/gofrs/uuid"
    22  	_ "github.com/jackc/pgx/v4/stdlib" // Import PostgreSQL driver
    23  	"github.com/jmoiron/sqlx"
    24  )
    25  
    26  const nameCockroach = "cockroach"
    27  const portCockroach = "26257"
    28  
    29  const selectTablesQueryCockroach = "select table_name from information_schema.tables where table_schema = 'public' and table_type = 'BASE TABLE' and table_name <> ? and table_catalog = ?"
    30  const selectTablesQueryCockroachV1 = "select table_name from information_schema.tables where table_name <> ? and table_schema = ?"
    31  
    32  func init() {
    33  	AvailableDialects = append(AvailableDialects, nameCockroach)
    34  	dialectSynonyms["cockroachdb"] = nameCockroach
    35  	dialectSynonyms["crdb"] = nameCockroach
    36  	finalizer[nameCockroach] = finalizerCockroach
    37  	newConnection[nameCockroach] = newCockroach
    38  }
    39  
    40  var _ dialect = &cockroach{}
    41  
    42  // ServerInfo holds informational data about connected database server.
    43  type cockroachInfo struct {
    44  	VersionString string `db:"version"`
    45  	product       string `db:"-"`
    46  	license       string `db:"-"`
    47  	version       string `db:"-"`
    48  	buildInfo     string `db:"-"`
    49  	client        string `db:"-"`
    50  }
    51  
    52  type cockroach struct {
    53  	commonDialect
    54  	translateCache map[string]string
    55  	mu             sync.Mutex
    56  	info           cockroachInfo
    57  }
    58  
    59  func (p *cockroach) Name() string {
    60  	return nameCockroach
    61  }
    62  
    63  func (p *cockroach) DefaultDriver() string {
    64  	return "pgx"
    65  }
    66  
    67  func (p *cockroach) Details() *ConnectionDetails {
    68  	return p.ConnectionDetails
    69  }
    70  
    71  func (p *cockroach) Create(c *Connection, model *Model, cols columns.Columns) error {
    72  	keyType, err := model.PrimaryKeyType()
    73  	if err != nil {
    74  		return err
    75  	}
    76  	switch keyType {
    77  	case "int", "int64":
    78  		cols.Remove(model.IDField())
    79  		w := cols.Writeable()
    80  		var query string
    81  		if len(w.Cols) > 0 {
    82  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
    83  		} else {
    84  			query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", p.Quote(model.TableName()), model.IDField())
    85  		}
    86  		txlog(logging.SQL, c, query, model.Value)
    87  		rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value)
    88  		if err != nil {
    89  			return fmt.Errorf("named insert: %w", err)
    90  		}
    91  		defer rows.Close()
    92  		if !rows.Next() {
    93  			if err := rows.Err(); err != nil {
    94  				return fmt.Errorf("named insert: next: %w", err)
    95  			}
    96  			return fmt.Errorf("named insert: %w", sql.ErrNoRows)
    97  		}
    98  		var id interface{}
    99  		if err := rows.Scan(&id); err != nil {
   100  			return fmt.Errorf("named insert: scan: %w", err)
   101  		}
   102  		if err := rows.Close(); err != nil {
   103  			return fmt.Errorf("named insert: close: %w", err)
   104  		}
   105  		model.setID(id)
   106  		return nil
   107  
   108  	case "UUID":
   109  		var query string
   110  		if model.ID() == emptyUUID {
   111  			cols.Remove(model.IDField())
   112  			w := cols.Writeable()
   113  			if len(w.Cols) > 0 {
   114  				query = fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (gen_random_uuid(), %s) RETURNING %s", p.Quote(model.TableName()), model.IDField(), w.QuotedString(p), w.SymbolizedString(), model.IDField())
   115  			} else {
   116  				query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (gen_random_uuid()) RETURNING %s", p.Quote(model.TableName()), model.IDField(), model.IDField())
   117  			}
   118  		} else {
   119  			w := cols.Writeable()
   120  			w.Add(model.IDField())
   121  			query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
   122  		}
   123  		txlog(logging.SQL, c, query, model.Value)
   124  		rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value)
   125  		if err != nil {
   126  			return fmt.Errorf("named insert: %w", err)
   127  		}
   128  		defer rows.Close()
   129  		if !rows.Next() {
   130  			if err := rows.Err(); err != nil {
   131  				return fmt.Errorf("named insert: next: %w", err)
   132  			}
   133  			return fmt.Errorf("named insert: %w", sql.ErrNoRows)
   134  		}
   135  		var id uuid.UUID
   136  		if err := rows.Scan(&id); err != nil {
   137  			return fmt.Errorf("named insert: scan: %w", err)
   138  		}
   139  		if err := rows.Close(); err != nil {
   140  			return fmt.Errorf("named insert: close: %w", err)
   141  		}
   142  		model.setID(id)
   143  		return nil
   144  	}
   145  	return genericCreate(c, model, cols, p)
   146  }
   147  
   148  func (p *cockroach) Update(c *Connection, model *Model, cols columns.Columns) error {
   149  	return genericUpdate(c, model, cols, p)
   150  }
   151  
   152  func (p *cockroach) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) {
   153  	return genericUpdateQuery(c, model, cols, p, query, sqlx.DOLLAR)
   154  }
   155  
   156  func (p *cockroach) Destroy(c *Connection, model *Model) error {
   157  	stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID()))
   158  	_, err := genericExec(c, stmt, model.ID())
   159  	return err
   160  }
   161  
   162  func (p *cockroach) Delete(c *Connection, model *Model, query Query) error {
   163  	return genericDelete(c, model, query)
   164  }
   165  
   166  func (p *cockroach) SelectOne(c *Connection, model *Model, query Query) error {
   167  	return genericSelectOne(c, model, query)
   168  }
   169  
   170  func (p *cockroach) SelectMany(c *Connection, models *Model, query Query) error {
   171  	return genericSelectMany(c, models, query)
   172  }
   173  
   174  func (p *cockroach) CreateDB() error {
   175  	// createdb -h db -p 5432 -U cockroach enterprise_development
   176  	deets := p.ConnectionDetails
   177  
   178  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   179  	if err != nil {
   180  		return fmt.Errorf("error creating Cockroach database %s: %w", deets.Database, err)
   181  	}
   182  	defer db.Close()
   183  	query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
   184  	log(logging.SQL, query)
   185  
   186  	_, err = db.Exec(query)
   187  	if err != nil {
   188  		return fmt.Errorf("error creating Cockroach database %s: %w", deets.Database, err)
   189  	}
   190  
   191  	log(logging.Info, "created database %s", deets.Database)
   192  	return nil
   193  }
   194  
   195  func (p *cockroach) DropDB() error {
   196  	deets := p.ConnectionDetails
   197  
   198  	db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb())
   199  	if err != nil {
   200  		return fmt.Errorf("error dropping Cockroach database %s: %w", deets.Database, err)
   201  	}
   202  	defer db.Close()
   203  	query := fmt.Sprintf("DROP DATABASE %s CASCADE;", p.Quote(deets.Database))
   204  	log(logging.SQL, query)
   205  
   206  	_, err = db.Exec(query)
   207  	if err != nil {
   208  		return fmt.Errorf("error dropping Cockroach database %s: %w", deets.Database, err)
   209  	}
   210  
   211  	log(logging.Info, "dropped database %s", deets.Database)
   212  	return nil
   213  }
   214  
   215  func (p *cockroach) URL() string {
   216  	c := p.ConnectionDetails
   217  	if c.URL != "" {
   218  		return c.URL
   219  	}
   220  	s := "postgres://%s:%s@%s:%s/%s?%s"
   221  	return fmt.Sprintf(s, c.User, url.QueryEscape(c.Password), c.Host, c.Port, c.Database, c.OptionsString(""))
   222  }
   223  
   224  func (p *cockroach) urlWithoutDb() string {
   225  	c := p.ConnectionDetails
   226  	s := "postgres://%s:%s@%s:%s/?%s"
   227  	return fmt.Sprintf(s, c.User, url.QueryEscape(c.Password), c.Host, c.Port, c.OptionsString(""))
   228  }
   229  
   230  func (p *cockroach) MigrationURL() string {
   231  	return p.URL()
   232  }
   233  
   234  func (p *cockroach) TranslateSQL(sql string) string {
   235  	defer p.mu.Unlock()
   236  	p.mu.Lock()
   237  
   238  	if csql, ok := p.translateCache[sql]; ok {
   239  		return csql
   240  	}
   241  	csql := sqlx.Rebind(sqlx.DOLLAR, sql)
   242  
   243  	p.translateCache[sql] = csql
   244  	return csql
   245  }
   246  
   247  func (p *cockroach) FizzTranslator() fizz.Translator {
   248  	return translators.NewCockroach(p.URL(), p.Details().Database)
   249  }
   250  
   251  func (p *cockroach) DumpSchema(w io.Writer) error {
   252  	cmd := exec.Command("cockroach", "sql", "-e", "SHOW CREATE ALL TABLES", "-d", p.Details().Database, "--format", "raw")
   253  
   254  	c := p.ConnectionDetails
   255  	if defaults.String(c.option("sslmode"), "disable") == "disable" || strings.Contains(c.RawOptions, "sslmode=disable") {
   256  		cmd.Args = append(cmd.Args, "--insecure")
   257  	}
   258  	return cockroachDumpSchema(p.Details(), cmd, w)
   259  }
   260  
   261  func cockroachDumpSchema(deets *ConnectionDetails, cmd *exec.Cmd, w io.Writer) error {
   262  	log(logging.SQL, strings.Join(cmd.Args, " "))
   263  
   264  	var bb bytes.Buffer
   265  
   266  	cmd.Stdout = &bb
   267  	cmd.Stderr = os.Stderr
   268  
   269  	err := cmd.Run()
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	// --format raw returns comments prefixed with # which is invalid, so we make it a valid SQL comment.
   275  	result := regexp.MustCompile("(?m)^#").ReplaceAll(bb.Bytes(), []byte("-- #"))
   276  
   277  	if _, err := w.Write(result); err != nil {
   278  		return err
   279  	}
   280  
   281  	x := bytes.TrimSpace(result)
   282  	if len(x) == 0 {
   283  		return fmt.Errorf("unable to dump schema for %s", deets.Database)
   284  	}
   285  
   286  	log(logging.Info, "dumped schema for %s", deets.Database)
   287  	return nil
   288  }
   289  
   290  func (p *cockroach) LoadSchema(r io.Reader) error {
   291  	return genericLoadSchema(p, r)
   292  }
   293  
   294  func (p *cockroach) TruncateAll(tx *Connection) error {
   295  	type table struct {
   296  		TableName string `db:"table_name"`
   297  	}
   298  
   299  	tableQuery := p.tablesQuery()
   300  
   301  	var tables []table
   302  	if err := tx.RawQuery(tableQuery, tx.MigrationTableName(), tx.Dialect.Details().Database).All(&tables); err != nil {
   303  		return err
   304  	}
   305  
   306  	if len(tables) == 0 {
   307  		return nil
   308  	}
   309  
   310  	tableNames := make([]string, len(tables))
   311  	for i, t := range tables {
   312  		tableNames[i] = t.TableName
   313  		//! work around for current limitation of DDL and DML at the same transaction.
   314  		//  it should be fixed when cockroach support it or with other approach.
   315  		//  https://www.cockroachlabs.com/docs/stable/known-limitations.html#schema-changes-within-transactions
   316  		if err := tx.RawQuery(fmt.Sprintf("delete from %s", p.Quote(t.TableName))).Exec(); err != nil {
   317  			return err
   318  		}
   319  	}
   320  	return nil
   321  	// TODO!
   322  	// return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec()
   323  }
   324  
   325  func (p *cockroach) AfterOpen(c *Connection) error {
   326  	if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil {
   327  		return err
   328  	}
   329  	if s := strings.Split(p.info.VersionString, " "); len(s) > 3 {
   330  		p.info.product = s[0]
   331  		p.info.license = s[1]
   332  		p.info.version = s[2]
   333  		p.info.buildInfo = s[3]
   334  	}
   335  	log(logging.Debug, "server: %v %v %v", p.info.product, p.info.license, p.info.version)
   336  
   337  	return nil
   338  }
   339  
   340  func newCockroach(deets *ConnectionDetails) (dialect, error) {
   341  	deets.Dialect = "postgres"
   342  	d := &cockroach{
   343  		commonDialect:  commonDialect{ConnectionDetails: deets},
   344  		translateCache: map[string]string{},
   345  		mu:             sync.Mutex{},
   346  	}
   347  	d.info.client = deets.option("application_name")
   348  	return d, nil
   349  }
   350  
   351  func finalizerCockroach(cd *ConnectionDetails) {
   352  	appName := filepath.Base(os.Args[0])
   353  	cd.setOptionWithDefault("application_name", cd.option("application_name"), appName)
   354  	cd.Port = defaults.String(cd.Port, portCockroach)
   355  	if cd.URL != "" {
   356  		cd.URL = "postgres://" + trimCockroachPrefix(cd.URL)
   357  	}
   358  }
   359  
   360  func trimCockroachPrefix(u string) string {
   361  	parts := strings.Split(u, "://")
   362  	if len(parts) != 2 {
   363  		return u
   364  	}
   365  	return parts[1]
   366  }
   367  
   368  func (p *cockroach) tablesQuery() string {
   369  	// See https://www.cockroachlabs.com/docs/stable/information-schema.html for more info about information schema changes
   370  	tableQuery := selectTablesQueryCockroach
   371  	if strings.HasPrefix(p.info.version, "v1.") {
   372  		tableQuery = selectTablesQueryCockroachV1
   373  	}
   374  	return tableQuery
   375  }