github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/util.go (about)

     1  package sqlitepool
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/tailscale/sqlite/sqliteh"
     8  )
     9  
    10  // CopyAll copies the contents of one database to another.
    11  //
    12  // Traditionally this is done in sqlite by closing the database and copying
    13  // the file. However it can be useful to do it online: a single exclusive
    14  // transaction can cross multiple databases, and if multiple processes are
    15  // using a file, this lets one replace the database without first
    16  // communicating with the other processes, asking them to close the DB first.
    17  //
    18  // The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA
    19  // schema-name conventions: https://sqlite.org/pragma.html#syntax
    20  func CopyAll(db sqliteh.DB, dstSchemaName, srcSchemaName string) (err error) {
    21  	defer func() {
    22  		if err != nil {
    23  			err = fmt.Errorf("sqlitepool.CopyAll: %w", err)
    24  		}
    25  	}()
    26  	if dstSchemaName == "" {
    27  		dstSchemaName = "main"
    28  	}
    29  	if srcSchemaName == "" {
    30  		srcSchemaName = "main"
    31  	}
    32  	if dstSchemaName == srcSchemaName {
    33  		return fmt.Errorf("source matches destination: %q", srcSchemaName)
    34  	}
    35  	// Filter on sql to avoid auto indexes.
    36  	// See https://www.sqlite.org/schematab.html for sqlite_schema docs.
    37  	rows, err := Query(db, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName))
    38  	if err != nil {
    39  		return err
    40  	}
    41  	defer rows.Close()
    42  	for rows.Next() {
    43  		var name, sqlType, sqlText string
    44  		if err := rows.Scan(&name, &sqlType, &sqlText); err != nil {
    45  			return err
    46  		}
    47  		// Regardless of the case or whitespace used in the original
    48  		// create statement (or whether or not "if not exists" is used),
    49  		// the SQL text in the sqlite_schema table always reads:
    50  		// 	"CREATE (TABLE|VIEW|INDEX|TRIGGER) name".
    51  		// We take advantage of that here to rewrite the create
    52  		// statement for a different schema.
    53  		switch sqlType {
    54  		case "index":
    55  			sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ")
    56  			sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText)
    57  			if err := ExecScript(db, sqlText); err != nil {
    58  				return err
    59  			}
    60  		case "table":
    61  			sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ")
    62  			sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText)
    63  			if err := ExecScript(db, sqlText); err != nil {
    64  				return err
    65  			}
    66  			if err := ExecScript(db, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil {
    67  				return err
    68  			}
    69  		case "trigger":
    70  			sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ")
    71  			sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText)
    72  			if err := ExecScript(db, sqlText); err != nil {
    73  				return err
    74  			}
    75  		case "view":
    76  			sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ")
    77  			sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText)
    78  			if err := ExecScript(db, sqlText); err != nil {
    79  				return err
    80  			}
    81  		default:
    82  			return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name)
    83  		}
    84  	}
    85  	return rows.Err()
    86  }