github.com/aergoio/aergo@v1.3.1/contract/statesql.go (about)

     1  package contract
     2  
     3  /*
     4  #include "sqlite3-binding.h"
     5  */
     6  import "C"
     7  import (
     8  	"context"
     9  	"database/sql"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"github.com/aergoio/aergo/internal/enc"
    14  	"os"
    15  	"path/filepath"
    16  	"sync"
    17  
    18  	"github.com/aergoio/aergo-lib/log"
    19  	"github.com/aergoio/aergo/state"
    20  	"github.com/aergoio/aergo/types"
    21  )
    22  
    23  var (
    24  	ErrDBOpen = errors.New("failed to open the sql database")
    25  	ErrUndo   = errors.New("failed to undo the sql database")
    26  	ErrFindRp = errors.New("cannot find a recover point")
    27  
    28  	database = &Database{}
    29  	load     sync.Once
    30  
    31  	logger = log.NewLogger("statesql")
    32  
    33  	queryConn     *SQLiteConn
    34  	queryConnLock sync.Mutex
    35  )
    36  
    37  const (
    38  	statesqlDriver = "statesql"
    39  	queryDriver    = "query"
    40  )
    41  
    42  type Database struct {
    43  	DBs        map[string]*DB
    44  	OpenDbName string
    45  	DataDir    string
    46  }
    47  
    48  func init() {
    49  	sql.Register(statesqlDriver, &SQLiteDriver{
    50  		ConnectHook: func(conn *SQLiteConn) error {
    51  			if _, ok := database.DBs[database.OpenDbName]; !ok {
    52  				b, err := enc.ToBytes(database.OpenDbName)
    53  				if err != nil {
    54  					logger.Error().Err(err).Msg("Open SQL Connection")
    55  					return nil
    56  				}
    57  				database.DBs[database.OpenDbName] = &DB{
    58  					Conn:      nil,
    59  					db:        nil,
    60  					tx:        nil,
    61  					conn:      conn,
    62  					name:      database.OpenDbName,
    63  					accountID: types.AccountID(types.ToHashID(b)),
    64  				}
    65  			} else {
    66  				logger.Warn().Err(errors.New("duplicated connection")).Msg("Open SQL Connection")
    67  			}
    68  			return nil
    69  		},
    70  	})
    71  	sql.Register(queryDriver, &SQLiteDriver{
    72  		ConnectHook: func(conn *SQLiteConn) error {
    73  			queryConn = conn
    74  			return nil
    75  		},
    76  	})
    77  }
    78  
    79  func checkPath(path string) error {
    80  	_, err := os.Stat(path)
    81  	if os.IsNotExist(err) {
    82  		err = os.Mkdir(path, 0755)
    83  	}
    84  	return err
    85  }
    86  
    87  func LoadDatabase(dataDir string) error {
    88  	var err error
    89  	load.Do(func() {
    90  		path := filepath.Join(dataDir, statesqlDriver)
    91  		logger.Debug().Str("path", path).Msg("loading statesql")
    92  		if err = checkPath(path); err == nil {
    93  			database.DBs = make(map[string]*DB)
    94  			database.DataDir = path
    95  		}
    96  	})
    97  	return err
    98  }
    99  
   100  func LoadTestDatabase(dataDir string) error {
   101  	var err error
   102  	path := filepath.Join(dataDir, statesqlDriver)
   103  	logger.Debug().Str("path", path).Msg("loading statesql")
   104  	if err = checkPath(path); err == nil {
   105  		database.DBs = make(map[string]*DB)
   106  		database.DataDir = path
   107  	}
   108  	return err
   109  }
   110  
   111  func CloseDatabase() {
   112  	for name, db := range database.DBs {
   113  		if db.tx != nil {
   114  			db.tx.Rollback()
   115  			db.tx = nil
   116  		}
   117  		_ = db.close()
   118  		delete(database.DBs, name)
   119  	}
   120  }
   121  
   122  func SaveRecoveryPoint(bs *state.BlockState) error {
   123  	defer CloseDatabase()
   124  
   125  	for id, db := range database.DBs {
   126  		if db.tx != nil {
   127  			err := db.tx.Commit()
   128  			db.tx = nil
   129  			if err != nil {
   130  				continue
   131  			}
   132  			rp := db.recoveryPoint()
   133  			if rp == 0 {
   134  				return ErrFindRp
   135  			}
   136  			if rp > 0 {
   137  				if logger.IsDebugEnabled() {
   138  					logger.Debug().Str("db_name", id).Uint64("commit_id", rp).Msg("save recovery point")
   139  				}
   140  				receiverState, err := bs.GetAccountState(db.accountID)
   141  				if err != nil {
   142  					return err
   143  				}
   144  				receiverChange := types.State(*receiverState)
   145  				receiverChange.SqlRecoveryPoint = uint64(rp)
   146  				err = bs.PutState(db.accountID, &receiverChange)
   147  				if err != nil {
   148  					return err
   149  				}
   150  			}
   151  		}
   152  	}
   153  	return nil
   154  }
   155  
   156  func BeginTx(dbName string, rp uint64) (Tx, error) {
   157  	db, err := conn(dbName)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	return db.beginTx(rp)
   162  }
   163  
   164  func BeginReadOnly(dbName string, rp uint64) (Tx, error) {
   165  	db, err := readOnlyConn(dbName)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	return newReadOnlyTx(db, rp)
   170  }
   171  
   172  func conn(dbName string) (*DB, error) {
   173  	if db, ok := database.DBs[dbName]; ok {
   174  		return db, nil
   175  	}
   176  	return openDB(dbName)
   177  }
   178  
   179  func dataSrc(dbName string) string {
   180  	return fmt.Sprintf(
   181  		"file:%s/%s.db?branches=on&max_db_size=%d",
   182  		database.DataDir,
   183  		dbName,
   184  		maxSQLDBSize*1024*1024)
   185  }
   186  
   187  func readOnlyConn(dbName string) (*DB, error) {
   188  	queryConnLock.Lock()
   189  	defer queryConnLock.Unlock()
   190  
   191  	db, err := sql.Open(queryDriver, dataSrc(dbName)+"&_query_only=true")
   192  	if err != nil {
   193  		return nil, ErrDBOpen
   194  	}
   195  	err = db.Ping()
   196  	if err != nil {
   197  		logger.Fatal().Err(err)
   198  		_ = db.Close()
   199  		return nil, ErrDBOpen
   200  	}
   201  	c, err := db.Conn(context.Background())
   202  	if err != nil {
   203  		logger.Fatal().Err(err)
   204  		_ = db.Close()
   205  		return nil, ErrDBOpen
   206  	}
   207  	return &DB{
   208  		Conn: c,
   209  		db:   db,
   210  		tx:   nil,
   211  		conn: queryConn,
   212  		name: dbName,
   213  	}, nil
   214  }
   215  
   216  func openDB(dbName string) (*DB, error) {
   217  	database.OpenDbName = dbName
   218  	db, err := sql.Open(statesqlDriver, dataSrc(dbName))
   219  	if err != nil {
   220  		return nil, ErrDBOpen
   221  	}
   222  	c, err := db.Conn(context.Background())
   223  	if err != nil {
   224  		logger.Fatal().Err(err)
   225  		_ = db.Close()
   226  		return nil, ErrDBOpen
   227  	}
   228  	err = c.PingContext(context.Background())
   229  	if err != nil {
   230  		logger.Fatal().Err(err)
   231  		_ = c.Close()
   232  		_ = db.Close()
   233  		return nil, ErrDBOpen
   234  	}
   235  	_, err = c.ExecContext(context.Background(), "create table if not exists _dummy(_dummy)")
   236  	if err != nil {
   237  		logger.Fatal().Err(err)
   238  		_ = c.Close()
   239  		_ = db.Close()
   240  		return nil, ErrDBOpen
   241  	}
   242  	database.DBs[dbName].Conn = c
   243  	database.DBs[dbName].db = db
   244  	return database.DBs[dbName], nil
   245  }
   246  
   247  type DB struct {
   248  	*sql.Conn
   249  	db        *sql.DB
   250  	tx        Tx
   251  	conn      *SQLiteConn
   252  	name      string
   253  	accountID types.AccountID
   254  }
   255  
   256  func (db *DB) beginTx(rp uint64) (Tx, error) {
   257  	if db.tx == nil {
   258  		err := db.restoreRecoveryPoint(rp)
   259  		if err != nil {
   260  			return nil, err
   261  		}
   262  		if logger.IsDebugEnabled() {
   263  			logger.Debug().Str("db_name", db.name).Msg("begin transaction")
   264  		}
   265  		tx, err := db.BeginTx(context.Background(), nil)
   266  		if err != nil {
   267  			return nil, err
   268  		}
   269  		db.tx = &WritableTx{
   270  			TxCommon: TxCommon{db: db},
   271  			Tx:       tx,
   272  		}
   273  	}
   274  	return db.tx, nil
   275  }
   276  
   277  type branchInfo struct {
   278  	TotalCommits uint64 `json:"total_commits"`
   279  }
   280  
   281  func (db *DB) recoveryPoint() uint64 {
   282  	row := db.QueryRowContext(context.Background(), "pragma branch_info(master)")
   283  	var rv string
   284  	err := row.Scan(&rv)
   285  	if err != nil {
   286  		return uint64(0)
   287  	}
   288  	var bi branchInfo
   289  	err = json.Unmarshal([]byte(rv), &bi)
   290  	if err != nil {
   291  		return uint64(0)
   292  	}
   293  	return bi.TotalCommits
   294  }
   295  
   296  func (db *DB) restoreRecoveryPoint(stateRp uint64) error {
   297  	lastRp := db.recoveryPoint()
   298  	if logger.IsDebugEnabled() {
   299  		logger.Debug().Str("db_name", db.name).
   300  			Uint64("state_rp", stateRp).
   301  			Uint64("last_rp", lastRp).Msgf("restore recovery point")
   302  	}
   303  	if lastRp == 0 {
   304  		return ErrFindRp
   305  	}
   306  	if stateRp == lastRp {
   307  		return nil
   308  	}
   309  	if stateRp > lastRp {
   310  		return ErrUndo
   311  	}
   312  	if err := db.rollbackToRecoveryPoint(stateRp); err != nil {
   313  		return err
   314  	}
   315  	if logger.IsDebugEnabled() {
   316  		logger.Debug().Str("db_name", db.name).Uint64("commit_id", stateRp).
   317  			Msg("restore recovery point")
   318  	}
   319  	return nil
   320  }
   321  
   322  func (db *DB) rollbackToRecoveryPoint(rp uint64) error {
   323  	_, err := db.ExecContext(
   324  		context.Background(),
   325  		fmt.Sprintf("pragma branch_truncate(master.%d)", rp),
   326  	)
   327  	return err
   328  }
   329  
   330  func (db *DB) snapshotView(rp uint64) error {
   331  	if logger.IsDebugEnabled() {
   332  		logger.Debug().Uint64("rp", rp).Msgf("snapshot view, %p", db.Conn)
   333  	}
   334  	_, err := db.ExecContext(
   335  		context.Background(),
   336  		fmt.Sprintf("pragma branch=master.%d", rp),
   337  	)
   338  	return err
   339  }
   340  
   341  func (db *DB) close() error {
   342  	err := db.Conn.Close()
   343  	if err != nil {
   344  		_ = db.db.Close()
   345  		return err
   346  	}
   347  	return db.db.Close()
   348  }
   349  
   350  type Tx interface {
   351  	Commit() error
   352  	Rollback() error
   353  	Savepoint() error
   354  	Release() error
   355  	RollbackToSavepoint() error
   356  	SubSavepoint(string) error
   357  	SubRelease(string) error
   358  	RollbackToSubSavepoint(string) error
   359  	GetHandle() *C.sqlite3
   360  }
   361  
   362  type TxCommon struct {
   363  	db *DB
   364  }
   365  
   366  func (tx *TxCommon) GetHandle() *C.sqlite3 {
   367  	return tx.db.conn.db
   368  }
   369  
   370  type WritableTx struct {
   371  	TxCommon
   372  	*sql.Tx
   373  }
   374  
   375  func (tx *WritableTx) Commit() error {
   376  	if logger.IsDebugEnabled() {
   377  		logger.Debug().Str("db_name", tx.db.name).Msg("commit")
   378  	}
   379  	return tx.Tx.Commit()
   380  }
   381  
   382  func (tx *WritableTx) Rollback() error {
   383  	if logger.IsDebugEnabled() {
   384  		logger.Debug().Str("db_name", tx.db.name).Msg("rollback")
   385  	}
   386  	return tx.Tx.Rollback()
   387  }
   388  
   389  func (tx *WritableTx) Savepoint() error {
   390  	if logger.IsDebugEnabled() {
   391  		logger.Debug().Str("db_name", tx.db.name).Msg("savepoint")
   392  	}
   393  	_, err := tx.Tx.Exec("SAVEPOINT \"" + tx.db.name + "\"")
   394  	return err
   395  }
   396  
   397  func (tx *WritableTx) SubSavepoint(name string) error {
   398  	if logger.IsDebugEnabled() {
   399  		logger.Debug().Str("db_name", name).Msg("savepoint")
   400  	}
   401  	_, err := tx.Tx.Exec("SAVEPOINT \"" + name + "\"")
   402  	return err
   403  }
   404  
   405  func (tx *WritableTx) Release() error {
   406  	if logger.IsDebugEnabled() {
   407  		logger.Debug().Str("db_name", tx.db.name).Msg("release")
   408  	}
   409  	err := tx.db.conn.DBCacheFlush()
   410  	if err != nil {
   411  		return err
   412  	}
   413  	_, err = tx.Tx.Exec("RELEASE SAVEPOINT \"" + tx.db.name + "\"")
   414  	return err
   415  }
   416  
   417  func (tx *WritableTx) SubRelease(name string) error {
   418  	if logger.IsDebugEnabled() {
   419  		logger.Debug().Str("name", name).Msg("release")
   420  	}
   421  	_, err := tx.Tx.Exec("RELEASE SAVEPOINT \"" + name + "\"")
   422  	return err
   423  }
   424  
   425  func (tx *WritableTx) RollbackToSavepoint() error {
   426  	if logger.IsDebugEnabled() {
   427  		logger.Debug().Str("db_name", tx.db.name).Msg("rollback to savepoint")
   428  	}
   429  	_, err := tx.Tx.Exec("ROLLBACK TO SAVEPOINT \"" + tx.db.name + "\"")
   430  	return err
   431  }
   432  
   433  func (tx *WritableTx) RollbackToSubSavepoint(name string) error {
   434  	if logger.IsDebugEnabled() {
   435  		logger.Debug().Str("db_name", name).Msg("rollback to savepoint")
   436  	}
   437  	_, err := tx.Tx.Exec("ROLLBACK TO SAVEPOINT \"" + name + "\"")
   438  	return err
   439  }
   440  
   441  type ReadOnlyTx struct {
   442  	TxCommon
   443  }
   444  
   445  func newReadOnlyTx(db *DB, rp uint64) (Tx, error) {
   446  	if err := db.snapshotView(rp); err != nil {
   447  		return nil, err
   448  	}
   449  	tx := &ReadOnlyTx{
   450  		TxCommon: TxCommon{db: db},
   451  	}
   452  	return tx, nil
   453  }
   454  
   455  func (tx *ReadOnlyTx) Commit() error {
   456  	return errors.New("only select queries allowed")
   457  }
   458  
   459  func (tx *ReadOnlyTx) Rollback() error {
   460  	if logger.IsDebugEnabled() {
   461  		logger.Debug().Str("db_name", tx.db.name).Msg("read-only tx is closed")
   462  	}
   463  	return tx.db.close()
   464  }
   465  
   466  func (tx *ReadOnlyTx) Savepoint() error {
   467  	return errors.New("only select queries allowed")
   468  }
   469  
   470  func (tx *ReadOnlyTx) Release() error {
   471  	return errors.New("only select queries allowed")
   472  }
   473  
   474  func (tx *ReadOnlyTx) RollbackToSavepoint() error {
   475  	return tx.Rollback()
   476  }
   477  
   478  func (tx *ReadOnlyTx) SubSavepoint(name string) error {
   479  	return nil
   480  }
   481  
   482  func (tx *ReadOnlyTx) SubRelease(name string) error {
   483  	return nil
   484  }
   485  
   486  func (tx *ReadOnlyTx) RollbackToSubSavepoint(name string) error {
   487  	return nil
   488  }