github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/clickhouse/clickhouse.go (about)

     1  package clickhouse
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"net/url"
     9  	"regexp"
    10  	"sort"
    11  	"strings"
    12  
    13  	"github.com/amacneil/dbmate/pkg/dbmate"
    14  	"github.com/amacneil/dbmate/pkg/dbutil"
    15  
    16  	"github.com/ClickHouse/clickhouse-go"
    17  )
    18  
    19  func init() {
    20  	dbmate.RegisterDriver(NewDriver, "clickhouse")
    21  }
    22  
    23  // Driver provides top level database functions
    24  type Driver struct {
    25  	migrationsTableName string
    26  	databaseURL         *url.URL
    27  	log                 io.Writer
    28  }
    29  
    30  // NewDriver initializes the driver
    31  func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
    32  	return &Driver{
    33  		migrationsTableName: config.MigrationsTableName,
    34  		databaseURL:         config.DatabaseURL,
    35  		log:                 config.Log,
    36  	}
    37  }
    38  
    39  func connectionString(initialURL *url.URL) string {
    40  	u := *initialURL
    41  
    42  	u.Scheme = "tcp"
    43  	host := u.Host
    44  	if u.Port() == "" {
    45  		host = fmt.Sprintf("%s:9000", host)
    46  	}
    47  	u.Host = host
    48  
    49  	query := u.Query()
    50  	if query.Get("username") == "" && u.User.Username() != "" {
    51  		query.Set("username", u.User.Username())
    52  	}
    53  	password, passwordSet := u.User.Password()
    54  	if query.Get("password") == "" && passwordSet {
    55  		query.Set("password", password)
    56  	}
    57  	u.User = nil
    58  
    59  	if query.Get("database") == "" {
    60  		path := strings.Trim(u.Path, "/")
    61  		if path != "" {
    62  			query.Set("database", path)
    63  			u.Path = ""
    64  		}
    65  	}
    66  	u.RawQuery = query.Encode()
    67  
    68  	return u.String()
    69  }
    70  
    71  // Open creates a new database connection
    72  func (drv *Driver) Open() (*sql.DB, error) {
    73  	return sql.Open("clickhouse", connectionString(drv.databaseURL))
    74  }
    75  
    76  func (drv *Driver) openClickHouseDB() (*sql.DB, error) {
    77  	// clone databaseURL
    78  	clickhouseURL, err := url.Parse(connectionString(drv.databaseURL))
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	// connect to clickhouse database
    84  	values := clickhouseURL.Query()
    85  	values.Set("database", "default")
    86  	clickhouseURL.RawQuery = values.Encode()
    87  
    88  	return sql.Open("clickhouse", clickhouseURL.String())
    89  }
    90  
    91  func (drv *Driver) databaseName() string {
    92  	name := dbutil.MustParseURL(connectionString(drv.databaseURL)).Query().Get("database")
    93  	if name == "" {
    94  		name = "default"
    95  	}
    96  	return name
    97  }
    98  
    99  var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
   100  
   101  func (drv *Driver) quoteIdentifier(str string) string {
   102  	if clickhouseValidIdentifier.MatchString(str) {
   103  		return str
   104  	}
   105  
   106  	str = strings.Replace(str, `"`, `""`, -1)
   107  
   108  	return fmt.Sprintf(`"%s"`, str)
   109  }
   110  
   111  // CreateDatabase creates the specified database
   112  func (drv *Driver) CreateDatabase() error {
   113  	name := drv.databaseName()
   114  	fmt.Fprintf(drv.log, "Creating: %s\n", name)
   115  
   116  	db, err := drv.openClickHouseDB()
   117  	if err != nil {
   118  		return err
   119  	}
   120  	defer dbutil.MustClose(db)
   121  
   122  	_, err = db.Exec("create database " + drv.quoteIdentifier(name))
   123  
   124  	return err
   125  }
   126  
   127  // DropDatabase drops the specified database (if it exists)
   128  func (drv *Driver) DropDatabase() error {
   129  	name := drv.databaseName()
   130  	fmt.Fprintf(drv.log, "Dropping: %s\n", name)
   131  
   132  	db, err := drv.openClickHouseDB()
   133  	if err != nil {
   134  		return err
   135  	}
   136  	defer dbutil.MustClose(db)
   137  
   138  	_, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
   139  
   140  	return err
   141  }
   142  
   143  func (drv *Driver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
   144  	buf.WriteString("\n--\n-- Database schema\n--\n\n")
   145  
   146  	buf.WriteString("CREATE DATABASE IF NOT EXISTS" + drv.quoteIdentifier(databaseName) + ";\n\n")
   147  
   148  	tables, err := dbutil.QueryColumn(db, "show tables")
   149  	if err != nil {
   150  		return err
   151  	}
   152  	sort.Strings(tables)
   153  
   154  	for _, table := range tables {
   155  		var clause string
   156  		err = db.QueryRow("show create table " + drv.quoteIdentifier(table)).Scan(&clause)
   157  		if err != nil {
   158  			return err
   159  		}
   160  		buf.WriteString(clause + ";\n\n")
   161  	}
   162  	return nil
   163  }
   164  
   165  func (drv *Driver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
   166  	migrationsTable := drv.quotedMigrationsTableName()
   167  
   168  	// load applied migrations
   169  	migrations, err := dbutil.QueryColumn(db,
   170  		fmt.Sprintf("select version from %s final ", migrationsTable)+
   171  			"where applied order by version asc",
   172  	)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	quoter := strings.NewReplacer(`\`, `\\`, `'`, `\'`)
   178  	for i := range migrations {
   179  		migrations[i] = "'" + quoter.Replace(migrations[i]) + "'"
   180  	}
   181  
   182  	// build schema migrations table data
   183  	buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
   184  
   185  	if len(migrations) > 0 {
   186  		buf.WriteString(
   187  			fmt.Sprintf("INSERT INTO %s (version) VALUES\n    (", migrationsTable) +
   188  				strings.Join(migrations, "),\n    (") +
   189  				");\n")
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  // DumpSchema returns the current database schema
   196  func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
   197  	var buf bytes.Buffer
   198  	var err error
   199  
   200  	err = drv.schemaDump(db, &buf, drv.databaseName())
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	err = drv.schemaMigrationsDump(db, &buf)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	return buf.Bytes(), nil
   211  }
   212  
   213  // DatabaseExists determines whether the database exists
   214  func (drv *Driver) DatabaseExists() (bool, error) {
   215  	name := drv.databaseName()
   216  
   217  	db, err := drv.openClickHouseDB()
   218  	if err != nil {
   219  		return false, err
   220  	}
   221  	defer dbutil.MustClose(db)
   222  
   223  	exists := false
   224  	err = db.QueryRow("SELECT 1 FROM system.databases where name = ?", name).
   225  		Scan(&exists)
   226  	if err == sql.ErrNoRows {
   227  		return false, nil
   228  	}
   229  
   230  	return exists, err
   231  }
   232  
   233  // MigrationsTableExists checks if the schema_migrations table exists
   234  func (drv *Driver) MigrationsTableExists(db *sql.DB) (bool, error) {
   235  	exists := false
   236  	err := db.QueryRow(fmt.Sprintf("EXISTS TABLE %s", drv.quotedMigrationsTableName())).
   237  		Scan(&exists)
   238  	if err == sql.ErrNoRows {
   239  		return false, nil
   240  	}
   241  
   242  	return exists, err
   243  }
   244  
   245  // CreateMigrationsTable creates the schema migrations table
   246  func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
   247  	_, err := db.Exec(fmt.Sprintf(`
   248  		create table if not exists %s (
   249  			version String,
   250  			ts DateTime default now(),
   251  			applied UInt8 default 1
   252  		) engine = ReplacingMergeTree(ts)
   253  		primary key version
   254  		order by version
   255  	`, drv.quotedMigrationsTableName()))
   256  
   257  	return err
   258  }
   259  
   260  // SelectMigrations returns a list of applied migrations
   261  // with an optional limit (in descending order)
   262  func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
   263  	query := fmt.Sprintf("select version from %s final where applied order by version desc",
   264  		drv.quotedMigrationsTableName())
   265  
   266  	if limit >= 0 {
   267  		query = fmt.Sprintf("%s limit %d", query, limit)
   268  	}
   269  	rows, err := db.Query(query)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	defer dbutil.MustClose(rows)
   275  
   276  	migrations := map[string]bool{}
   277  	for rows.Next() {
   278  		var version string
   279  		if err := rows.Scan(&version); err != nil {
   280  			return nil, err
   281  		}
   282  
   283  		migrations[version] = true
   284  	}
   285  
   286  	if err = rows.Err(); err != nil {
   287  		return nil, err
   288  	}
   289  
   290  	return migrations, nil
   291  }
   292  
   293  // InsertMigration adds a new migration record
   294  func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
   295  	_, err := db.Exec(
   296  		fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
   297  		version)
   298  
   299  	return err
   300  }
   301  
   302  // DeleteMigration removes a migration record
   303  func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
   304  	_, err := db.Exec(
   305  		fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
   306  			drv.quotedMigrationsTableName()),
   307  		version, false,
   308  	)
   309  
   310  	return err
   311  }
   312  
   313  // Ping verifies a connection to the database server. It does not verify whether the
   314  // specified database exists.
   315  func (drv *Driver) Ping() error {
   316  	// attempt connection to primary database, not "clickhouse" database
   317  	// to support servers with no "clickhouse" database
   318  	// (see https://github.com/amacneil/dbmate/issues/78)
   319  	db, err := drv.Open()
   320  	if err != nil {
   321  		return err
   322  	}
   323  	defer dbutil.MustClose(db)
   324  
   325  	err = db.Ping()
   326  	if err == nil {
   327  		return nil
   328  	}
   329  
   330  	// ignore 'Database foo doesn't exist' error
   331  	chErr, ok := err.(*clickhouse.Exception)
   332  	if ok && chErr.Code == 81 {
   333  		return nil
   334  	}
   335  
   336  	return err
   337  }
   338  
   339  func (drv *Driver) quotedMigrationsTableName() string {
   340  	return drv.quoteIdentifier(drv.migrationsTableName)
   341  }