github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/database/clickhouse/clickhouse.go (about) 1 package clickhouse 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "net/url" 8 "strconv" 9 "strings" 10 "time" 11 12 "go.uber.org/atomic" 13 14 "github.com/bdollma-te/migrate/v4" 15 "github.com/bdollma-te/migrate/v4/database" 16 "github.com/bdollma-te/migrate/v4/database/multistmt" 17 "github.com/hashicorp/go-multierror" 18 ) 19 20 var ( 21 multiStmtDelimiter = []byte(";") 22 23 DefaultMigrationsTable = "schema_migrations" 24 DefaultMigrationsTableEngine = "TinyLog" 25 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 26 27 ErrNilConfig = fmt.Errorf("no config") 28 ) 29 30 type Config struct { 31 DatabaseName string 32 ClusterName string 33 MigrationsTable string 34 MigrationsTableEngine string 35 MultiStatementEnabled bool 36 MultiStatementMaxSize int 37 } 38 39 func init() { 40 database.Register("clickhouse", &ClickHouse{}) 41 } 42 43 func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) { 44 if config == nil { 45 return nil, ErrNilConfig 46 } 47 48 if err := conn.Ping(); err != nil { 49 return nil, err 50 } 51 52 ch := &ClickHouse{ 53 conn: conn, 54 config: config, 55 } 56 57 if err := ch.init(); err != nil { 58 return nil, err 59 } 60 61 return ch, nil 62 } 63 64 type ClickHouse struct { 65 conn *sql.DB 66 config *Config 67 isLocked atomic.Bool 68 } 69 70 func (ch *ClickHouse) Open(dsn string) (database.Driver, error) { 71 purl, err := url.Parse(dsn) 72 if err != nil { 73 return nil, err 74 } 75 q := migrate.FilterCustomQuery(purl) 76 q.Scheme = "tcp" 77 conn, err := sql.Open("clickhouse", q.String()) 78 if err != nil { 79 return nil, err 80 } 81 82 multiStatementMaxSize := DefaultMultiStatementMaxSize 83 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 84 multiStatementMaxSize, err = strconv.Atoi(s) 85 if err != nil { 86 return nil, err 87 } 88 } 89 90 migrationsTableEngine := DefaultMigrationsTableEngine 91 if s := purl.Query().Get("x-migrations-table-engine"); len(s) > 0 { 92 migrationsTableEngine = s 93 } 94 95 ch = &ClickHouse{ 96 conn: conn, 97 config: &Config{ 98 MigrationsTable: purl.Query().Get("x-migrations-table"), 99 MigrationsTableEngine: migrationsTableEngine, 100 DatabaseName: strings.TrimLeft(purl.Path, "/"), 101 ClusterName: purl.Query().Get("x-cluster-name"), 102 MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true", 103 MultiStatementMaxSize: multiStatementMaxSize, 104 }, 105 } 106 107 if err := ch.init(); err != nil { 108 return nil, err 109 } 110 111 return ch, nil 112 } 113 114 func (ch *ClickHouse) init() error { 115 if len(ch.config.DatabaseName) == 0 { 116 if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil { 117 return err 118 } 119 } 120 121 if len(ch.config.MigrationsTable) == 0 { 122 ch.config.MigrationsTable = DefaultMigrationsTable 123 } 124 125 if ch.config.MultiStatementMaxSize <= 0 { 126 ch.config.MultiStatementMaxSize = DefaultMultiStatementMaxSize 127 } 128 129 if len(ch.config.MigrationsTableEngine) == 0 { 130 ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine 131 } 132 133 return ch.ensureVersionTable() 134 } 135 136 func (ch *ClickHouse) Run(r io.Reader) error { 137 if ch.config.MultiStatementEnabled { 138 var err error 139 if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool { 140 tq := strings.TrimSpace(string(m)) 141 if tq == "" { 142 return true 143 } 144 if _, e := ch.conn.Exec(string(m)); e != nil { 145 err = database.Error{OrigErr: e, Err: "migration failed", Query: m} 146 return false 147 } 148 return true 149 }); e != nil { 150 return e 151 } 152 return err 153 } 154 155 migration, err := io.ReadAll(r) 156 if err != nil { 157 return err 158 } 159 160 if _, err := ch.conn.Exec(string(migration)); err != nil { 161 return database.Error{OrigErr: err, Err: "migration failed", Query: migration} 162 } 163 164 return nil 165 } 166 func (ch *ClickHouse) Version() (int, bool, error) { 167 var ( 168 version int 169 dirty uint8 170 query = "SELECT version, dirty FROM `" + ch.config.MigrationsTable + "` ORDER BY sequence DESC LIMIT 1" 171 ) 172 if err := ch.conn.QueryRow(query).Scan(&version, &dirty); err != nil { 173 if err == sql.ErrNoRows { 174 return database.NilVersion, false, nil 175 } 176 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 177 } 178 return version, dirty == 1, nil 179 } 180 181 func (ch *ClickHouse) SetVersion(version int, dirty bool) error { 182 var ( 183 bool = func(v bool) uint8 { 184 if v { 185 return 1 186 } 187 return 0 188 } 189 tx, err = ch.conn.Begin() 190 ) 191 if err != nil { 192 return err 193 } 194 195 query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)" 196 stmt, err := tx.Prepare(query) 197 if err != nil { 198 if errRollback := tx.Rollback(); errRollback != nil { 199 return fmt.Errorf("error during prepare statement %w and rollback %s", err, errRollback.Error()) 200 } 201 return err 202 } 203 204 if _, err := stmt.Exec(int64(version), bool(dirty), uint64(time.Now().UnixNano())); err != nil { 205 return &database.Error{OrigErr: err, Query: []byte(query)} 206 } 207 208 return tx.Commit() 209 } 210 211 // ensureVersionTable checks if versions table exists and, if not, creates it. 212 // Note that this function locks the database, which deviates from the usual 213 // convention of "caller locks" in the ClickHouse type. 214 func (ch *ClickHouse) ensureVersionTable() (err error) { 215 if err = ch.Lock(); err != nil { 216 return err 217 } 218 219 defer func() { 220 if e := ch.Unlock(); e != nil { 221 if err == nil { 222 err = e 223 } else { 224 err = multierror.Append(err, e) 225 } 226 } 227 }() 228 229 var ( 230 table string 231 query = "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName) + " LIKE '" + ch.config.MigrationsTable + "'" 232 ) 233 // check if migration table exists 234 if err := ch.conn.QueryRow(query).Scan(&table); err != nil { 235 if err != sql.ErrNoRows { 236 return &database.Error{OrigErr: err, Query: []byte(query)} 237 } 238 } else { 239 return nil 240 } 241 242 // if not, create the empty migration table 243 if len(ch.config.ClusterName) > 0 { 244 query = fmt.Sprintf(` 245 CREATE TABLE %s ON CLUSTER %s ( 246 version Int64, 247 dirty UInt8, 248 sequence UInt64 249 ) Engine=%s`, ch.config.MigrationsTable, ch.config.ClusterName, ch.config.MigrationsTableEngine) 250 } else { 251 query = fmt.Sprintf(` 252 CREATE TABLE %s ( 253 version Int64, 254 dirty UInt8, 255 sequence UInt64 256 ) Engine=%s`, ch.config.MigrationsTable, ch.config.MigrationsTableEngine) 257 } 258 259 if strings.HasSuffix(ch.config.MigrationsTableEngine, "Tree") { 260 query = fmt.Sprintf(`%s ORDER BY sequence`, query) 261 } 262 263 if _, err := ch.conn.Exec(query); err != nil { 264 return &database.Error{OrigErr: err, Query: []byte(query)} 265 } 266 return nil 267 } 268 269 func (ch *ClickHouse) Drop() (err error) { 270 query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName) 271 tables, err := ch.conn.Query(query) 272 273 if err != nil { 274 return &database.Error{OrigErr: err, Query: []byte(query)} 275 } 276 defer func() { 277 if errClose := tables.Close(); errClose != nil { 278 err = multierror.Append(err, errClose) 279 } 280 }() 281 282 for tables.Next() { 283 var table string 284 if err := tables.Scan(&table); err != nil { 285 return err 286 } 287 288 query = "DROP TABLE IF EXISTS " + quoteIdentifier(ch.config.DatabaseName) + "." + quoteIdentifier(table) 289 290 if _, err := ch.conn.Exec(query); err != nil { 291 return &database.Error{OrigErr: err, Query: []byte(query)} 292 } 293 } 294 if err := tables.Err(); err != nil { 295 return &database.Error{OrigErr: err, Query: []byte(query)} 296 } 297 298 return nil 299 } 300 301 func (ch *ClickHouse) Lock() error { 302 if !ch.isLocked.CAS(false, true) { 303 return database.ErrLocked 304 } 305 306 return nil 307 } 308 func (ch *ClickHouse) Unlock() error { 309 if !ch.isLocked.CAS(true, false) { 310 return database.ErrNotLocked 311 } 312 313 return nil 314 } 315 func (ch *ClickHouse) Close() error { return ch.conn.Close() } 316 317 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 318 func quoteIdentifier(name string) string { 319 end := strings.IndexRune(name, 0) 320 if end > -1 { 321 name = name[:end] 322 } 323 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 324 }