github.com/rosedblabs/rosedb/v2@v2.3.7-0.20240423093736-a89ea823e5b9/batch.go (about)

     1  package rosedb
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"github.com/rosedblabs/rosedb/v2/utils"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/bwmarrin/snowflake"
    11  	"github.com/valyala/bytebufferpool"
    12  )
    13  
    14  // Batch is a batch operations of the database.
    15  // If readonly is true, you can only get data from the batch by Get method.
    16  // An error will be returned if you try to use Put or Delete method.
    17  //
    18  // If readonly is false, you can use Put and Delete method to write data to the batch.
    19  // The data will be written to the database when you call Commit method.
    20  //
    21  // Batch is not a transaction, it does not guarantee isolation.
    22  // But it can guarantee atomicity, consistency and durability(if the Sync options is true).
    23  //
    24  // You must call Commit method to commit the batch, otherwise the DB will be locked.
    25  type Batch struct {
    26  	db               *DB
    27  	pendingWrites    []*LogRecord     // save the data to be written
    28  	pendingWritesMap map[uint64][]int // map record hash key to index, fast lookup to pendingWrites
    29  	options          BatchOptions
    30  	mu               sync.RWMutex
    31  	committed        bool // whether the batch has been committed
    32  	rollbacked       bool // whether the batch has been rollbacked
    33  	batchId          *snowflake.Node
    34  	buffers          []*bytebufferpool.ByteBuffer
    35  }
    36  
    37  // NewBatch creates a new Batch instance.
    38  func (db *DB) NewBatch(options BatchOptions) *Batch {
    39  	batch := &Batch{
    40  		db:         db,
    41  		options:    options,
    42  		committed:  false,
    43  		rollbacked: false,
    44  	}
    45  	if !options.ReadOnly {
    46  		node, err := snowflake.NewNode(1)
    47  		if err != nil {
    48  			panic(fmt.Sprintf("snowflake.NewNode(1) failed: %v", err))
    49  		}
    50  		batch.batchId = node
    51  	}
    52  	batch.lock()
    53  	return batch
    54  }
    55  
    56  func newBatch() interface{} {
    57  	node, err := snowflake.NewNode(1)
    58  	if err != nil {
    59  		panic(fmt.Sprintf("snowflake.NewNode(1) failed: %v", err))
    60  	}
    61  	return &Batch{
    62  		options: DefaultBatchOptions,
    63  		batchId: node,
    64  	}
    65  }
    66  
    67  func newRecord() interface{} {
    68  	return &LogRecord{}
    69  }
    70  
    71  func (b *Batch) init(rdonly, sync bool, db *DB) *Batch {
    72  	b.options.ReadOnly = rdonly
    73  	b.options.Sync = sync
    74  	b.db = db
    75  	b.lock()
    76  	return b
    77  }
    78  
    79  func (b *Batch) reset() {
    80  	b.db = nil
    81  	b.pendingWrites = b.pendingWrites[:0]
    82  	b.pendingWritesMap = nil
    83  	b.committed = false
    84  	b.rollbacked = false
    85  	// put all buffers back to the pool
    86  	for _, buf := range b.buffers {
    87  		bytebufferpool.Put(buf)
    88  	}
    89  	b.buffers = b.buffers[:0]
    90  }
    91  
    92  func (b *Batch) lock() {
    93  	if b.options.ReadOnly {
    94  		b.db.mu.RLock()
    95  	} else {
    96  		b.db.mu.Lock()
    97  	}
    98  }
    99  
   100  func (b *Batch) unlock() {
   101  	if b.options.ReadOnly {
   102  		b.db.mu.RUnlock()
   103  	} else {
   104  		b.db.mu.Unlock()
   105  	}
   106  }
   107  
   108  // Put adds a key-value pair to the batch for writing.
   109  func (b *Batch) Put(key []byte, value []byte) error {
   110  	if len(key) == 0 {
   111  		return ErrKeyIsEmpty
   112  	}
   113  	if b.db.closed {
   114  		return ErrDBClosed
   115  	}
   116  	if b.options.ReadOnly {
   117  		return ErrReadOnlyBatch
   118  	}
   119  
   120  	b.mu.Lock()
   121  	// write to pendingWrites
   122  	var record = b.lookupPendingWrites(key)
   123  	if record == nil {
   124  		// if the key does not exist in pendingWrites, write a new record
   125  		// the record will be put back to the pool when the batch is committed or rollbacked
   126  		record = b.db.recordPool.Get().(*LogRecord)
   127  		b.appendPendingWrites(key, record)
   128  	}
   129  
   130  	record.Key, record.Value = key, value
   131  	record.Type, record.Expire = LogRecordNormal, 0
   132  	b.mu.Unlock()
   133  
   134  	return nil
   135  }
   136  
   137  // PutWithTTL adds a key-value pair with ttl to the batch for writing.
   138  func (b *Batch) PutWithTTL(key []byte, value []byte, ttl time.Duration) error {
   139  	if len(key) == 0 {
   140  		return ErrKeyIsEmpty
   141  	}
   142  	if b.db.closed {
   143  		return ErrDBClosed
   144  	}
   145  	if b.options.ReadOnly {
   146  		return ErrReadOnlyBatch
   147  	}
   148  
   149  	b.mu.Lock()
   150  	// write to pendingWrites
   151  	var record = b.lookupPendingWrites(key)
   152  	if record == nil {
   153  		// if the key does not exist in pendingWrites, write a new record
   154  		// the record will be put back to the pool when the batch is committed or rollbacked
   155  		record = b.db.recordPool.Get().(*LogRecord)
   156  		b.appendPendingWrites(key, record)
   157  	}
   158  
   159  	record.Key, record.Value = key, value
   160  	record.Type, record.Expire = LogRecordNormal, time.Now().Add(ttl).UnixNano()
   161  	b.mu.Unlock()
   162  
   163  	return nil
   164  }
   165  
   166  // Get retrieves the value associated with a given key from the batch.
   167  func (b *Batch) Get(key []byte) ([]byte, error) {
   168  	if len(key) == 0 {
   169  		return nil, ErrKeyIsEmpty
   170  	}
   171  	if b.db.closed {
   172  		return nil, ErrDBClosed
   173  	}
   174  
   175  	now := time.Now().UnixNano()
   176  	// get from pendingWrites
   177  	b.mu.RLock()
   178  	var record = b.lookupPendingWrites(key)
   179  	b.mu.RUnlock()
   180  
   181  	// if the record is in pendingWrites, return the value directly
   182  	if record != nil {
   183  		if record.Type == LogRecordDeleted || record.IsExpired(now) {
   184  			return nil, ErrKeyNotFound
   185  		}
   186  		return record.Value, nil
   187  	}
   188  
   189  	// get key/value from data file
   190  	chunkPosition := b.db.index.Get(key)
   191  	if chunkPosition == nil {
   192  		return nil, ErrKeyNotFound
   193  	}
   194  	chunk, err := b.db.dataFiles.Read(chunkPosition)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	// check if the record is deleted or expired
   200  	record = decodeLogRecord(chunk)
   201  	if record.Type == LogRecordDeleted {
   202  		panic("Deleted data cannot exist in the index")
   203  	}
   204  	if record.IsExpired(now) {
   205  		b.db.index.Delete(record.Key)
   206  		return nil, ErrKeyNotFound
   207  	}
   208  	return record.Value, nil
   209  }
   210  
   211  // Delete marks a key for deletion in the batch.
   212  func (b *Batch) Delete(key []byte) error {
   213  	if len(key) == 0 {
   214  		return ErrKeyIsEmpty
   215  	}
   216  	if b.db.closed {
   217  		return ErrDBClosed
   218  	}
   219  	if b.options.ReadOnly {
   220  		return ErrReadOnlyBatch
   221  	}
   222  
   223  	b.mu.Lock()
   224  	// only need key and type when deleting a value.
   225  	var exist bool
   226  	var record = b.lookupPendingWrites(key)
   227  	if record != nil {
   228  		record.Type = LogRecordDeleted
   229  		record.Value = nil
   230  		record.Expire = 0
   231  		exist = true
   232  	}
   233  	if !exist {
   234  		record = &LogRecord{
   235  			Key:  key,
   236  			Type: LogRecordDeleted,
   237  		}
   238  		b.appendPendingWrites(key, record)
   239  	}
   240  	b.mu.Unlock()
   241  
   242  	return nil
   243  }
   244  
   245  // Exist checks if the key exists in the database.
   246  func (b *Batch) Exist(key []byte) (bool, error) {
   247  	if len(key) == 0 {
   248  		return false, ErrKeyIsEmpty
   249  	}
   250  	if b.db.closed {
   251  		return false, ErrDBClosed
   252  	}
   253  
   254  	now := time.Now().UnixNano()
   255  	// check if the key exists in pendingWrites
   256  	b.mu.RLock()
   257  	var record = b.lookupPendingWrites(key)
   258  	b.mu.RUnlock()
   259  
   260  	if record != nil {
   261  		return record.Type != LogRecordDeleted && !record.IsExpired(now), nil
   262  	}
   263  
   264  	// check if the key exists in index
   265  	position := b.db.index.Get(key)
   266  	if position == nil {
   267  		return false, nil
   268  	}
   269  
   270  	// check if the record is deleted or expired
   271  	chunk, err := b.db.dataFiles.Read(position)
   272  	if err != nil {
   273  		return false, err
   274  	}
   275  
   276  	record = decodeLogRecord(chunk)
   277  	if record.Type == LogRecordDeleted || record.IsExpired(now) {
   278  		b.db.index.Delete(record.Key)
   279  		return false, nil
   280  	}
   281  	return true, nil
   282  }
   283  
   284  // Expire sets the ttl of the key.
   285  func (b *Batch) Expire(key []byte, ttl time.Duration) error {
   286  	if len(key) == 0 {
   287  		return ErrKeyIsEmpty
   288  	}
   289  	if b.db.closed {
   290  		return ErrDBClosed
   291  	}
   292  	if b.options.ReadOnly {
   293  		return ErrReadOnlyBatch
   294  	}
   295  
   296  	b.mu.Lock()
   297  	defer b.mu.Unlock()
   298  
   299  	var record = b.lookupPendingWrites(key)
   300  
   301  	// if the key exists in pendingWrites, update the expiry time directly
   302  	if record != nil {
   303  		// return key not found if the record is deleted or expired
   304  		if record.Type == LogRecordDeleted || record.IsExpired(time.Now().UnixNano()) {
   305  			return ErrKeyNotFound
   306  		}
   307  		record.Expire = time.Now().Add(ttl).UnixNano()
   308  		return nil
   309  	}
   310  	// if the key does not exist in pendingWrites, get the value from wal
   311  	position := b.db.index.Get(key)
   312  	if position == nil {
   313  		return ErrKeyNotFound
   314  	}
   315  	chunk, err := b.db.dataFiles.Read(position)
   316  	if err != nil {
   317  		return err
   318  	}
   319  
   320  	now := time.Now()
   321  	record = decodeLogRecord(chunk)
   322  	// if the record is deleted or expired, we can assume that the key does not exist,
   323  	// and delete the key from the index
   324  	if record.Type == LogRecordDeleted || record.IsExpired(now.UnixNano()) {
   325  		b.db.index.Delete(key)
   326  		return ErrKeyNotFound
   327  	}
   328  	// now we get the value from wal, update the expiry time
   329  	// and rewrite the record to pendingWrites
   330  	record.Expire = now.Add(ttl).UnixNano()
   331  	b.appendPendingWrites(key, record)
   332  
   333  	return nil
   334  }
   335  
   336  // TTL returns the ttl of the key.
   337  func (b *Batch) TTL(key []byte) (time.Duration, error) {
   338  	if len(key) == 0 {
   339  		return -1, ErrKeyIsEmpty
   340  	}
   341  	if b.db.closed {
   342  		return -1, ErrDBClosed
   343  	}
   344  
   345  	now := time.Now()
   346  	b.mu.Lock()
   347  	defer b.mu.Unlock()
   348  
   349  	var record = b.lookupPendingWrites(key)
   350  	if record != nil {
   351  		if record.Expire == 0 {
   352  			return -1, nil
   353  		}
   354  		// return key not found if the record is deleted or expired
   355  		if record.Type == LogRecordDeleted || record.IsExpired(now.UnixNano()) {
   356  			return -1, ErrKeyNotFound
   357  		}
   358  		// now we get the valid expiry time, we can calculate the ttl
   359  		return time.Duration(record.Expire - now.UnixNano()), nil
   360  	}
   361  
   362  	// if the key does not exist in pendingWrites, get the value from wal
   363  	position := b.db.index.Get(key)
   364  	if position == nil {
   365  		return -1, ErrKeyNotFound
   366  	}
   367  	chunk, err := b.db.dataFiles.Read(position)
   368  	if err != nil {
   369  		return -1, err
   370  	}
   371  
   372  	// return key not found if the record is deleted or expired
   373  	record = decodeLogRecord(chunk)
   374  	if record.Type == LogRecordDeleted {
   375  		return -1, ErrKeyNotFound
   376  	}
   377  	if record.IsExpired(now.UnixNano()) {
   378  		b.db.index.Delete(key)
   379  		return -1, ErrKeyNotFound
   380  	}
   381  
   382  	// now we get the valid expiry time, we can calculate the ttl
   383  	if record.Expire > 0 {
   384  		return time.Duration(record.Expire - now.UnixNano()), nil
   385  	}
   386  
   387  	return -1, nil
   388  }
   389  
   390  // Persist removes the ttl of the key.
   391  func (b *Batch) Persist(key []byte) error {
   392  	if len(key) == 0 {
   393  		return ErrKeyIsEmpty
   394  	}
   395  	if b.db.closed {
   396  		return ErrDBClosed
   397  	}
   398  	if b.options.ReadOnly {
   399  		return ErrReadOnlyBatch
   400  	}
   401  
   402  	b.mu.Lock()
   403  	defer b.mu.Unlock()
   404  
   405  	// if the key exists in pendingWrites, update the expiry time directly
   406  	var record = b.lookupPendingWrites(key)
   407  	if record != nil {
   408  		if record.Type == LogRecordDeleted && record.IsExpired(time.Now().UnixNano()) {
   409  			return ErrKeyNotFound
   410  		}
   411  		record.Expire = 0
   412  		return nil
   413  	}
   414  
   415  	// check if the key exists in index
   416  	position := b.db.index.Get(key)
   417  	if position == nil {
   418  		return ErrKeyNotFound
   419  	}
   420  	chunk, err := b.db.dataFiles.Read(position)
   421  	if err != nil {
   422  		return err
   423  	}
   424  
   425  	record = decodeLogRecord(chunk)
   426  	now := time.Now().UnixNano()
   427  	// check if the record is deleted or expired
   428  	if record.Type == LogRecordDeleted || record.IsExpired(now) {
   429  		b.db.index.Delete(record.Key)
   430  		return ErrKeyNotFound
   431  	}
   432  	// if the expiration time is 0, it means that the key has no expiration time,
   433  	// so we can return directly
   434  	if record.Expire == 0 {
   435  		return nil
   436  	}
   437  
   438  	// set the expiration time to 0, and rewrite the record to wal
   439  	record.Expire = 0
   440  	b.appendPendingWrites(key, record)
   441  
   442  	return nil
   443  }
   444  
   445  // Commit commits the batch, if the batch is readonly or empty, it will return directly.
   446  //
   447  // It will iterate the pendingWrites and write the data to the database,
   448  // then write a record to indicate the end of the batch to guarantee atomicity.
   449  // Finally, it will write the index.
   450  func (b *Batch) Commit() error {
   451  	defer b.unlock()
   452  	if b.db.closed {
   453  		return ErrDBClosed
   454  	}
   455  
   456  	if b.options.ReadOnly || len(b.pendingWrites) == 0 {
   457  		return nil
   458  	}
   459  
   460  	b.mu.Lock()
   461  	defer b.mu.Unlock()
   462  
   463  	// check if committed or rollbacked
   464  	if b.committed {
   465  		return ErrBatchCommitted
   466  	}
   467  	if b.rollbacked {
   468  		return ErrBatchRollbacked
   469  	}
   470  
   471  	batchId := b.batchId.Generate()
   472  	now := time.Now().UnixNano()
   473  	// write to wal buffer
   474  	for _, record := range b.pendingWrites {
   475  		buf := bytebufferpool.Get()
   476  		b.buffers = append(b.buffers, buf)
   477  		record.BatchId = uint64(batchId)
   478  		encRecord := encodeLogRecord(record, b.db.encodeHeader, buf)
   479  		b.db.dataFiles.PendingWrites(encRecord)
   480  	}
   481  
   482  	// write a record to indicate the end of the batch
   483  	buf := bytebufferpool.Get()
   484  	b.buffers = append(b.buffers, buf)
   485  	endRecord := encodeLogRecord(&LogRecord{
   486  		Key:  batchId.Bytes(),
   487  		Type: LogRecordBatchFinished,
   488  	}, b.db.encodeHeader, buf)
   489  	b.db.dataFiles.PendingWrites(endRecord)
   490  
   491  	// write to wal file
   492  	chunkPositions, err := b.db.dataFiles.WriteAll()
   493  	if err != nil {
   494  		b.db.dataFiles.ClearPendingWrites()
   495  		return err
   496  	}
   497  	if len(chunkPositions) != len(b.pendingWrites)+1 {
   498  		panic("chunk positions length is not equal to pending writes length")
   499  	}
   500  
   501  	// flush wal if necessary
   502  	if b.options.Sync && !b.db.options.Sync {
   503  		if err := b.db.dataFiles.Sync(); err != nil {
   504  			return err
   505  		}
   506  	}
   507  
   508  	// write to index
   509  	for i, record := range b.pendingWrites {
   510  		if record.Type == LogRecordDeleted || record.IsExpired(now) {
   511  			b.db.index.Delete(record.Key)
   512  		} else {
   513  			b.db.index.Put(record.Key, chunkPositions[i])
   514  		}
   515  
   516  		if b.db.options.WatchQueueSize > 0 {
   517  			e := &Event{Key: record.Key, Value: record.Value, BatchId: record.BatchId}
   518  			if record.Type == LogRecordDeleted {
   519  				e.Action = WatchActionDelete
   520  			} else {
   521  				e.Action = WatchActionPut
   522  			}
   523  			b.db.watcher.putEvent(e)
   524  		}
   525  		// put the record back to the pool
   526  		b.db.recordPool.Put(record)
   527  	}
   528  
   529  	b.committed = true
   530  	return nil
   531  }
   532  
   533  // Rollback discards an uncommitted batch instance.
   534  // the discard operation will clear the buffered data and release the lock.
   535  func (b *Batch) Rollback() error {
   536  	defer b.unlock()
   537  
   538  	if b.db.closed {
   539  		return ErrDBClosed
   540  	}
   541  
   542  	if b.committed {
   543  		return ErrBatchCommitted
   544  	}
   545  	if b.rollbacked {
   546  		return ErrBatchRollbacked
   547  	}
   548  
   549  	for _, buf := range b.buffers {
   550  		bytebufferpool.Put(buf)
   551  	}
   552  
   553  	if !b.options.ReadOnly {
   554  		// clear pendingWrites
   555  		for _, record := range b.pendingWrites {
   556  			b.db.recordPool.Put(record)
   557  		}
   558  		b.pendingWrites = b.pendingWrites[:0]
   559  		for key := range b.pendingWritesMap {
   560  			delete(b.pendingWritesMap, key)
   561  		}
   562  	}
   563  
   564  	b.rollbacked = true
   565  	return nil
   566  }
   567  
   568  // lookupPendingWrites if the key exists in pendingWrites, update the value directly
   569  func (b *Batch) lookupPendingWrites(key []byte) *LogRecord {
   570  	if len(b.pendingWritesMap) == 0 {
   571  		return nil
   572  	}
   573  
   574  	hashKey := utils.MemHash(key)
   575  	for _, entry := range b.pendingWritesMap[hashKey] {
   576  		if bytes.Compare(b.pendingWrites[entry].Key, key) == 0 {
   577  			return b.pendingWrites[entry]
   578  		}
   579  	}
   580  	return nil
   581  }
   582  
   583  // add new record to pendingWrites and pendingWritesMap.
   584  func (b *Batch) appendPendingWrites(key []byte, record *LogRecord) {
   585  	b.pendingWrites = append(b.pendingWrites, record)
   586  	if b.pendingWritesMap == nil {
   587  		b.pendingWritesMap = make(map[uint64][]int)
   588  	}
   589  	hashKey := utils.MemHash(key)
   590  	b.pendingWritesMap[hashKey] = append(b.pendingWritesMap[hashKey], len(b.pendingWrites)-1)
   591  }