github.com/hdt3213/godis@v1.2.9/database/set.go (about)

     1  package database
     2  
     3  import (
     4  	HashSet "github.com/hdt3213/godis/datastruct/set"
     5  	"github.com/hdt3213/godis/interface/database"
     6  	"github.com/hdt3213/godis/interface/redis"
     7  	"github.com/hdt3213/godis/lib/utils"
     8  	"github.com/hdt3213/godis/redis/protocol"
     9  	"strconv"
    10  )
    11  
    12  func (db *DB) getAsSet(key string) (*HashSet.Set, protocol.ErrorReply) {
    13  	entity, exists := db.GetEntity(key)
    14  	if !exists {
    15  		return nil, nil
    16  	}
    17  	set, ok := entity.Data.(*HashSet.Set)
    18  	if !ok {
    19  		return nil, &protocol.WrongTypeErrReply{}
    20  	}
    21  	return set, nil
    22  }
    23  
    24  func (db *DB) getOrInitSet(key string) (set *HashSet.Set, inited bool, errReply protocol.ErrorReply) {
    25  	set, errReply = db.getAsSet(key)
    26  	if errReply != nil {
    27  		return nil, false, errReply
    28  	}
    29  	inited = false
    30  	if set == nil {
    31  		set = HashSet.Make()
    32  		db.PutEntity(key, &database.DataEntity{
    33  			Data: set,
    34  		})
    35  		inited = true
    36  	}
    37  	return set, inited, nil
    38  }
    39  
    40  // execSAdd adds members into set
    41  func execSAdd(db *DB, args [][]byte) redis.Reply {
    42  	key := string(args[0])
    43  	members := args[1:]
    44  
    45  	// get or init entity
    46  	set, _, errReply := db.getOrInitSet(key)
    47  	if errReply != nil {
    48  		return errReply
    49  	}
    50  	counter := 0
    51  	for _, member := range members {
    52  		counter += set.Add(string(member))
    53  	}
    54  	db.addAof(utils.ToCmdLine3("sadd", args...))
    55  	return protocol.MakeIntReply(int64(counter))
    56  }
    57  
    58  // execSIsMember checks if the given value is member of set
    59  func execSIsMember(db *DB, args [][]byte) redis.Reply {
    60  	key := string(args[0])
    61  	member := string(args[1])
    62  
    63  	// get set
    64  	set, errReply := db.getAsSet(key)
    65  	if errReply != nil {
    66  		return errReply
    67  	}
    68  	if set == nil {
    69  		return protocol.MakeIntReply(0)
    70  	}
    71  
    72  	has := set.Has(member)
    73  	if has {
    74  		return protocol.MakeIntReply(1)
    75  	}
    76  	return protocol.MakeIntReply(0)
    77  }
    78  
    79  // execSRem removes a member from set
    80  func execSRem(db *DB, args [][]byte) redis.Reply {
    81  	key := string(args[0])
    82  	members := args[1:]
    83  
    84  	set, errReply := db.getAsSet(key)
    85  	if errReply != nil {
    86  		return errReply
    87  	}
    88  	if set == nil {
    89  		return protocol.MakeIntReply(0)
    90  	}
    91  	counter := 0
    92  	for _, member := range members {
    93  		counter += set.Remove(string(member))
    94  	}
    95  	if set.Len() == 0 {
    96  		db.Remove(key)
    97  	}
    98  	if counter > 0 {
    99  		db.addAof(utils.ToCmdLine3("srem", args...))
   100  	}
   101  	return protocol.MakeIntReply(int64(counter))
   102  }
   103  
   104  // execSPop removes one or more random members from set
   105  func execSPop(db *DB, args [][]byte) redis.Reply {
   106  	if len(args) != 1 && len(args) != 2 {
   107  		return protocol.MakeErrReply("ERR wrong number of arguments for 'spop' command")
   108  	}
   109  	key := string(args[0])
   110  
   111  	set, errReply := db.getAsSet(key)
   112  	if errReply != nil {
   113  		return errReply
   114  	}
   115  	if set == nil {
   116  		return &protocol.NullBulkReply{}
   117  	}
   118  
   119  	count := 1
   120  	if len(args) == 2 {
   121  		count64, err := strconv.ParseInt(string(args[1]), 10, 64)
   122  		if err != nil || count64 <= 0 {
   123  			return protocol.MakeErrReply("ERR value is out of range, must be positive")
   124  		}
   125  		count = int(count64)
   126  	}
   127  	if count > set.Len() {
   128  		count = set.Len()
   129  	}
   130  
   131  	members := set.RandomDistinctMembers(count)
   132  	result := make([][]byte, len(members))
   133  	for i, v := range members {
   134  		set.Remove(v)
   135  		result[i] = []byte(v)
   136  	}
   137  
   138  	if count > 0 {
   139  		db.addAof(utils.ToCmdLine3("spop", args...))
   140  	}
   141  	return protocol.MakeMultiBulkReply(result)
   142  }
   143  
   144  // execSCard gets the number of members in a set
   145  func execSCard(db *DB, args [][]byte) redis.Reply {
   146  	key := string(args[0])
   147  
   148  	// get or init entity
   149  	set, errReply := db.getAsSet(key)
   150  	if errReply != nil {
   151  		return errReply
   152  	}
   153  	if set == nil {
   154  		return protocol.MakeIntReply(0)
   155  	}
   156  	return protocol.MakeIntReply(int64(set.Len()))
   157  }
   158  
   159  // execSMembers gets all members in a set
   160  func execSMembers(db *DB, args [][]byte) redis.Reply {
   161  	key := string(args[0])
   162  
   163  	// get or init entity
   164  	set, errReply := db.getAsSet(key)
   165  	if errReply != nil {
   166  		return errReply
   167  	}
   168  	if set == nil {
   169  		return &protocol.EmptyMultiBulkReply{}
   170  	}
   171  
   172  	arr := make([][]byte, set.Len())
   173  	i := 0
   174  	set.ForEach(func(member string) bool {
   175  		arr[i] = []byte(member)
   176  		i++
   177  		return true
   178  	})
   179  	return protocol.MakeMultiBulkReply(arr)
   180  }
   181  
   182  func set2reply(set *HashSet.Set) redis.Reply {
   183  	arr := make([][]byte, set.Len())
   184  	i := 0
   185  	set.ForEach(func(member string) bool {
   186  		arr[i] = []byte(member)
   187  		i++
   188  		return true
   189  	})
   190  	return protocol.MakeMultiBulkReply(arr)
   191  }
   192  
   193  // execSInter intersect multiple sets
   194  func execSInter(db *DB, args [][]byte) redis.Reply {
   195  	sets := make([]*HashSet.Set, 0, len(args))
   196  	for _, arg := range args {
   197  		key := string(arg)
   198  		set, errReply := db.getAsSet(key)
   199  		if errReply != nil {
   200  			return errReply
   201  		}
   202  		if set.Len() == 0 {
   203  			return &protocol.EmptyMultiBulkReply{}
   204  		}
   205  		sets = append(sets, set)
   206  	}
   207  	result := HashSet.Intersect(sets...)
   208  	return set2reply(result)
   209  }
   210  
   211  // execSInterStore intersects multiple sets and store the result in a key
   212  func execSInterStore(db *DB, args [][]byte) redis.Reply {
   213  	dest := string(args[0])
   214  	sets := make([]*HashSet.Set, 0, len(args)-1)
   215  	for i := 1; i < len(args); i++ {
   216  		key := string(args[i])
   217  		set, errReply := db.getAsSet(key)
   218  		if errReply != nil {
   219  			return errReply
   220  		}
   221  		if set.Len() == 0 {
   222  			return protocol.MakeIntReply(0)
   223  		}
   224  		sets = append(sets, set)
   225  	}
   226  	result := HashSet.Intersect(sets...)
   227  
   228  	db.PutEntity(dest, &database.DataEntity{
   229  		Data: result,
   230  	})
   231  	db.addAof(utils.ToCmdLine3("sinterstore", args...))
   232  	return protocol.MakeIntReply(int64(result.Len()))
   233  }
   234  
   235  // execSUnion adds multiple sets
   236  func execSUnion(db *DB, args [][]byte) redis.Reply {
   237  	sets := make([]*HashSet.Set, 0, len(args))
   238  	for _, arg := range args {
   239  		key := string(arg)
   240  		set, errReply := db.getAsSet(key)
   241  		if errReply != nil {
   242  			return errReply
   243  		}
   244  		sets = append(sets, set)
   245  	}
   246  	result := HashSet.Union(sets...)
   247  	return set2reply(result)
   248  }
   249  
   250  // execSUnionStore adds multiple sets and store the result in a key
   251  func execSUnionStore(db *DB, args [][]byte) redis.Reply {
   252  	dest := string(args[0])
   253  	sets := make([]*HashSet.Set, 0, len(args)-1)
   254  	for i := 1; i < len(args); i++ {
   255  		key := string(args[i])
   256  		set, errReply := db.getAsSet(key)
   257  		if errReply != nil {
   258  			return errReply
   259  		}
   260  		sets = append(sets, set)
   261  	}
   262  	result := HashSet.Union(sets...)
   263  	db.Remove(dest) // clean ttl
   264  	if result.Len() == 0 {
   265  		return protocol.MakeIntReply(0)
   266  	}
   267  
   268  	db.PutEntity(dest, &database.DataEntity{
   269  		Data: result,
   270  	})
   271  	db.addAof(utils.ToCmdLine3("sunionstore", args...))
   272  	return protocol.MakeIntReply(int64(result.Len()))
   273  }
   274  
   275  // execSDiff subtracts multiple sets
   276  func execSDiff(db *DB, args [][]byte) redis.Reply {
   277  	sets := make([]*HashSet.Set, 0, len(args))
   278  	for _, arg := range args {
   279  		key := string(arg)
   280  		set, errReply := db.getAsSet(key)
   281  		if errReply != nil {
   282  			return errReply
   283  		}
   284  		sets = append(sets, set)
   285  	}
   286  	result := HashSet.Diff(sets...)
   287  	return set2reply(result)
   288  }
   289  
   290  // execSDiffStore subtracts multiple sets and store the result in a key
   291  func execSDiffStore(db *DB, args [][]byte) redis.Reply {
   292  	dest := string(args[0])
   293  	sets := make([]*HashSet.Set, 0, len(args)-1)
   294  	for i := 1; i < len(args); i++ {
   295  		key := string(args[i])
   296  		set, errReply := db.getAsSet(key)
   297  		if errReply != nil {
   298  			return errReply
   299  		}
   300  		sets = append(sets, set)
   301  	}
   302  	result := HashSet.Diff(sets...)
   303  	db.Remove(dest) // clean ttl
   304  	if result.Len() == 0 {
   305  		return protocol.MakeIntReply(0)
   306  	}
   307  	db.PutEntity(dest, &database.DataEntity{
   308  		Data: result,
   309  	})
   310  	db.addAof(utils.ToCmdLine3("sdiffstore", args...))
   311  	return protocol.MakeIntReply(int64(result.Len()))
   312  }
   313  
   314  // execSRandMember gets random members from set
   315  func execSRandMember(db *DB, args [][]byte) redis.Reply {
   316  	if len(args) != 1 && len(args) != 2 {
   317  		return protocol.MakeErrReply("ERR wrong number of arguments for 'srandmember' command")
   318  	}
   319  	key := string(args[0])
   320  
   321  	// get or init entity
   322  	set, errReply := db.getAsSet(key)
   323  	if errReply != nil {
   324  		return errReply
   325  	}
   326  	if set == nil {
   327  		return &protocol.NullBulkReply{}
   328  	}
   329  	if len(args) == 1 {
   330  		// get a random member
   331  		members := set.RandomMembers(1)
   332  		return protocol.MakeBulkReply([]byte(members[0]))
   333  	}
   334  	count64, err := strconv.ParseInt(string(args[1]), 10, 64)
   335  	if err != nil {
   336  		return protocol.MakeErrReply("ERR value is not an integer or out of range")
   337  	}
   338  	count := int(count64)
   339  	if count > 0 {
   340  		members := set.RandomDistinctMembers(count)
   341  		result := make([][]byte, len(members))
   342  		for i, v := range members {
   343  			result[i] = []byte(v)
   344  		}
   345  		return protocol.MakeMultiBulkReply(result)
   346  	} else if count < 0 {
   347  		members := set.RandomMembers(-count)
   348  		result := make([][]byte, len(members))
   349  		for i, v := range members {
   350  			result[i] = []byte(v)
   351  		}
   352  		return protocol.MakeMultiBulkReply(result)
   353  	}
   354  	return &protocol.EmptyMultiBulkReply{}
   355  }
   356  
   357  func init() {
   358  	registerCommand("SAdd", execSAdd, writeFirstKey, undoSetChange, -3, flagWrite).
   359  		attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1)
   360  	registerCommand("SIsMember", execSIsMember, readFirstKey, nil, 3, flagReadOnly).
   361  		attachCommandExtra([]string{redisFlagReadonly, redisFlagFast}, 1, 1, 1)
   362  	registerCommand("SRem", execSRem, writeFirstKey, undoSetChange, -3, flagWrite).
   363  		attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1)
   364  	registerCommand("SPop", execSPop, writeFirstKey, undoSetChange, -2, flagWrite).
   365  		attachCommandExtra([]string{redisFlagWrite, redisFlagRandom, redisFlagFast}, 1, 1, 1)
   366  	registerCommand("SCard", execSCard, readFirstKey, nil, 2, flagReadOnly).
   367  		attachCommandExtra([]string{redisFlagReadonly, redisFlagFast}, 1, 1, 1)
   368  	registerCommand("SMembers", execSMembers, readFirstKey, nil, 2, flagReadOnly).
   369  		attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
   370  	registerCommand("SInter", execSInter, prepareSetCalculate, nil, -2, flagReadOnly).
   371  		attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, -1, 1)
   372  	registerCommand("SInterStore", execSInterStore, prepareSetCalculateStore, rollbackFirstKey, -3, flagWrite).
   373  		attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM}, 1, -1, 1)
   374  	registerCommand("SUnion", execSUnion, prepareSetCalculate, nil, -2, flagReadOnly).
   375  		attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, -1, 1)
   376  	registerCommand("SUnionStore", execSUnionStore, prepareSetCalculateStore, rollbackFirstKey, -3, flagWrite).
   377  		attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM}, 1, -1, 1)
   378  	registerCommand("SDiff", execSDiff, prepareSetCalculate, nil, -2, flagReadOnly).
   379  		attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
   380  	registerCommand("SDiffStore", execSDiffStore, prepareSetCalculateStore, rollbackFirstKey, -3, flagWrite).
   381  		attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM}, 1, 1, 1)
   382  	registerCommand("SRandMember", execSRandMember, readFirstKey, nil, -2, flagReadOnly).
   383  		attachCommandExtra([]string{redisFlagReadonly, redisFlagRandom}, 1, 1, 1)
   384  }