github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/util.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"strings"
     8  )
     9  
    10  // DropAll deletes all the data from a database.
    11  //
    12  // The schemaName parameter follows the SQLite PRAMGA schema-name conventions:
    13  // https://sqlite.org/pragma.html#syntax
    14  func DropAll(ctx context.Context, conn *sql.Conn, schemaName string) (err error) {
    15  	defer func() {
    16  		if err != nil {
    17  			err = fmt.Errorf("sqlitedb.DropAll: %w", err)
    18  		}
    19  	}()
    20  
    21  	if schemaName == "" {
    22  		schemaName = "main"
    23  	}
    24  
    25  	var indexes, tables, triggers, views []string
    26  
    27  	// Filter on sql to avoid auto indexes.
    28  	// See https://www.sqlite.org/schematab.html for sqlite_schema docs.
    29  	rows, err := conn.QueryContext(ctx, fmt.Sprintf("SELECT name, type FROM %q.sqlite_schema WHERE sql != ''", schemaName))
    30  	if err != nil {
    31  		return err
    32  	}
    33  	defer rows.Close()
    34  	for rows.Next() {
    35  		var name, sqlType string
    36  		if err := rows.Scan(&name, &sqlType); err != nil {
    37  			return err
    38  		}
    39  		switch sqlType {
    40  		case "index":
    41  			indexes = append(indexes, name)
    42  		case "table":
    43  			tables = append(tables, name)
    44  		case "trigger":
    45  			triggers = append(triggers, name)
    46  		case "view":
    47  			views = append(views, name)
    48  		default:
    49  			return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name)
    50  		}
    51  	}
    52  	rows.Close()
    53  	if err := rows.Err(); err != nil {
    54  		return err
    55  	}
    56  
    57  	for _, name := range indexes {
    58  		if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX %q.%q", schemaName, name)); err != nil {
    59  			return err
    60  		}
    61  	}
    62  	for _, name := range triggers {
    63  		if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TRIGGER %q.%q", schemaName, name)); err != nil {
    64  			return err
    65  		}
    66  	}
    67  	for _, name := range views {
    68  		if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP VIEW %q.%q", schemaName, name)); err != nil {
    69  			return err
    70  		}
    71  	}
    72  	for _, name := range tables {
    73  		if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE %q.%q", schemaName, name)); err != nil {
    74  			return err
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  // CopyAll copies the contents of one database to another.
    81  //
    82  // Traditionally this is done in sqlite by closing the database and copying
    83  // the file. However it can be useful to do it online: a single exclusive
    84  // transaction can cross multiple databases, and if multiple processes are
    85  // using a file, this lets one replace the database without first
    86  // communicating with the other processes, asking them to close the DB first.
    87  //
    88  // The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA
    89  // schema-name conventions: https://sqlite.org/pragma.html#syntax
    90  func CopyAll(ctx context.Context, conn *sql.Conn, dstSchemaName, srcSchemaName string) (err error) {
    91  	defer func() {
    92  		if err != nil {
    93  			err = fmt.Errorf("sqlitedb.CopyAll: %w", err)
    94  		}
    95  	}()
    96  	if dstSchemaName == "" {
    97  		dstSchemaName = "main"
    98  	}
    99  	if srcSchemaName == "" {
   100  		srcSchemaName = "main"
   101  	}
   102  	if dstSchemaName == srcSchemaName {
   103  		return fmt.Errorf("source matches destination: %q", srcSchemaName)
   104  	}
   105  	// Filter on sql to avoid auto indexes.
   106  	// See https://www.sqlite.org/schematab.html for sqlite_schema docs.
   107  	rows, err := conn.QueryContext(ctx, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName))
   108  	if err != nil {
   109  		return err
   110  	}
   111  	defer rows.Close()
   112  	for rows.Next() {
   113  		var name, sqlType, sqlText string
   114  		if err := rows.Scan(&name, &sqlType, &sqlText); err != nil {
   115  			return err
   116  		}
   117  		// Regardless of the case or whitespace used in the original
   118  		// create statement (or whether or not "if not exists" is used),
   119  		// the SQL text in the sqlite_schema table always reads:
   120  		// 	"CREATE (TABLE|VIEW|INDEX|TRIGGER) name".
   121  		// We take advantage of that here to rewrite the create
   122  		// statement for a different schema.
   123  		switch sqlType {
   124  		case "index":
   125  			sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ")
   126  			sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText)
   127  			if _, err := conn.ExecContext(ctx, sqlText); err != nil {
   128  				return err
   129  			}
   130  		case "table":
   131  			sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ")
   132  			sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText)
   133  			if _, err := conn.ExecContext(ctx, sqlText); err != nil {
   134  				return err
   135  			}
   136  			if _, err := conn.ExecContext(ctx, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil {
   137  				return err
   138  			}
   139  		case "trigger":
   140  			sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ")
   141  			sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText)
   142  			if _, err := conn.ExecContext(ctx, sqlText); err != nil {
   143  				return err
   144  			}
   145  		case "view":
   146  			sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ")
   147  			sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText)
   148  			if _, err := conn.ExecContext(ctx, sqlText); err != nil {
   149  				return err
   150  			}
   151  		default:
   152  			return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name)
   153  		}
   154  	}
   155  	return rows.Err()
   156  }