github.com/friesencr/pop/v6@v6.1.6/dialect_cockroach.go (about) 1 package pop 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "io" 8 "net/url" 9 "os" 10 "os/exec" 11 "path/filepath" 12 "regexp" 13 "strings" 14 "sync" 15 16 "github.com/friesencr/pop/v6/columns" 17 "github.com/friesencr/pop/v6/internal/defaults" 18 "github.com/friesencr/pop/v6/logging" 19 "github.com/gobuffalo/fizz" 20 "github.com/gobuffalo/fizz/translators" 21 "github.com/gofrs/uuid/v5" 22 _ "github.com/jackc/pgx/v4/stdlib" // Import PostgreSQL driver 23 "github.com/jmoiron/sqlx" 24 ) 25 26 const nameCockroach = "cockroach" 27 const portCockroach = "26257" 28 29 const selectTablesQueryCockroach = "select table_name from information_schema.tables where table_schema = 'public' and table_type = 'BASE TABLE' and table_name <> ? and table_catalog = ?" 30 const selectTablesQueryCockroachV1 = "select table_name from information_schema.tables where table_name <> ? and table_schema = ?" 31 32 func init() { 33 AvailableDialects = append(AvailableDialects, nameCockroach) 34 dialectSynonyms["cockroachdb"] = nameCockroach 35 dialectSynonyms["crdb"] = nameCockroach 36 finalizer[nameCockroach] = finalizerCockroach 37 newConnection[nameCockroach] = newCockroach 38 } 39 40 var _ dialect = &cockroach{} 41 42 // ServerInfo holds informational data about connected database server. 43 type cockroachInfo struct { 44 VersionString string `db:"version"` 45 product string `db:"-"` 46 license string `db:"-"` 47 version string `db:"-"` 48 buildInfo string `db:"-"` 49 client string `db:"-"` 50 } 51 52 type cockroach struct { 53 commonDialect 54 translateCache map[string]string 55 mu sync.Mutex 56 info cockroachInfo 57 } 58 59 func (p *cockroach) Name() string { 60 return nameCockroach 61 } 62 63 func (p *cockroach) DefaultDriver() string { 64 return "pgx" 65 } 66 67 func (p *cockroach) Details() *ConnectionDetails { 68 return p.ConnectionDetails 69 } 70 71 func (p *cockroach) Create(c *Connection, model *Model, cols columns.Columns) error { 72 keyType, err := model.PrimaryKeyType() 73 if err != nil { 74 return err 75 } 76 switch keyType { 77 case "int", "int64": 78 cols.Remove(model.IDField()) 79 w := cols.Writeable() 80 var query string 81 if len(w.Cols) > 0 { 82 query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField()) 83 } else { 84 query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", p.Quote(model.TableName()), model.IDField()) 85 } 86 txlog(logging.SQL, c, query, model.Value) 87 rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value) 88 if err != nil { 89 return fmt.Errorf("named insert: %w", err) 90 } 91 defer rows.Close() 92 if !rows.Next() { 93 if err := rows.Err(); err != nil { 94 return fmt.Errorf("named insert: next: %w", err) 95 } 96 return fmt.Errorf("named insert: %w", sql.ErrNoRows) 97 } 98 var id interface{} 99 if err := rows.Scan(&id); err != nil { 100 return fmt.Errorf("named insert: scan: %w", err) 101 } 102 if err := rows.Close(); err != nil { 103 return fmt.Errorf("named insert: close: %w", err) 104 } 105 model.setID(id) 106 return nil 107 108 case "UUID": 109 var query string 110 if model.ID() == emptyUUID { 111 cols.Remove(model.IDField()) 112 w := cols.Writeable() 113 if len(w.Cols) > 0 { 114 query = fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (gen_random_uuid(), %s) RETURNING %s", p.Quote(model.TableName()), model.IDField(), w.QuotedString(p), w.SymbolizedString(), model.IDField()) 115 } else { 116 query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (gen_random_uuid()) RETURNING %s", p.Quote(model.TableName()), model.IDField(), model.IDField()) 117 } 118 } else { 119 w := cols.Writeable() 120 w.Add(model.IDField()) 121 query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField()) 122 } 123 txlog(logging.SQL, c, query, model.Value) 124 rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value) 125 if err != nil { 126 return fmt.Errorf("named insert: %w", err) 127 } 128 defer rows.Close() 129 if !rows.Next() { 130 if err := rows.Err(); err != nil { 131 return fmt.Errorf("named insert: next: %w", err) 132 } 133 return fmt.Errorf("named insert: %w", sql.ErrNoRows) 134 } 135 var id uuid.UUID 136 if err := rows.Scan(&id); err != nil { 137 return fmt.Errorf("named insert: scan: %w", err) 138 } 139 if err := rows.Close(); err != nil { 140 return fmt.Errorf("named insert: close: %w", err) 141 } 142 model.setID(id) 143 return nil 144 } 145 return genericCreate(c, model, cols, p) 146 } 147 148 func (p *cockroach) Update(c *Connection, model *Model, cols columns.Columns) error { 149 return genericUpdate(c, model, cols, p) 150 } 151 152 func (p *cockroach) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) { 153 return genericUpdateQuery(c, model, cols, p, query, sqlx.DOLLAR) 154 } 155 156 func (p *cockroach) Destroy(c *Connection, model *Model) error { 157 stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s AS %s WHERE %s", p.Quote(model.TableName()), model.Alias(), model.WhereID())) 158 _, err := genericExec(c, stmt, model.ID()) 159 return err 160 } 161 162 func (p *cockroach) Delete(c *Connection, model *Model, query Query) error { 163 return genericDelete(c, model, query) 164 } 165 166 func (p *cockroach) SelectOne(c *Connection, model *Model, query Query) error { 167 return genericSelectOne(c, model, query) 168 } 169 170 func (p *cockroach) SelectMany(c *Connection, models *Model, query Query) error { 171 return genericSelectMany(c, models, query) 172 } 173 174 func (p *cockroach) CreateDB() error { 175 // createdb -h db -p 5432 -U cockroach enterprise_development 176 deets := p.ConnectionDetails 177 178 db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb()) 179 if err != nil { 180 return fmt.Errorf("error creating Cockroach database %s: %w", deets.Database, err) 181 } 182 defer db.Close() 183 query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database)) 184 log(logging.SQL, query) 185 186 _, err = db.Exec(query) 187 if err != nil { 188 return fmt.Errorf("error creating Cockroach database %s: %w", deets.Database, err) 189 } 190 191 log(logging.Info, "created database %s", deets.Database) 192 return nil 193 } 194 195 func (p *cockroach) DropDB() error { 196 deets := p.ConnectionDetails 197 198 db, err := openPotentiallyInstrumentedConnection(p, p.urlWithoutDb()) 199 if err != nil { 200 return fmt.Errorf("error dropping Cockroach database %s: %w", deets.Database, err) 201 } 202 defer db.Close() 203 query := fmt.Sprintf("DROP DATABASE %s CASCADE;", p.Quote(deets.Database)) 204 log(logging.SQL, query) 205 206 _, err = db.Exec(query) 207 if err != nil { 208 return fmt.Errorf("error dropping Cockroach database %s: %w", deets.Database, err) 209 } 210 211 log(logging.Info, "dropped database %s", deets.Database) 212 return nil 213 } 214 215 func (p *cockroach) URL() string { 216 c := p.ConnectionDetails 217 if c.URL != "" { 218 return c.URL 219 } 220 s := "postgres://%s:%s@%s:%s/%s?%s" 221 return fmt.Sprintf(s, c.User, url.QueryEscape(c.Password), c.Host, c.Port, c.Database, c.OptionsString("")) 222 } 223 224 func (p *cockroach) urlWithoutDb() string { 225 c := p.ConnectionDetails 226 s := "postgres://%s:%s@%s:%s/?%s" 227 return fmt.Sprintf(s, c.User, url.QueryEscape(c.Password), c.Host, c.Port, c.OptionsString("")) 228 } 229 230 func (p *cockroach) MigrationURL() string { 231 return p.URL() 232 } 233 234 func (p *cockroach) TranslateSQL(sql string) string { 235 defer p.mu.Unlock() 236 p.mu.Lock() 237 238 if csql, ok := p.translateCache[sql]; ok { 239 return csql 240 } 241 csql := sqlx.Rebind(sqlx.DOLLAR, sql) 242 243 p.translateCache[sql] = csql 244 return csql 245 } 246 247 func (p *cockroach) FizzTranslator() fizz.Translator { 248 return translators.NewCockroach(p.URL(), p.Details().Database) 249 } 250 251 func (p *cockroach) DumpSchema(w io.Writer) error { 252 cmd := exec.Command("cockroach", "sql", "-e", "SHOW CREATE ALL TABLES", "-d", p.Details().Database, "--format", "raw") 253 254 c := p.ConnectionDetails 255 if defaults.String(c.option("sslmode"), "disable") == "disable" || strings.Contains(c.RawOptions, "sslmode=disable") { 256 cmd.Args = append(cmd.Args, "--insecure") 257 } 258 return cockroachDumpSchema(p.Details(), cmd, w) 259 } 260 261 func cockroachDumpSchema(deets *ConnectionDetails, cmd *exec.Cmd, w io.Writer) error { 262 log(logging.SQL, strings.Join(cmd.Args, " ")) 263 264 var bb bytes.Buffer 265 266 cmd.Stdout = &bb 267 cmd.Stderr = os.Stderr 268 269 err := cmd.Run() 270 if err != nil { 271 return err 272 } 273 274 // --format raw returns comments prefixed with # which is invalid, so we make it a valid SQL comment. 275 result := regexp.MustCompile("(?m)^#").ReplaceAll(bb.Bytes(), []byte("-- #")) 276 277 if _, err := w.Write(result); err != nil { 278 return err 279 } 280 281 x := bytes.TrimSpace(result) 282 if len(x) == 0 { 283 return fmt.Errorf("unable to dump schema for %s", deets.Database) 284 } 285 286 log(logging.Info, "dumped schema for %s", deets.Database) 287 return nil 288 } 289 290 func (p *cockroach) LoadSchema(r io.Reader) error { 291 return genericLoadSchema(p, r) 292 } 293 294 func (p *cockroach) TruncateAll(tx *Connection) error { 295 type table struct { 296 TableName string `db:"table_name"` 297 } 298 299 tableQuery := p.tablesQuery() 300 301 var tables []table 302 if err := tx.RawQuery(tableQuery, tx.MigrationTableName(), tx.Dialect.Details().Database).All(&tables); err != nil { 303 return err 304 } 305 306 if len(tables) == 0 { 307 return nil 308 } 309 310 tableNames := make([]string, len(tables)) 311 for i, t := range tables { 312 tableNames[i] = t.TableName 313 //! work around for current limitation of DDL and DML at the same transaction. 314 // it should be fixed when cockroach support it or with other approach. 315 // https://www.cockroachlabs.com/docs/stable/known-limitations.html#schema-changes-within-transactions 316 if err := tx.RawQuery(fmt.Sprintf("delete from %s", p.Quote(t.TableName))).Exec(); err != nil { 317 return err 318 } 319 } 320 return nil 321 // TODO! 322 // return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec() 323 } 324 325 func (p *cockroach) AfterOpen(c *Connection) error { 326 if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil { 327 return err 328 } 329 if s := strings.Split(p.info.VersionString, " "); len(s) > 3 { 330 p.info.product = s[0] 331 p.info.license = s[1] 332 p.info.version = s[2] 333 p.info.buildInfo = s[3] 334 } 335 log(logging.Debug, "server: %v %v %v", p.info.product, p.info.license, p.info.version) 336 337 return nil 338 } 339 340 func newCockroach(deets *ConnectionDetails) (dialect, error) { 341 deets.Dialect = "postgres" 342 d := &cockroach{ 343 commonDialect: commonDialect{ConnectionDetails: deets}, 344 translateCache: map[string]string{}, 345 mu: sync.Mutex{}, 346 } 347 d.info.client = deets.option("application_name") 348 return d, nil 349 } 350 351 func finalizerCockroach(cd *ConnectionDetails) { 352 appName := filepath.Base(os.Args[0]) 353 cd.setOptionWithDefault("application_name", cd.option("application_name"), appName) 354 cd.Port = defaults.String(cd.Port, portCockroach) 355 if cd.URL != "" { 356 cd.URL = "postgres://" + trimCockroachPrefix(cd.URL) 357 } 358 } 359 360 func trimCockroachPrefix(u string) string { 361 parts := strings.Split(u, "://") 362 if len(parts) != 2 { 363 return u 364 } 365 return parts[1] 366 } 367 368 func (p *cockroach) tablesQuery() string { 369 // See https://www.cockroachlabs.com/docs/stable/information-schema.html for more info about information schema changes 370 tableQuery := selectTablesQueryCockroach 371 if strings.HasPrefix(p.info.version, "v1.") { 372 tableQuery = selectTablesQueryCockroachV1 373 } 374 return tableQuery 375 }