github.com/wrgl/wrgl@v0.14.0/pkg/ref/sql/store.go (about)

     1  package refsql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/google/uuid"
    10  	"github.com/wrgl/wrgl/pkg/ref"
    11  	"github.com/wrgl/wrgl/pkg/sqlutil"
    12  )
    13  
    14  var CreateTableStmts = []string{
    15  	`CREATE TABLE refs (
    16  		name TEXT NOT NULL PRIMARY KEY,
    17  		sum  BLOB NOT NULL
    18  	)`,
    19  	`CREATE TABLE transactions (
    20  		id     BLOB NOT NULL PRIMARY KEY,
    21  		status TEXT NOT NULL,
    22  		begin  DATETIME NOT NULL,
    23  		end    DATETIME
    24  	)`,
    25  	`CREATE TABLE reflogs (
    26  		ref         TEXT NOT NULL,
    27  		ordinal     INTEGER NOT NULL,
    28  		oldoid      BLOB,
    29  		newoid      BLOB NOT NULL,
    30  		authorname  TEXT NOT NULL DEFAULT '',
    31  		authoremail TEXT NOT NULL DEFAULT '',
    32  		time        DATETIME NOT NULL,
    33  		action      TEXT NOT NULL DEFAULT '',
    34  		message     TEXT NOT NULL DEFAULT '',
    35  		txid        BLOB,
    36  		PRIMARY KEY (ref, ordinal),
    37  		FOREIGN KEY (ref) REFERENCES refs(name),
    38  		FOREIGN KEY (txid) REFERENCES transactions(id)
    39  	)`,
    40  }
    41  
    42  type Store struct {
    43  	db *sql.DB
    44  }
    45  
    46  func NewStore(db *sql.DB) *Store {
    47  	s := &Store{
    48  		db: db,
    49  	}
    50  	return s
    51  }
    52  
    53  func (s *Store) Set(key string, sum []byte) error {
    54  	_, err := s.db.Exec(`INSERT INTO refs (name, sum) VALUES (?, ?) ON CONFLICT (name) DO UPDATE SET sum=excluded.sum`, key, sum)
    55  	return err
    56  }
    57  
    58  func (s *Store) Get(key string) ([]byte, error) {
    59  	row := s.db.QueryRow(`SELECT sum FROM refs WHERE name = ?`, key)
    60  	sum := make([]byte, 16)
    61  	if err := row.Scan(&sum); err != nil {
    62  		return nil, ref.ErrKeyNotFound
    63  	}
    64  	return sum, nil
    65  }
    66  
    67  func (s *Store) SetWithLog(key string, sum []byte, rl *ref.Reflog) error {
    68  	return sqlutil.RunInTx(s.db, func(tx *sql.Tx) error {
    69  		row := tx.QueryRow(`SELECT sum FROM refs WHERE name = ?`, key)
    70  		oldSum := make([]byte, 16)
    71  		if err := row.Scan(&oldSum); err != nil {
    72  			oldSum = nil
    73  		}
    74  		if _, err := tx.Exec(
    75  			`INSERT INTO refs (name, sum) VALUES (?, ?) ON CONFLICT (name) DO UPDATE SET sum=excluded.sum`,
    76  			key, sum,
    77  		); err != nil {
    78  			return err
    79  		}
    80  		var txid []byte
    81  		if rl.Txid != nil {
    82  			txid = (*rl.Txid)[:]
    83  		}
    84  		if _, err := tx.Exec(
    85  			`INSERT INTO reflogs (
    86  				ref, ordinal, oldoid, newoid, authorname, authoremail, time, action, message, txid
    87  			) VALUES (
    88  				?, (
    89  					SELECT COUNT(*)+1 FROM reflogs WHERE ref = ?
    90  				), ?, ?, ?, ?, ?, ?, ?, ?
    91  			)`,
    92  			key, key, oldSum, sum, rl.AuthorName, rl.AuthorEmail, rl.Time, rl.Action, rl.Message, txid,
    93  		); err != nil {
    94  			return err
    95  		}
    96  		return nil
    97  	})
    98  }
    99  
   100  func (s *Store) Delete(key string) error {
   101  	return sqlutil.RunInTx(s.db, func(tx *sql.Tx) error {
   102  		if _, err := tx.Exec(`DELETE FROM reflogs WHERE ref = ?`, key); err != nil {
   103  			return err
   104  		}
   105  		_, err := tx.Exec(`DELETE FROM refs WHERE name = ?`, key)
   106  		return err
   107  	})
   108  }
   109  
   110  func filterQuery(q string, prefixes []string, notPrefixes []string) (string, []interface{}) {
   111  	conds := []string{}
   112  	args := []interface{}{}
   113  	for _, s := range prefixes {
   114  		conds = append(conds, "name LIKE ?")
   115  		args = append(args, s+"%")
   116  	}
   117  	if len(conds) > 1 {
   118  		conds = []string{fmt.Sprintf("(%s)", strings.Join(conds, " OR "))}
   119  	}
   120  	for _, s := range notPrefixes {
   121  		conds = append(conds, "name NOT LIKE ?")
   122  		args = append(args, s+"%")
   123  	}
   124  	if len(conds) > 0 {
   125  		q = fmt.Sprintf("%s WHERE %s", q, strings.Join(conds, " AND "))
   126  	}
   127  	return q, args
   128  }
   129  
   130  func (s *Store) Filter(prefixes []string, notPrefixes []string) (m map[string][]byte, err error) {
   131  	q, args := filterQuery("SELECT name, sum FROM refs", prefixes, notPrefixes)
   132  	rows, err := s.db.Query(q, args...)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	defer rows.Close()
   137  	m = map[string][]byte{}
   138  	for rows.Next() {
   139  		var name string
   140  		var sum = make([]byte, 16)
   141  		if err = rows.Scan(&name, &sum); err != nil {
   142  			return nil, err
   143  		}
   144  		m[name] = sum
   145  	}
   146  	if err = rows.Err(); err != nil {
   147  		return nil, err
   148  	}
   149  	return m, nil
   150  }
   151  
   152  func (s *Store) FilterKey(prefixes []string, notPrefixes []string) (keys []string, err error) {
   153  	q, args := filterQuery("SELECT name FROM refs", prefixes, notPrefixes)
   154  	rows, err := s.db.Query(fmt.Sprintf(`%s ORDER BY name`, q), args...)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	defer rows.Close()
   159  	for rows.Next() {
   160  		var name string
   161  		if err = rows.Scan(&name); err != nil {
   162  			return nil, err
   163  		}
   164  		keys = append(keys, name)
   165  	}
   166  	if err = rows.Err(); err != nil {
   167  		return nil, err
   168  	}
   169  	return keys, nil
   170  }
   171  
   172  func (s *Store) Rename(oldKey, newKey string) (err error) {
   173  	return sqlutil.RunInTx(s.db, func(tx *sql.Tx) error {
   174  		row := tx.QueryRow(`SELECT sum FROM refs WHERE name = ?`, oldKey)
   175  		sum := make([]byte, 16)
   176  		if err := row.Scan(&sum); err != nil {
   177  			return err
   178  		}
   179  		if _, err := tx.Exec(`INSERT INTO refs (name, sum) VALUES (?, ?)`, newKey, sum); err != nil {
   180  			return err
   181  		}
   182  		if _, err := tx.Exec(`UPDATE reflogs SET ref = ? WHERE ref = ?`, newKey, oldKey); err != nil {
   183  			return err
   184  		}
   185  		if _, err := tx.Exec(`DELETE FROM refs WHERE name = ?`, oldKey); err != nil {
   186  			return err
   187  		}
   188  		return nil
   189  	})
   190  }
   191  
   192  func (s *Store) Copy(srcKey, dstKey string) (err error) {
   193  	return sqlutil.RunInTx(s.db, func(tx *sql.Tx) error {
   194  		if _, err := tx.Exec(
   195  			`INSERT INTO refs (name, sum) VALUES (?, (SELECT sum FROM refs WHERE name = ?))`,
   196  			dstKey, srcKey,
   197  		); err != nil {
   198  			return err
   199  		}
   200  		if _, err := tx.Exec(
   201  			`INSERT INTO reflogs
   202  			SELECT ? AS ref, ordinal, oldoid, newoid, authorname, authoremail, time, action, message, txid
   203  			FROM reflogs WHERE ref = ?`,
   204  			dstKey, srcKey,
   205  		); err != nil {
   206  			return err
   207  		}
   208  		return nil
   209  	})
   210  }
   211  
   212  func (s *Store) LogReader(key string) (ref.ReflogReader, error) {
   213  	row := s.db.QueryRow(`SELECT COUNT(*) FROM reflogs WHERE ref = ?`, key)
   214  	var c int
   215  	if err := row.Scan(&c); err != nil {
   216  		return &ReflogReader{}, nil
   217  	}
   218  	if c == 0 {
   219  		return nil, ref.ErrKeyNotFound
   220  	}
   221  	return &ReflogReader{db: s.db, ref: key, ordinal: c}, nil
   222  }
   223  
   224  func (s *Store) NewTransaction(tx *ref.Transaction) (*uuid.UUID, error) {
   225  	if tx == nil {
   226  		id := uuid.New()
   227  		tx = &ref.Transaction{
   228  			ID:     id,
   229  			Status: ref.TSInProgress,
   230  			Begin:  time.Now(),
   231  		}
   232  	}
   233  	var err error
   234  	if tx.End.IsZero() {
   235  		_, err = s.db.Exec(
   236  			`INSERT INTO transactions (id, status, begin) VALUES (?, ?, ?)`,
   237  			tx.ID[:], tx.Status, tx.Begin,
   238  		)
   239  
   240  	} else {
   241  		_, err = s.db.Exec(
   242  			`INSERT INTO transactions (id, status, begin, end) VALUES (?, ?, ?, ?)`,
   243  			tx.ID[:], tx.Status, tx.Begin, tx.End,
   244  		)
   245  	}
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	return &tx.ID, nil
   250  }
   251  
   252  func (s *Store) GetTransaction(id uuid.UUID) (*ref.Transaction, error) {
   253  	row := s.db.QueryRow(`SELECT status, begin, end FROM transactions WHERE id = ?`, id[:])
   254  	tx := &ref.Transaction{
   255  		ID: id,
   256  	}
   257  	end := sql.NullTime{}
   258  	if err := row.Scan(&tx.Status, &tx.Begin, &end); err != nil {
   259  		return nil, err
   260  	}
   261  	if end.Valid {
   262  		tx.End = end.Time
   263  	}
   264  	return tx, nil
   265  }
   266  
   267  func (s *Store) UpdateTransaction(tx *ref.Transaction) error {
   268  	_, err := s.db.Exec(
   269  		`UPDATE transactions SET status = ?, begin = ?, end = ? WHERE id = ?`,
   270  		tx.Status, tx.Begin, tx.End, tx.ID[:],
   271  	)
   272  	return err
   273  }
   274  
   275  func (s *Store) DeleteTransaction(id uuid.UUID) error {
   276  	return sqlutil.RunInTx(s.db, func(tx *sql.Tx) error {
   277  		row := tx.QueryRow(`SELECT status FROM transactions WHERE id = ?`, id[:])
   278  		var status ref.TransactionStatus
   279  		if err := row.Scan(&status); err != nil {
   280  			return err
   281  		}
   282  		if status == ref.TSCommitted {
   283  			return fmt.Errorf("cannot discard committed transaction")
   284  		}
   285  		_, err := tx.Exec(`DELETE FROM transactions WHERE id = ?`, id[:])
   286  		return err
   287  	})
   288  }
   289  
   290  func (s *Store) GCTransactions(txTTL time.Duration) (ids []uuid.UUID, err error) {
   291  	cutOffTime := time.Now().Add(-txTTL)
   292  	rows, err := s.db.Query(
   293  		`DELETE FROM transactions WHERE status = ? AND begin <= ? RETURNING id`,
   294  		ref.TSInProgress, cutOffTime,
   295  	)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  	defer rows.Close()
   300  	for rows.Next() {
   301  		var id uuid.UUID
   302  		if err = rows.Scan(&id); err != nil {
   303  			return nil, err
   304  		}
   305  		ids = append(ids, id)
   306  	}
   307  	if err = rows.Err(); err != nil {
   308  		return nil, err
   309  	}
   310  	return ids, nil
   311  }
   312  
   313  func (s *Store) GetTransactionLogs(txid uuid.UUID) (logs map[string]*ref.Reflog, err error) {
   314  	rows, err := s.db.Query(
   315  		`SELECT ref, oldoid, newoid, authorname, authoremail, time, action, message
   316  		FROM reflogs WHERE txid = ? ORDER BY ref ASC`,
   317  		txid[:],
   318  	)
   319  	if err != nil {
   320  		return
   321  	}
   322  	defer rows.Close()
   323  	var (
   324  		name                    string
   325  		oldOID, newOID          sqlutil.NullBlob
   326  		authorName, authorEmail string
   327  		ts                      time.Time
   328  		action, message         string
   329  	)
   330  	logs = map[string]*ref.Reflog{}
   331  	if err = sqlutil.QueryRows(s.db,
   332  		`SELECT ref, oldoid, newoid, authorname, authoremail, time, action, message
   333  		FROM reflogs WHERE txid = ?`,
   334  		[]interface{}{txid[:]},
   335  		[]interface{}{&name, &oldOID, &newOID, &authorName, &authorEmail, &ts, &action, &message},
   336  		func() error {
   337  			rl := &ref.Reflog{
   338  				Txid:        &txid,
   339  				NewOID:      make([]byte, 16),
   340  				AuthorName:  authorName,
   341  				AuthorEmail: authorEmail,
   342  				Time:        ts,
   343  				Action:      action,
   344  				Message:     message,
   345  			}
   346  			copy(rl.NewOID, newOID.Blob)
   347  			if oldOID.Valid {
   348  				rl.OldOID = make([]byte, 16)
   349  				copy(rl.OldOID, oldOID.Blob)
   350  			}
   351  			logs[name] = rl
   352  			return nil
   353  		},
   354  	); err != nil {
   355  		return nil, err
   356  	}
   357  	return logs, nil
   358  }
   359  
   360  func (s *Store) CountTransactions() (int, error) {
   361  	row := s.db.QueryRow(`SELECT COUNT(*) FROM transactions`)
   362  	var c int
   363  	if err := row.Scan(&c); err != nil {
   364  		return 0, err
   365  	}
   366  	return c, nil
   367  }
   368  
   369  func (s *Store) ListTransactions(offset, limit int) (txs []*ref.Transaction, err error) {
   370  	var (
   371  		id     []byte
   372  		status string
   373  		begin  time.Time
   374  		end    sql.NullTime
   375  	)
   376  	if err = sqlutil.QueryRows(s.db,
   377  		`SELECT id, status, begin, end FROM transactions
   378  		ORDER BY begin DESC LIMIT ? OFFSET ?`,
   379  		[]interface{}{limit, offset},
   380  		[]interface{}{&id, &status, &begin, &end},
   381  		func() error {
   382  			tx := &ref.Transaction{
   383  				Status: ref.TransactionStatus(status),
   384  				Begin:  begin,
   385  			}
   386  			copy(tx.ID[:], id)
   387  			if end.Valid {
   388  				tx.End = end.Time
   389  			}
   390  			txs = append(txs, tx)
   391  			return nil
   392  		},
   393  	); err != nil {
   394  		return nil, err
   395  	}
   396  	return txs, nil
   397  }