github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/util.go (about) 1 package sqlitepool 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/tailscale/sqlite/sqliteh" 8 ) 9 10 // CopyAll copies the contents of one database to another. 11 // 12 // Traditionally this is done in sqlite by closing the database and copying 13 // the file. However it can be useful to do it online: a single exclusive 14 // transaction can cross multiple databases, and if multiple processes are 15 // using a file, this lets one replace the database without first 16 // communicating with the other processes, asking them to close the DB first. 17 // 18 // The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA 19 // schema-name conventions: https://sqlite.org/pragma.html#syntax 20 func CopyAll(db sqliteh.DB, dstSchemaName, srcSchemaName string) (err error) { 21 defer func() { 22 if err != nil { 23 err = fmt.Errorf("sqlitepool.CopyAll: %w", err) 24 } 25 }() 26 if dstSchemaName == "" { 27 dstSchemaName = "main" 28 } 29 if srcSchemaName == "" { 30 srcSchemaName = "main" 31 } 32 if dstSchemaName == srcSchemaName { 33 return fmt.Errorf("source matches destination: %q", srcSchemaName) 34 } 35 // Filter on sql to avoid auto indexes. 36 // See https://www.sqlite.org/schematab.html for sqlite_schema docs. 37 rows, err := Query(db, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName)) 38 if err != nil { 39 return err 40 } 41 defer rows.Close() 42 for rows.Next() { 43 var name, sqlType, sqlText string 44 if err := rows.Scan(&name, &sqlType, &sqlText); err != nil { 45 return err 46 } 47 // Regardless of the case or whitespace used in the original 48 // create statement (or whether or not "if not exists" is used), 49 // the SQL text in the sqlite_schema table always reads: 50 // "CREATE (TABLE|VIEW|INDEX|TRIGGER) name". 51 // We take advantage of that here to rewrite the create 52 // statement for a different schema. 53 switch sqlType { 54 case "index": 55 sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ") 56 sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText) 57 if err := ExecScript(db, sqlText); err != nil { 58 return err 59 } 60 case "table": 61 sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ") 62 sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText) 63 if err := ExecScript(db, sqlText); err != nil { 64 return err 65 } 66 if err := ExecScript(db, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil { 67 return err 68 } 69 case "trigger": 70 sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ") 71 sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText) 72 if err := ExecScript(db, sqlText); err != nil { 73 return err 74 } 75 case "view": 76 sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ") 77 sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText) 78 if err := ExecScript(db, sqlText); err != nil { 79 return err 80 } 81 default: 82 return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name) 83 } 84 } 85 return rows.Err() 86 }