github.com/hdt3213/godis@v1.2.9/database/database.go (about)

     1  // Package database is a memory database with redis compatible interface
     2  package database
     3  
     4  import (
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/hdt3213/godis/datastruct/dict"
     9  	"github.com/hdt3213/godis/interface/database"
    10  	"github.com/hdt3213/godis/interface/redis"
    11  	"github.com/hdt3213/godis/lib/logger"
    12  	"github.com/hdt3213/godis/lib/timewheel"
    13  	"github.com/hdt3213/godis/redis/protocol"
    14  )
    15  
    16  const (
    17  	dataDictSize = 1 << 16
    18  	ttlDictSize  = 1 << 10
    19  )
    20  
    21  // DB stores data and execute user's commands
    22  type DB struct {
    23  	index int
    24  	// key -> DataEntity
    25  	data *dict.ConcurrentDict
    26  	// key -> expireTime (time.Time)
    27  	ttlMap *dict.ConcurrentDict
    28  	// key -> version(uint32)
    29  	versionMap *dict.ConcurrentDict
    30  
    31  	// addaof is used to add command to aof
    32  	addAof func(CmdLine)
    33  }
    34  
    35  // ExecFunc is interface for command executor
    36  // args don't include cmd line
    37  type ExecFunc func(db *DB, args [][]byte) redis.Reply
    38  
    39  // PreFunc analyses command line when queued command to `multi`
    40  // returns related write keys and read keys
    41  type PreFunc func(args [][]byte) ([]string, []string)
    42  
    43  // CmdLine is alias for [][]byte, represents a command line
    44  type CmdLine = [][]byte
    45  
    46  // UndoFunc returns undo logs for the given command line
    47  // execute from head to tail when undo
    48  type UndoFunc func(db *DB, args [][]byte) []CmdLine
    49  
    50  // makeDB create DB instance
    51  func makeDB() *DB {
    52  	db := &DB{
    53  		data:       dict.MakeConcurrent(dataDictSize),
    54  		ttlMap:     dict.MakeConcurrent(ttlDictSize),
    55  		versionMap: dict.MakeConcurrent(dataDictSize),
    56  		addAof:     func(line CmdLine) {},
    57  	}
    58  	return db
    59  }
    60  
    61  // makeBasicDB create DB instance only with basic abilities.
    62  func makeBasicDB() *DB {
    63  	db := &DB{
    64  		data:       dict.MakeConcurrent(dataDictSize),
    65  		ttlMap:     dict.MakeConcurrent(ttlDictSize),
    66  		versionMap: dict.MakeConcurrent(dataDictSize),
    67  		addAof:     func(line CmdLine) {},
    68  	}
    69  	return db
    70  }
    71  
    72  // Exec executes command within one database
    73  func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
    74  	// transaction control commands and other commands which cannot execute within transaction
    75  	cmdName := strings.ToLower(string(cmdLine[0]))
    76  	if cmdName == "multi" {
    77  		if len(cmdLine) != 1 {
    78  			return protocol.MakeArgNumErrReply(cmdName)
    79  		}
    80  		return StartMulti(c)
    81  	} else if cmdName == "discard" {
    82  		if len(cmdLine) != 1 {
    83  			return protocol.MakeArgNumErrReply(cmdName)
    84  		}
    85  		return DiscardMulti(c)
    86  	} else if cmdName == "exec" {
    87  		if len(cmdLine) != 1 {
    88  			return protocol.MakeArgNumErrReply(cmdName)
    89  		}
    90  		return execMulti(db, c)
    91  	} else if cmdName == "watch" {
    92  		if !validateArity(-2, cmdLine) {
    93  			return protocol.MakeArgNumErrReply(cmdName)
    94  		}
    95  		return Watch(db, c, cmdLine[1:])
    96  	}
    97  	if c != nil && c.InMultiState() {
    98  		return EnqueueCmd(c, cmdLine)
    99  	}
   100  
   101  	return db.execNormalCommand(cmdLine)
   102  }
   103  
   104  func (db *DB) execNormalCommand(cmdLine [][]byte) redis.Reply {
   105  	cmdName := strings.ToLower(string(cmdLine[0]))
   106  	cmd, ok := cmdTable[cmdName]
   107  	if !ok {
   108  		return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
   109  	}
   110  	if !validateArity(cmd.arity, cmdLine) {
   111  		return protocol.MakeArgNumErrReply(cmdName)
   112  	}
   113  
   114  	prepare := cmd.prepare
   115  	write, read := prepare(cmdLine[1:])
   116  	db.addVersion(write...)
   117  	db.RWLocks(write, read)
   118  	defer db.RWUnLocks(write, read)
   119  	fun := cmd.executor
   120  	return fun(db, cmdLine[1:])
   121  }
   122  
   123  // execWithLock executes normal commands, invoker should provide locks
   124  func (db *DB) execWithLock(cmdLine [][]byte) redis.Reply {
   125  	cmdName := strings.ToLower(string(cmdLine[0]))
   126  	cmd, ok := cmdTable[cmdName]
   127  	if !ok {
   128  		return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
   129  	}
   130  	if !validateArity(cmd.arity, cmdLine) {
   131  		return protocol.MakeArgNumErrReply(cmdName)
   132  	}
   133  	fun := cmd.executor
   134  	return fun(db, cmdLine[1:])
   135  }
   136  
   137  func validateArity(arity int, cmdArgs [][]byte) bool {
   138  	argNum := len(cmdArgs)
   139  	if arity >= 0 {
   140  		return argNum == arity
   141  	}
   142  	return argNum >= -arity
   143  }
   144  
   145  /* ---- Data Access ----- */
   146  
   147  // GetEntity returns DataEntity bind to given key
   148  func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
   149  	raw, ok := db.data.GetWithLock(key)
   150  	if !ok {
   151  		return nil, false
   152  	}
   153  	if db.IsExpired(key) {
   154  		return nil, false
   155  	}
   156  	entity, _ := raw.(*database.DataEntity)
   157  	return entity, true
   158  }
   159  
   160  // PutEntity a DataEntity into DB
   161  func (db *DB) PutEntity(key string, entity *database.DataEntity) int {
   162  	return db.data.PutWithLock(key, entity)
   163  }
   164  
   165  // PutIfExists edit an existing DataEntity
   166  func (db *DB) PutIfExists(key string, entity *database.DataEntity) int {
   167  	return db.data.PutIfExistsWithLock(key, entity)
   168  }
   169  
   170  // PutIfAbsent insert an DataEntity only if the key not exists
   171  func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int {
   172  	return db.data.PutIfAbsentWithLock(key, entity)
   173  }
   174  
   175  // Remove the given key from db
   176  func (db *DB) Remove(key string) {
   177  	db.data.RemoveWithLock(key)
   178  	db.ttlMap.Remove(key)
   179  	taskKey := genExpireTask(key)
   180  	timewheel.Cancel(taskKey)
   181  }
   182  
   183  // Removes the given keys from db
   184  func (db *DB) Removes(keys ...string) (deleted int) {
   185  	deleted = 0
   186  	for _, key := range keys {
   187  		_, exists := db.data.GetWithLock(key)
   188  		if exists {
   189  			db.Remove(key)
   190  			deleted++
   191  		}
   192  	}
   193  	return deleted
   194  }
   195  
   196  // Flush clean database
   197  // deprecated
   198  // for test only
   199  func (db *DB) Flush() {
   200  	db.data.Clear()
   201  	db.ttlMap.Clear()
   202  }
   203  
   204  /* ---- Lock Function ----- */
   205  
   206  // RWLocks lock keys for writing and reading
   207  func (db *DB) RWLocks(writeKeys []string, readKeys []string) {
   208  	db.data.RWLocks(writeKeys, readKeys)
   209  }
   210  
   211  // RWUnLocks unlock keys for writing and reading
   212  func (db *DB) RWUnLocks(writeKeys []string, readKeys []string) {
   213  	db.data.RWUnLocks(writeKeys, readKeys)
   214  }
   215  
   216  /* ---- TTL Functions ---- */
   217  
   218  func genExpireTask(key string) string {
   219  	return "expire:" + key
   220  }
   221  
   222  // Expire sets ttlCmd of key
   223  func (db *DB) Expire(key string, expireTime time.Time) {
   224  	db.ttlMap.Put(key, expireTime)
   225  	taskKey := genExpireTask(key)
   226  	timewheel.At(expireTime, taskKey, func() {
   227  		keys := []string{key}
   228  		db.RWLocks(keys, nil)
   229  		defer db.RWUnLocks(keys, nil)
   230  		// check-lock-check, ttl may be updated during waiting lock
   231  		logger.Info("expire " + key)
   232  		rawExpireTime, ok := db.ttlMap.Get(key)
   233  		if !ok {
   234  			return
   235  		}
   236  		expireTime, _ := rawExpireTime.(time.Time)
   237  		expired := time.Now().After(expireTime)
   238  		if expired {
   239  			db.Remove(key)
   240  		}
   241  	})
   242  }
   243  
   244  // Persist cancel ttlCmd of key
   245  func (db *DB) Persist(key string) {
   246  	db.ttlMap.Remove(key)
   247  	taskKey := genExpireTask(key)
   248  	timewheel.Cancel(taskKey)
   249  }
   250  
   251  // IsExpired check whether a key is expired
   252  func (db *DB) IsExpired(key string) bool {
   253  	rawExpireTime, ok := db.ttlMap.Get(key)
   254  	if !ok {
   255  		return false
   256  	}
   257  	expireTime, _ := rawExpireTime.(time.Time)
   258  	expired := time.Now().After(expireTime)
   259  	if expired {
   260  		db.Remove(key)
   261  	}
   262  	return expired
   263  }
   264  
   265  /* --- add version --- */
   266  
   267  func (db *DB) addVersion(keys ...string) {
   268  	for _, key := range keys {
   269  		versionCode := db.GetVersion(key)
   270  		db.versionMap.Put(key, versionCode+1)
   271  	}
   272  }
   273  
   274  // GetVersion returns version code for given key
   275  func (db *DB) GetVersion(key string) uint32 {
   276  	entity, ok := db.versionMap.Get(key)
   277  	if !ok {
   278  		return 0
   279  	}
   280  	return entity.(uint32)
   281  }
   282  
   283  // ForEach traverses all the keys in the database
   284  func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
   285  	db.data.ForEach(func(key string, raw interface{}) bool {
   286  		entity, _ := raw.(*database.DataEntity)
   287  		var expiration *time.Time
   288  		rawExpireTime, ok := db.ttlMap.Get(key)
   289  		if ok {
   290  			expireTime, _ := rawExpireTime.(time.Time)
   291  			expiration = &expireTime
   292  		}
   293  
   294  		return cb(key, entity, expiration)
   295  	})
   296  }