github.com/aergoio/aergo@v1.3.1/contract/statesql.go (about) 1 package contract 2 3 /* 4 #include "sqlite3-binding.h" 5 */ 6 import "C" 7 import ( 8 "context" 9 "database/sql" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "github.com/aergoio/aergo/internal/enc" 14 "os" 15 "path/filepath" 16 "sync" 17 18 "github.com/aergoio/aergo-lib/log" 19 "github.com/aergoio/aergo/state" 20 "github.com/aergoio/aergo/types" 21 ) 22 23 var ( 24 ErrDBOpen = errors.New("failed to open the sql database") 25 ErrUndo = errors.New("failed to undo the sql database") 26 ErrFindRp = errors.New("cannot find a recover point") 27 28 database = &Database{} 29 load sync.Once 30 31 logger = log.NewLogger("statesql") 32 33 queryConn *SQLiteConn 34 queryConnLock sync.Mutex 35 ) 36 37 const ( 38 statesqlDriver = "statesql" 39 queryDriver = "query" 40 ) 41 42 type Database struct { 43 DBs map[string]*DB 44 OpenDbName string 45 DataDir string 46 } 47 48 func init() { 49 sql.Register(statesqlDriver, &SQLiteDriver{ 50 ConnectHook: func(conn *SQLiteConn) error { 51 if _, ok := database.DBs[database.OpenDbName]; !ok { 52 b, err := enc.ToBytes(database.OpenDbName) 53 if err != nil { 54 logger.Error().Err(err).Msg("Open SQL Connection") 55 return nil 56 } 57 database.DBs[database.OpenDbName] = &DB{ 58 Conn: nil, 59 db: nil, 60 tx: nil, 61 conn: conn, 62 name: database.OpenDbName, 63 accountID: types.AccountID(types.ToHashID(b)), 64 } 65 } else { 66 logger.Warn().Err(errors.New("duplicated connection")).Msg("Open SQL Connection") 67 } 68 return nil 69 }, 70 }) 71 sql.Register(queryDriver, &SQLiteDriver{ 72 ConnectHook: func(conn *SQLiteConn) error { 73 queryConn = conn 74 return nil 75 }, 76 }) 77 } 78 79 func checkPath(path string) error { 80 _, err := os.Stat(path) 81 if os.IsNotExist(err) { 82 err = os.Mkdir(path, 0755) 83 } 84 return err 85 } 86 87 func LoadDatabase(dataDir string) error { 88 var err error 89 load.Do(func() { 90 path := filepath.Join(dataDir, statesqlDriver) 91 logger.Debug().Str("path", path).Msg("loading statesql") 92 if err = checkPath(path); err == nil { 93 database.DBs = make(map[string]*DB) 94 database.DataDir = path 95 } 96 }) 97 return err 98 } 99 100 func LoadTestDatabase(dataDir string) error { 101 var err error 102 path := filepath.Join(dataDir, statesqlDriver) 103 logger.Debug().Str("path", path).Msg("loading statesql") 104 if err = checkPath(path); err == nil { 105 database.DBs = make(map[string]*DB) 106 database.DataDir = path 107 } 108 return err 109 } 110 111 func CloseDatabase() { 112 for name, db := range database.DBs { 113 if db.tx != nil { 114 db.tx.Rollback() 115 db.tx = nil 116 } 117 _ = db.close() 118 delete(database.DBs, name) 119 } 120 } 121 122 func SaveRecoveryPoint(bs *state.BlockState) error { 123 defer CloseDatabase() 124 125 for id, db := range database.DBs { 126 if db.tx != nil { 127 err := db.tx.Commit() 128 db.tx = nil 129 if err != nil { 130 continue 131 } 132 rp := db.recoveryPoint() 133 if rp == 0 { 134 return ErrFindRp 135 } 136 if rp > 0 { 137 if logger.IsDebugEnabled() { 138 logger.Debug().Str("db_name", id).Uint64("commit_id", rp).Msg("save recovery point") 139 } 140 receiverState, err := bs.GetAccountState(db.accountID) 141 if err != nil { 142 return err 143 } 144 receiverChange := types.State(*receiverState) 145 receiverChange.SqlRecoveryPoint = uint64(rp) 146 err = bs.PutState(db.accountID, &receiverChange) 147 if err != nil { 148 return err 149 } 150 } 151 } 152 } 153 return nil 154 } 155 156 func BeginTx(dbName string, rp uint64) (Tx, error) { 157 db, err := conn(dbName) 158 if err != nil { 159 return nil, err 160 } 161 return db.beginTx(rp) 162 } 163 164 func BeginReadOnly(dbName string, rp uint64) (Tx, error) { 165 db, err := readOnlyConn(dbName) 166 if err != nil { 167 return nil, err 168 } 169 return newReadOnlyTx(db, rp) 170 } 171 172 func conn(dbName string) (*DB, error) { 173 if db, ok := database.DBs[dbName]; ok { 174 return db, nil 175 } 176 return openDB(dbName) 177 } 178 179 func dataSrc(dbName string) string { 180 return fmt.Sprintf( 181 "file:%s/%s.db?branches=on&max_db_size=%d", 182 database.DataDir, 183 dbName, 184 maxSQLDBSize*1024*1024) 185 } 186 187 func readOnlyConn(dbName string) (*DB, error) { 188 queryConnLock.Lock() 189 defer queryConnLock.Unlock() 190 191 db, err := sql.Open(queryDriver, dataSrc(dbName)+"&_query_only=true") 192 if err != nil { 193 return nil, ErrDBOpen 194 } 195 err = db.Ping() 196 if err != nil { 197 logger.Fatal().Err(err) 198 _ = db.Close() 199 return nil, ErrDBOpen 200 } 201 c, err := db.Conn(context.Background()) 202 if err != nil { 203 logger.Fatal().Err(err) 204 _ = db.Close() 205 return nil, ErrDBOpen 206 } 207 return &DB{ 208 Conn: c, 209 db: db, 210 tx: nil, 211 conn: queryConn, 212 name: dbName, 213 }, nil 214 } 215 216 func openDB(dbName string) (*DB, error) { 217 database.OpenDbName = dbName 218 db, err := sql.Open(statesqlDriver, dataSrc(dbName)) 219 if err != nil { 220 return nil, ErrDBOpen 221 } 222 c, err := db.Conn(context.Background()) 223 if err != nil { 224 logger.Fatal().Err(err) 225 _ = db.Close() 226 return nil, ErrDBOpen 227 } 228 err = c.PingContext(context.Background()) 229 if err != nil { 230 logger.Fatal().Err(err) 231 _ = c.Close() 232 _ = db.Close() 233 return nil, ErrDBOpen 234 } 235 _, err = c.ExecContext(context.Background(), "create table if not exists _dummy(_dummy)") 236 if err != nil { 237 logger.Fatal().Err(err) 238 _ = c.Close() 239 _ = db.Close() 240 return nil, ErrDBOpen 241 } 242 database.DBs[dbName].Conn = c 243 database.DBs[dbName].db = db 244 return database.DBs[dbName], nil 245 } 246 247 type DB struct { 248 *sql.Conn 249 db *sql.DB 250 tx Tx 251 conn *SQLiteConn 252 name string 253 accountID types.AccountID 254 } 255 256 func (db *DB) beginTx(rp uint64) (Tx, error) { 257 if db.tx == nil { 258 err := db.restoreRecoveryPoint(rp) 259 if err != nil { 260 return nil, err 261 } 262 if logger.IsDebugEnabled() { 263 logger.Debug().Str("db_name", db.name).Msg("begin transaction") 264 } 265 tx, err := db.BeginTx(context.Background(), nil) 266 if err != nil { 267 return nil, err 268 } 269 db.tx = &WritableTx{ 270 TxCommon: TxCommon{db: db}, 271 Tx: tx, 272 } 273 } 274 return db.tx, nil 275 } 276 277 type branchInfo struct { 278 TotalCommits uint64 `json:"total_commits"` 279 } 280 281 func (db *DB) recoveryPoint() uint64 { 282 row := db.QueryRowContext(context.Background(), "pragma branch_info(master)") 283 var rv string 284 err := row.Scan(&rv) 285 if err != nil { 286 return uint64(0) 287 } 288 var bi branchInfo 289 err = json.Unmarshal([]byte(rv), &bi) 290 if err != nil { 291 return uint64(0) 292 } 293 return bi.TotalCommits 294 } 295 296 func (db *DB) restoreRecoveryPoint(stateRp uint64) error { 297 lastRp := db.recoveryPoint() 298 if logger.IsDebugEnabled() { 299 logger.Debug().Str("db_name", db.name). 300 Uint64("state_rp", stateRp). 301 Uint64("last_rp", lastRp).Msgf("restore recovery point") 302 } 303 if lastRp == 0 { 304 return ErrFindRp 305 } 306 if stateRp == lastRp { 307 return nil 308 } 309 if stateRp > lastRp { 310 return ErrUndo 311 } 312 if err := db.rollbackToRecoveryPoint(stateRp); err != nil { 313 return err 314 } 315 if logger.IsDebugEnabled() { 316 logger.Debug().Str("db_name", db.name).Uint64("commit_id", stateRp). 317 Msg("restore recovery point") 318 } 319 return nil 320 } 321 322 func (db *DB) rollbackToRecoveryPoint(rp uint64) error { 323 _, err := db.ExecContext( 324 context.Background(), 325 fmt.Sprintf("pragma branch_truncate(master.%d)", rp), 326 ) 327 return err 328 } 329 330 func (db *DB) snapshotView(rp uint64) error { 331 if logger.IsDebugEnabled() { 332 logger.Debug().Uint64("rp", rp).Msgf("snapshot view, %p", db.Conn) 333 } 334 _, err := db.ExecContext( 335 context.Background(), 336 fmt.Sprintf("pragma branch=master.%d", rp), 337 ) 338 return err 339 } 340 341 func (db *DB) close() error { 342 err := db.Conn.Close() 343 if err != nil { 344 _ = db.db.Close() 345 return err 346 } 347 return db.db.Close() 348 } 349 350 type Tx interface { 351 Commit() error 352 Rollback() error 353 Savepoint() error 354 Release() error 355 RollbackToSavepoint() error 356 SubSavepoint(string) error 357 SubRelease(string) error 358 RollbackToSubSavepoint(string) error 359 GetHandle() *C.sqlite3 360 } 361 362 type TxCommon struct { 363 db *DB 364 } 365 366 func (tx *TxCommon) GetHandle() *C.sqlite3 { 367 return tx.db.conn.db 368 } 369 370 type WritableTx struct { 371 TxCommon 372 *sql.Tx 373 } 374 375 func (tx *WritableTx) Commit() error { 376 if logger.IsDebugEnabled() { 377 logger.Debug().Str("db_name", tx.db.name).Msg("commit") 378 } 379 return tx.Tx.Commit() 380 } 381 382 func (tx *WritableTx) Rollback() error { 383 if logger.IsDebugEnabled() { 384 logger.Debug().Str("db_name", tx.db.name).Msg("rollback") 385 } 386 return tx.Tx.Rollback() 387 } 388 389 func (tx *WritableTx) Savepoint() error { 390 if logger.IsDebugEnabled() { 391 logger.Debug().Str("db_name", tx.db.name).Msg("savepoint") 392 } 393 _, err := tx.Tx.Exec("SAVEPOINT \"" + tx.db.name + "\"") 394 return err 395 } 396 397 func (tx *WritableTx) SubSavepoint(name string) error { 398 if logger.IsDebugEnabled() { 399 logger.Debug().Str("db_name", name).Msg("savepoint") 400 } 401 _, err := tx.Tx.Exec("SAVEPOINT \"" + name + "\"") 402 return err 403 } 404 405 func (tx *WritableTx) Release() error { 406 if logger.IsDebugEnabled() { 407 logger.Debug().Str("db_name", tx.db.name).Msg("release") 408 } 409 err := tx.db.conn.DBCacheFlush() 410 if err != nil { 411 return err 412 } 413 _, err = tx.Tx.Exec("RELEASE SAVEPOINT \"" + tx.db.name + "\"") 414 return err 415 } 416 417 func (tx *WritableTx) SubRelease(name string) error { 418 if logger.IsDebugEnabled() { 419 logger.Debug().Str("name", name).Msg("release") 420 } 421 _, err := tx.Tx.Exec("RELEASE SAVEPOINT \"" + name + "\"") 422 return err 423 } 424 425 func (tx *WritableTx) RollbackToSavepoint() error { 426 if logger.IsDebugEnabled() { 427 logger.Debug().Str("db_name", tx.db.name).Msg("rollback to savepoint") 428 } 429 _, err := tx.Tx.Exec("ROLLBACK TO SAVEPOINT \"" + tx.db.name + "\"") 430 return err 431 } 432 433 func (tx *WritableTx) RollbackToSubSavepoint(name string) error { 434 if logger.IsDebugEnabled() { 435 logger.Debug().Str("db_name", name).Msg("rollback to savepoint") 436 } 437 _, err := tx.Tx.Exec("ROLLBACK TO SAVEPOINT \"" + name + "\"") 438 return err 439 } 440 441 type ReadOnlyTx struct { 442 TxCommon 443 } 444 445 func newReadOnlyTx(db *DB, rp uint64) (Tx, error) { 446 if err := db.snapshotView(rp); err != nil { 447 return nil, err 448 } 449 tx := &ReadOnlyTx{ 450 TxCommon: TxCommon{db: db}, 451 } 452 return tx, nil 453 } 454 455 func (tx *ReadOnlyTx) Commit() error { 456 return errors.New("only select queries allowed") 457 } 458 459 func (tx *ReadOnlyTx) Rollback() error { 460 if logger.IsDebugEnabled() { 461 logger.Debug().Str("db_name", tx.db.name).Msg("read-only tx is closed") 462 } 463 return tx.db.close() 464 } 465 466 func (tx *ReadOnlyTx) Savepoint() error { 467 return errors.New("only select queries allowed") 468 } 469 470 func (tx *ReadOnlyTx) Release() error { 471 return errors.New("only select queries allowed") 472 } 473 474 func (tx *ReadOnlyTx) RollbackToSavepoint() error { 475 return tx.Rollback() 476 } 477 478 func (tx *ReadOnlyTx) SubSavepoint(name string) error { 479 return nil 480 } 481 482 func (tx *ReadOnlyTx) SubRelease(name string) error { 483 return nil 484 } 485 486 func (tx *ReadOnlyTx) RollbackToSubSavepoint(name string) error { 487 return nil 488 }