github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/database/clickhouse/clickhouse.go (about)

     1  package clickhouse
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"net/url"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/bdollma-te/migrate/v4"
    15  	"github.com/bdollma-te/migrate/v4/database"
    16  	"github.com/bdollma-te/migrate/v4/database/multistmt"
    17  	"github.com/hashicorp/go-multierror"
    18  )
    19  
    20  var (
    21  	multiStmtDelimiter = []byte(";")
    22  
    23  	DefaultMigrationsTable       = "schema_migrations"
    24  	DefaultMigrationsTableEngine = "TinyLog"
    25  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    26  
    27  	ErrNilConfig = fmt.Errorf("no config")
    28  )
    29  
    30  type Config struct {
    31  	DatabaseName          string
    32  	ClusterName           string
    33  	MigrationsTable       string
    34  	MigrationsTableEngine string
    35  	MultiStatementEnabled bool
    36  	MultiStatementMaxSize int
    37  }
    38  
    39  func init() {
    40  	database.Register("clickhouse", &ClickHouse{})
    41  }
    42  
    43  func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
    44  	if config == nil {
    45  		return nil, ErrNilConfig
    46  	}
    47  
    48  	if err := conn.Ping(); err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	ch := &ClickHouse{
    53  		conn:   conn,
    54  		config: config,
    55  	}
    56  
    57  	if err := ch.init(); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	return ch, nil
    62  }
    63  
    64  type ClickHouse struct {
    65  	conn     *sql.DB
    66  	config   *Config
    67  	isLocked atomic.Bool
    68  }
    69  
    70  func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
    71  	purl, err := url.Parse(dsn)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	q := migrate.FilterCustomQuery(purl)
    76  	q.Scheme = "tcp"
    77  	conn, err := sql.Open("clickhouse", q.String())
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	multiStatementMaxSize := DefaultMultiStatementMaxSize
    83  	if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
    84  		multiStatementMaxSize, err = strconv.Atoi(s)
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  	}
    89  
    90  	migrationsTableEngine := DefaultMigrationsTableEngine
    91  	if s := purl.Query().Get("x-migrations-table-engine"); len(s) > 0 {
    92  		migrationsTableEngine = s
    93  	}
    94  
    95  	ch = &ClickHouse{
    96  		conn: conn,
    97  		config: &Config{
    98  			MigrationsTable:       purl.Query().Get("x-migrations-table"),
    99  			MigrationsTableEngine: migrationsTableEngine,
   100  			DatabaseName:          strings.TrimLeft(purl.Path, "/"),
   101  			ClusterName:           purl.Query().Get("x-cluster-name"),
   102  			MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true",
   103  			MultiStatementMaxSize: multiStatementMaxSize,
   104  		},
   105  	}
   106  
   107  	if err := ch.init(); err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	return ch, nil
   112  }
   113  
   114  func (ch *ClickHouse) init() error {
   115  	if len(ch.config.DatabaseName) == 0 {
   116  		if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
   117  			return err
   118  		}
   119  	}
   120  
   121  	if len(ch.config.MigrationsTable) == 0 {
   122  		ch.config.MigrationsTable = DefaultMigrationsTable
   123  	}
   124  
   125  	if ch.config.MultiStatementMaxSize <= 0 {
   126  		ch.config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
   127  	}
   128  
   129  	if len(ch.config.MigrationsTableEngine) == 0 {
   130  		ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
   131  	}
   132  
   133  	return ch.ensureVersionTable()
   134  }
   135  
   136  func (ch *ClickHouse) Run(r io.Reader) error {
   137  	if ch.config.MultiStatementEnabled {
   138  		var err error
   139  		if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
   140  			tq := strings.TrimSpace(string(m))
   141  			if tq == "" {
   142  				return true
   143  			}
   144  			if _, e := ch.conn.Exec(string(m)); e != nil {
   145  				err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
   146  				return false
   147  			}
   148  			return true
   149  		}); e != nil {
   150  			return e
   151  		}
   152  		return err
   153  	}
   154  
   155  	migration, err := io.ReadAll(r)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	if _, err := ch.conn.Exec(string(migration)); err != nil {
   161  		return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
   162  	}
   163  
   164  	return nil
   165  }
   166  func (ch *ClickHouse) Version() (int, bool, error) {
   167  	var (
   168  		version int
   169  		dirty   uint8
   170  		query   = "SELECT version, dirty FROM `" + ch.config.MigrationsTable + "` ORDER BY sequence DESC LIMIT 1"
   171  	)
   172  	if err := ch.conn.QueryRow(query).Scan(&version, &dirty); err != nil {
   173  		if err == sql.ErrNoRows {
   174  			return database.NilVersion, false, nil
   175  		}
   176  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   177  	}
   178  	return version, dirty == 1, nil
   179  }
   180  
   181  func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
   182  	var (
   183  		bool = func(v bool) uint8 {
   184  			if v {
   185  				return 1
   186  			}
   187  			return 0
   188  		}
   189  		tx, err = ch.conn.Begin()
   190  	)
   191  	if err != nil {
   192  		return err
   193  	}
   194  
   195  	query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
   196  	stmt, err := tx.Prepare(query)
   197  	if err != nil {
   198  		if errRollback := tx.Rollback(); errRollback != nil {
   199  			return fmt.Errorf("error during prepare statement %w and rollback %s", err, errRollback.Error())
   200  		}
   201  		return err
   202  	}
   203  
   204  	if _, err := stmt.Exec(int64(version), bool(dirty), uint64(time.Now().UnixNano())); err != nil {
   205  		return &database.Error{OrigErr: err, Query: []byte(query)}
   206  	}
   207  
   208  	return tx.Commit()
   209  }
   210  
   211  // ensureVersionTable checks if versions table exists and, if not, creates it.
   212  // Note that this function locks the database, which deviates from the usual
   213  // convention of "caller locks" in the ClickHouse type.
   214  func (ch *ClickHouse) ensureVersionTable() (err error) {
   215  	if err = ch.Lock(); err != nil {
   216  		return err
   217  	}
   218  
   219  	defer func() {
   220  		if e := ch.Unlock(); e != nil {
   221  			if err == nil {
   222  				err = e
   223  			} else {
   224  				err = multierror.Append(err, e)
   225  			}
   226  		}
   227  	}()
   228  
   229  	var (
   230  		table string
   231  		query = "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName) + " LIKE '" + ch.config.MigrationsTable + "'"
   232  	)
   233  	// check if migration table exists
   234  	if err := ch.conn.QueryRow(query).Scan(&table); err != nil {
   235  		if err != sql.ErrNoRows {
   236  			return &database.Error{OrigErr: err, Query: []byte(query)}
   237  		}
   238  	} else {
   239  		return nil
   240  	}
   241  
   242  	// if not, create the empty migration table
   243  	if len(ch.config.ClusterName) > 0 {
   244  		query = fmt.Sprintf(`
   245  			CREATE TABLE %s ON CLUSTER %s (
   246  				version    Int64,
   247  				dirty      UInt8,
   248  				sequence   UInt64
   249  			) Engine=%s`, ch.config.MigrationsTable, ch.config.ClusterName, ch.config.MigrationsTableEngine)
   250  	} else {
   251  		query = fmt.Sprintf(`
   252  			CREATE TABLE %s (
   253  				version    Int64,
   254  				dirty      UInt8,
   255  				sequence   UInt64
   256  			) Engine=%s`, ch.config.MigrationsTable, ch.config.MigrationsTableEngine)
   257  	}
   258  
   259  	if strings.HasSuffix(ch.config.MigrationsTableEngine, "Tree") {
   260  		query = fmt.Sprintf(`%s ORDER BY sequence`, query)
   261  	}
   262  
   263  	if _, err := ch.conn.Exec(query); err != nil {
   264  		return &database.Error{OrigErr: err, Query: []byte(query)}
   265  	}
   266  	return nil
   267  }
   268  
   269  func (ch *ClickHouse) Drop() (err error) {
   270  	query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
   271  	tables, err := ch.conn.Query(query)
   272  
   273  	if err != nil {
   274  		return &database.Error{OrigErr: err, Query: []byte(query)}
   275  	}
   276  	defer func() {
   277  		if errClose := tables.Close(); errClose != nil {
   278  			err = multierror.Append(err, errClose)
   279  		}
   280  	}()
   281  
   282  	for tables.Next() {
   283  		var table string
   284  		if err := tables.Scan(&table); err != nil {
   285  			return err
   286  		}
   287  
   288  		query = "DROP TABLE IF EXISTS " + quoteIdentifier(ch.config.DatabaseName) + "." + quoteIdentifier(table)
   289  
   290  		if _, err := ch.conn.Exec(query); err != nil {
   291  			return &database.Error{OrigErr: err, Query: []byte(query)}
   292  		}
   293  	}
   294  	if err := tables.Err(); err != nil {
   295  		return &database.Error{OrigErr: err, Query: []byte(query)}
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  func (ch *ClickHouse) Lock() error {
   302  	if !ch.isLocked.CAS(false, true) {
   303  		return database.ErrLocked
   304  	}
   305  
   306  	return nil
   307  }
   308  func (ch *ClickHouse) Unlock() error {
   309  	if !ch.isLocked.CAS(true, false) {
   310  		return database.ErrNotLocked
   311  	}
   312  
   313  	return nil
   314  }
   315  func (ch *ClickHouse) Close() error { return ch.conn.Close() }
   316  
   317  // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
   318  func quoteIdentifier(name string) string {
   319  	end := strings.IndexRune(name, 0)
   320  	if end > -1 {
   321  		name = name[:end]
   322  	}
   323  	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
   324  }