github.com/hdt3213/godis@v1.2.9/database/database.go (about) 1 // Package database is a memory database with redis compatible interface 2 package database 3 4 import ( 5 "strings" 6 "time" 7 8 "github.com/hdt3213/godis/datastruct/dict" 9 "github.com/hdt3213/godis/interface/database" 10 "github.com/hdt3213/godis/interface/redis" 11 "github.com/hdt3213/godis/lib/logger" 12 "github.com/hdt3213/godis/lib/timewheel" 13 "github.com/hdt3213/godis/redis/protocol" 14 ) 15 16 const ( 17 dataDictSize = 1 << 16 18 ttlDictSize = 1 << 10 19 ) 20 21 // DB stores data and execute user's commands 22 type DB struct { 23 index int 24 // key -> DataEntity 25 data *dict.ConcurrentDict 26 // key -> expireTime (time.Time) 27 ttlMap *dict.ConcurrentDict 28 // key -> version(uint32) 29 versionMap *dict.ConcurrentDict 30 31 // addaof is used to add command to aof 32 addAof func(CmdLine) 33 } 34 35 // ExecFunc is interface for command executor 36 // args don't include cmd line 37 type ExecFunc func(db *DB, args [][]byte) redis.Reply 38 39 // PreFunc analyses command line when queued command to `multi` 40 // returns related write keys and read keys 41 type PreFunc func(args [][]byte) ([]string, []string) 42 43 // CmdLine is alias for [][]byte, represents a command line 44 type CmdLine = [][]byte 45 46 // UndoFunc returns undo logs for the given command line 47 // execute from head to tail when undo 48 type UndoFunc func(db *DB, args [][]byte) []CmdLine 49 50 // makeDB create DB instance 51 func makeDB() *DB { 52 db := &DB{ 53 data: dict.MakeConcurrent(dataDictSize), 54 ttlMap: dict.MakeConcurrent(ttlDictSize), 55 versionMap: dict.MakeConcurrent(dataDictSize), 56 addAof: func(line CmdLine) {}, 57 } 58 return db 59 } 60 61 // makeBasicDB create DB instance only with basic abilities. 62 func makeBasicDB() *DB { 63 db := &DB{ 64 data: dict.MakeConcurrent(dataDictSize), 65 ttlMap: dict.MakeConcurrent(ttlDictSize), 66 versionMap: dict.MakeConcurrent(dataDictSize), 67 addAof: func(line CmdLine) {}, 68 } 69 return db 70 } 71 72 // Exec executes command within one database 73 func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply { 74 // transaction control commands and other commands which cannot execute within transaction 75 cmdName := strings.ToLower(string(cmdLine[0])) 76 if cmdName == "multi" { 77 if len(cmdLine) != 1 { 78 return protocol.MakeArgNumErrReply(cmdName) 79 } 80 return StartMulti(c) 81 } else if cmdName == "discard" { 82 if len(cmdLine) != 1 { 83 return protocol.MakeArgNumErrReply(cmdName) 84 } 85 return DiscardMulti(c) 86 } else if cmdName == "exec" { 87 if len(cmdLine) != 1 { 88 return protocol.MakeArgNumErrReply(cmdName) 89 } 90 return execMulti(db, c) 91 } else if cmdName == "watch" { 92 if !validateArity(-2, cmdLine) { 93 return protocol.MakeArgNumErrReply(cmdName) 94 } 95 return Watch(db, c, cmdLine[1:]) 96 } 97 if c != nil && c.InMultiState() { 98 return EnqueueCmd(c, cmdLine) 99 } 100 101 return db.execNormalCommand(cmdLine) 102 } 103 104 func (db *DB) execNormalCommand(cmdLine [][]byte) redis.Reply { 105 cmdName := strings.ToLower(string(cmdLine[0])) 106 cmd, ok := cmdTable[cmdName] 107 if !ok { 108 return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") 109 } 110 if !validateArity(cmd.arity, cmdLine) { 111 return protocol.MakeArgNumErrReply(cmdName) 112 } 113 114 prepare := cmd.prepare 115 write, read := prepare(cmdLine[1:]) 116 db.addVersion(write...) 117 db.RWLocks(write, read) 118 defer db.RWUnLocks(write, read) 119 fun := cmd.executor 120 return fun(db, cmdLine[1:]) 121 } 122 123 // execWithLock executes normal commands, invoker should provide locks 124 func (db *DB) execWithLock(cmdLine [][]byte) redis.Reply { 125 cmdName := strings.ToLower(string(cmdLine[0])) 126 cmd, ok := cmdTable[cmdName] 127 if !ok { 128 return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") 129 } 130 if !validateArity(cmd.arity, cmdLine) { 131 return protocol.MakeArgNumErrReply(cmdName) 132 } 133 fun := cmd.executor 134 return fun(db, cmdLine[1:]) 135 } 136 137 func validateArity(arity int, cmdArgs [][]byte) bool { 138 argNum := len(cmdArgs) 139 if arity >= 0 { 140 return argNum == arity 141 } 142 return argNum >= -arity 143 } 144 145 /* ---- Data Access ----- */ 146 147 // GetEntity returns DataEntity bind to given key 148 func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { 149 raw, ok := db.data.GetWithLock(key) 150 if !ok { 151 return nil, false 152 } 153 if db.IsExpired(key) { 154 return nil, false 155 } 156 entity, _ := raw.(*database.DataEntity) 157 return entity, true 158 } 159 160 // PutEntity a DataEntity into DB 161 func (db *DB) PutEntity(key string, entity *database.DataEntity) int { 162 return db.data.PutWithLock(key, entity) 163 } 164 165 // PutIfExists edit an existing DataEntity 166 func (db *DB) PutIfExists(key string, entity *database.DataEntity) int { 167 return db.data.PutIfExistsWithLock(key, entity) 168 } 169 170 // PutIfAbsent insert an DataEntity only if the key not exists 171 func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int { 172 return db.data.PutIfAbsentWithLock(key, entity) 173 } 174 175 // Remove the given key from db 176 func (db *DB) Remove(key string) { 177 db.data.RemoveWithLock(key) 178 db.ttlMap.Remove(key) 179 taskKey := genExpireTask(key) 180 timewheel.Cancel(taskKey) 181 } 182 183 // Removes the given keys from db 184 func (db *DB) Removes(keys ...string) (deleted int) { 185 deleted = 0 186 for _, key := range keys { 187 _, exists := db.data.GetWithLock(key) 188 if exists { 189 db.Remove(key) 190 deleted++ 191 } 192 } 193 return deleted 194 } 195 196 // Flush clean database 197 // deprecated 198 // for test only 199 func (db *DB) Flush() { 200 db.data.Clear() 201 db.ttlMap.Clear() 202 } 203 204 /* ---- Lock Function ----- */ 205 206 // RWLocks lock keys for writing and reading 207 func (db *DB) RWLocks(writeKeys []string, readKeys []string) { 208 db.data.RWLocks(writeKeys, readKeys) 209 } 210 211 // RWUnLocks unlock keys for writing and reading 212 func (db *DB) RWUnLocks(writeKeys []string, readKeys []string) { 213 db.data.RWUnLocks(writeKeys, readKeys) 214 } 215 216 /* ---- TTL Functions ---- */ 217 218 func genExpireTask(key string) string { 219 return "expire:" + key 220 } 221 222 // Expire sets ttlCmd of key 223 func (db *DB) Expire(key string, expireTime time.Time) { 224 db.ttlMap.Put(key, expireTime) 225 taskKey := genExpireTask(key) 226 timewheel.At(expireTime, taskKey, func() { 227 keys := []string{key} 228 db.RWLocks(keys, nil) 229 defer db.RWUnLocks(keys, nil) 230 // check-lock-check, ttl may be updated during waiting lock 231 logger.Info("expire " + key) 232 rawExpireTime, ok := db.ttlMap.Get(key) 233 if !ok { 234 return 235 } 236 expireTime, _ := rawExpireTime.(time.Time) 237 expired := time.Now().After(expireTime) 238 if expired { 239 db.Remove(key) 240 } 241 }) 242 } 243 244 // Persist cancel ttlCmd of key 245 func (db *DB) Persist(key string) { 246 db.ttlMap.Remove(key) 247 taskKey := genExpireTask(key) 248 timewheel.Cancel(taskKey) 249 } 250 251 // IsExpired check whether a key is expired 252 func (db *DB) IsExpired(key string) bool { 253 rawExpireTime, ok := db.ttlMap.Get(key) 254 if !ok { 255 return false 256 } 257 expireTime, _ := rawExpireTime.(time.Time) 258 expired := time.Now().After(expireTime) 259 if expired { 260 db.Remove(key) 261 } 262 return expired 263 } 264 265 /* --- add version --- */ 266 267 func (db *DB) addVersion(keys ...string) { 268 for _, key := range keys { 269 versionCode := db.GetVersion(key) 270 db.versionMap.Put(key, versionCode+1) 271 } 272 } 273 274 // GetVersion returns version code for given key 275 func (db *DB) GetVersion(key string) uint32 { 276 entity, ok := db.versionMap.Get(key) 277 if !ok { 278 return 0 279 } 280 return entity.(uint32) 281 } 282 283 // ForEach traverses all the keys in the database 284 func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) { 285 db.data.ForEach(func(key string, raw interface{}) bool { 286 entity, _ := raw.(*database.DataEntity) 287 var expiration *time.Time 288 rawExpireTime, ok := db.ttlMap.Get(key) 289 if ok { 290 expireTime, _ := rawExpireTime.(time.Time) 291 expiration = &expireTime 292 } 293 294 return cb(key, entity, expiration) 295 }) 296 }