github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/clone/logdb/database.go (about)

     1  // Copyright 2021 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package logdb contains read/write access to the locally cloned data.
    16  package logdb
    17  
    18  import (
    19  	"bytes"
    20  	"context"
    21  	"database/sql"
    22  	"encoding/gob"
    23  	"errors"
    24  	"fmt"
    25  
    26  	"github.com/golang/glog"
    27  )
    28  
    29  // ErrNoDataFound is returned when the DB appears valid but has no data in it.
    30  var ErrNoDataFound = errors.New("no data found")
    31  
    32  // Database provides read/write access to the mirrored log.
    33  type Database struct {
    34  	db *sql.DB
    35  }
    36  
    37  // NewDatabase creates a Database using the given database connection string.
    38  // This has been tested with sqlite and MariaDB.
    39  func NewDatabase(connString string) (*Database, error) {
    40  	dbConn, err := sql.Open("mysql", connString)
    41  	if err != nil {
    42  		return nil, fmt.Errorf("sql.Open: %w", err)
    43  	}
    44  	db := &Database{
    45  		db: dbConn,
    46  	}
    47  	return db, db.init()
    48  }
    49  
    50  // NewDatabaseDirect creates a Database using the given database connection.
    51  func NewDatabaseDirect(db *sql.DB) (*Database, error) {
    52  	ret := &Database{
    53  		db: db,
    54  	}
    55  	return ret, ret.init()
    56  }
    57  
    58  func (d *Database) init() error {
    59  	if _, err := d.db.Exec("CREATE TABLE IF NOT EXISTS leaves (id INTEGER PRIMARY KEY, data BLOB)"); err != nil {
    60  		return err
    61  	}
    62  	if _, err := d.db.Exec("CREATE TABLE IF NOT EXISTS checkpoints (size INTEGER PRIMARY KEY, data BLOB, compactRange BLOB)"); err != nil {
    63  		return err
    64  	}
    65  	return nil
    66  }
    67  
    68  // WriteCheckpoint writes the checkpoint for the given tree size.
    69  // This should have been verified before writing.
    70  func (d *Database) WriteCheckpoint(ctx context.Context, size uint64, checkpoint []byte, compactRange [][]byte) error {
    71  	tx, err := d.db.BeginTx(ctx, nil)
    72  	if err != nil {
    73  		return fmt.Errorf("BeginTx(): %v", err)
    74  	}
    75  
    76  	row := tx.QueryRowContext(ctx, "SELECT size FROM checkpoints ORDER BY size DESC LIMIT 1")
    77  	var max uint64
    78  	if err := row.Scan(&max); err != nil {
    79  		if err != sql.ErrNoRows {
    80  			if err := tx.Rollback(); err != nil {
    81  				glog.Errorf("tx.Rollback(): %v", err)
    82  			}
    83  			return fmt.Errorf("Scan(): %v", err)
    84  		}
    85  	}
    86  
    87  	if size <= max {
    88  		if err := tx.Rollback(); err != nil {
    89  			glog.Errorf("tx.Rollback(): %v", err)
    90  		}
    91  		return nil
    92  	}
    93  
    94  	var srs bytes.Buffer
    95  	enc := gob.NewEncoder(&srs)
    96  	if err := enc.Encode(compactRange); err != nil {
    97  		if err := tx.Rollback(); err != nil {
    98  			glog.Errorf("tx.Rollback(): %v", err)
    99  		}
   100  		return fmt.Errorf("Encode(): %v", err)
   101  	}
   102  	if _, err := tx.ExecContext(ctx, "INSERT INTO checkpoints (size, data, compactRange) VALUES (?, ?, ?)", size, checkpoint, srs.Bytes()); err != nil {
   103  		glog.Errorf("tx.ExecContext(): %v", err)
   104  	}
   105  	return tx.Commit()
   106  }
   107  
   108  // GetLatestCheckpoint gets the details of the latest checkpoint.
   109  func (d *Database) GetLatestCheckpoint(ctx context.Context) (size uint64, checkpoint []byte, compactRange [][]byte, err error) {
   110  	row := d.db.QueryRowContext(ctx, "SELECT size, data, compactRange FROM checkpoints ORDER BY size DESC LIMIT 1")
   111  	srs := make([]byte, 0)
   112  	if err := row.Scan(&size, &checkpoint, &srs); err != nil {
   113  		if err == sql.ErrNoRows {
   114  			return 0, nil, nil, ErrNoDataFound
   115  		}
   116  		return 0, nil, nil, fmt.Errorf("Scan(): %v", err)
   117  	}
   118  	dec := gob.NewDecoder(bytes.NewReader(srs))
   119  	if err := dec.Decode(&compactRange); err != nil {
   120  		return 0, nil, nil, fmt.Errorf("Decode(): %v", err)
   121  	}
   122  	return
   123  }
   124  
   125  // WriteLeaves writes the contiguous chunk of leaves, starting at the stated index.
   126  // This is an atomic operation, and will fail if any leaf cannot be inserted.
   127  func (d *Database) WriteLeaves(ctx context.Context, start uint64, leaves [][]byte) error {
   128  	tx, err := d.db.BeginTx(ctx, nil)
   129  	if err != nil {
   130  		return fmt.Errorf("BeginTx: %w", err)
   131  	}
   132  	for li, l := range leaves {
   133  		lidx := uint64(li) + start
   134  		if _, err := tx.Exec("INSERT INTO leaves (id, data) VALUES (?, ?)", lidx, l); err != nil {
   135  			glog.Errorf("tx.Exec(): %v", err)
   136  		}
   137  	}
   138  	return tx.Commit()
   139  }
   140  
   141  // StreamLeaves streams leaves in order starting at the given index, putting the leaf preimage
   142  // values on the `out` channel. This takes ownership of the out channel and closes it when no
   143  // more data will be returned.
   144  func (d *Database) StreamLeaves(ctx context.Context, start, end uint64, out chan<- StreamResult) {
   145  	defer close(out)
   146  	rows, err := d.db.QueryContext(ctx, "SELECT data FROM leaves WHERE id>=? AND id < ? ORDER BY id", start, end)
   147  	if err != nil {
   148  		out <- StreamResult{Err: err}
   149  		return
   150  	}
   151  	defer func() {
   152  		if err := rows.Close(); err != nil {
   153  			glog.Errorf("rows.Close(): %v", err)
   154  		}
   155  	}()
   156  	for rows.Next() {
   157  		var data []byte
   158  		if err := rows.Scan(&data); err != nil {
   159  			out <- StreamResult{Err: err}
   160  			return
   161  		}
   162  		out <- StreamResult{Leaf: data}
   163  	}
   164  	if err := rows.Err(); err != nil {
   165  		out <- StreamResult{Err: err}
   166  	}
   167  }
   168  
   169  // Head returns the largest leaf index written.
   170  func (d *Database) Head() (int64, error) {
   171  	var head sql.NullInt64
   172  	if err := d.db.QueryRow("SELECT MAX(id) AS head FROM leaves").Scan(&head); err != nil {
   173  		if err == sql.ErrNoRows {
   174  			return 0, ErrNoDataFound
   175  		}
   176  		return 0, fmt.Errorf("failed to get max revision: %w", err)
   177  	}
   178  	if head.Valid {
   179  		return head.Int64, nil
   180  	}
   181  	return 0, ErrNoDataFound
   182  }
   183  
   184  // StreamResult is the return type for StreamLeaves. It allows the leaves to
   185  // be returned in the same channel as any errors. Only one of Leaf or Err will
   186  // be populated in any StreamResult.
   187  type StreamResult struct {
   188  	Leaf []byte
   189  	Err  error
   190  }