github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/util.go (about) 1 package sqlite 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "strings" 8 ) 9 10 // DropAll deletes all the data from a database. 11 // 12 // The schemaName parameter follows the SQLite PRAMGA schema-name conventions: 13 // https://sqlite.org/pragma.html#syntax 14 func DropAll(ctx context.Context, conn *sql.Conn, schemaName string) (err error) { 15 defer func() { 16 if err != nil { 17 err = fmt.Errorf("sqlitedb.DropAll: %w", err) 18 } 19 }() 20 21 if schemaName == "" { 22 schemaName = "main" 23 } 24 25 var indexes, tables, triggers, views []string 26 27 // Filter on sql to avoid auto indexes. 28 // See https://www.sqlite.org/schematab.html for sqlite_schema docs. 29 rows, err := conn.QueryContext(ctx, fmt.Sprintf("SELECT name, type FROM %q.sqlite_schema WHERE sql != ''", schemaName)) 30 if err != nil { 31 return err 32 } 33 defer rows.Close() 34 for rows.Next() { 35 var name, sqlType string 36 if err := rows.Scan(&name, &sqlType); err != nil { 37 return err 38 } 39 switch sqlType { 40 case "index": 41 indexes = append(indexes, name) 42 case "table": 43 tables = append(tables, name) 44 case "trigger": 45 triggers = append(triggers, name) 46 case "view": 47 views = append(views, name) 48 default: 49 return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name) 50 } 51 } 52 rows.Close() 53 if err := rows.Err(); err != nil { 54 return err 55 } 56 57 for _, name := range indexes { 58 if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX %q.%q", schemaName, name)); err != nil { 59 return err 60 } 61 } 62 for _, name := range triggers { 63 if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TRIGGER %q.%q", schemaName, name)); err != nil { 64 return err 65 } 66 } 67 for _, name := range views { 68 if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP VIEW %q.%q", schemaName, name)); err != nil { 69 return err 70 } 71 } 72 for _, name := range tables { 73 if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE %q.%q", schemaName, name)); err != nil { 74 return err 75 } 76 } 77 return nil 78 } 79 80 // CopyAll copies the contents of one database to another. 81 // 82 // Traditionally this is done in sqlite by closing the database and copying 83 // the file. However it can be useful to do it online: a single exclusive 84 // transaction can cross multiple databases, and if multiple processes are 85 // using a file, this lets one replace the database without first 86 // communicating with the other processes, asking them to close the DB first. 87 // 88 // The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA 89 // schema-name conventions: https://sqlite.org/pragma.html#syntax 90 func CopyAll(ctx context.Context, conn *sql.Conn, dstSchemaName, srcSchemaName string) (err error) { 91 defer func() { 92 if err != nil { 93 err = fmt.Errorf("sqlitedb.CopyAll: %w", err) 94 } 95 }() 96 if dstSchemaName == "" { 97 dstSchemaName = "main" 98 } 99 if srcSchemaName == "" { 100 srcSchemaName = "main" 101 } 102 if dstSchemaName == srcSchemaName { 103 return fmt.Errorf("source matches destination: %q", srcSchemaName) 104 } 105 // Filter on sql to avoid auto indexes. 106 // See https://www.sqlite.org/schematab.html for sqlite_schema docs. 107 rows, err := conn.QueryContext(ctx, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName)) 108 if err != nil { 109 return err 110 } 111 defer rows.Close() 112 for rows.Next() { 113 var name, sqlType, sqlText string 114 if err := rows.Scan(&name, &sqlType, &sqlText); err != nil { 115 return err 116 } 117 // Regardless of the case or whitespace used in the original 118 // create statement (or whether or not "if not exists" is used), 119 // the SQL text in the sqlite_schema table always reads: 120 // "CREATE (TABLE|VIEW|INDEX|TRIGGER) name". 121 // We take advantage of that here to rewrite the create 122 // statement for a different schema. 123 switch sqlType { 124 case "index": 125 sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ") 126 sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText) 127 if _, err := conn.ExecContext(ctx, sqlText); err != nil { 128 return err 129 } 130 case "table": 131 sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ") 132 sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText) 133 if _, err := conn.ExecContext(ctx, sqlText); err != nil { 134 return err 135 } 136 if _, err := conn.ExecContext(ctx, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil { 137 return err 138 } 139 case "trigger": 140 sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ") 141 sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText) 142 if _, err := conn.ExecContext(ctx, sqlText); err != nil { 143 return err 144 } 145 case "view": 146 sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ") 147 sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText) 148 if _, err := conn.ExecContext(ctx, sqlText); err != nil { 149 return err 150 } 151 default: 152 return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name) 153 } 154 } 155 return rows.Err() 156 }