github.com/flower-corp/rosedb@v1.1.2-0.20230117132829-21dc4f7b319a/sets.go (about)

     1  package rosedb
     2  
     3  import (
     4  	"github.com/flower-corp/rosedb/ds/art"
     5  	"github.com/flower-corp/rosedb/logfile"
     6  	"github.com/flower-corp/rosedb/logger"
     7  	"github.com/flower-corp/rosedb/util"
     8  )
     9  
    10  // SAdd add the specified members to the set stored at key.
    11  // Specified members that are already a member of this set are ignored.
    12  // If key does not exist, a new set is created before adding the specified members.
    13  func (db *RoseDB) SAdd(key []byte, members ...[]byte) error {
    14  	db.setIndex.mu.Lock()
    15  	defer db.setIndex.mu.Unlock()
    16  
    17  	if db.setIndex.trees[string(key)] == nil {
    18  		db.setIndex.trees[string(key)] = art.NewART()
    19  	}
    20  	idxTree := db.setIndex.trees[string(key)]
    21  	for _, mem := range members {
    22  		if len(mem) == 0 {
    23  			continue
    24  		}
    25  		if err := db.setIndex.murhash.Write(mem); err != nil {
    26  			return err
    27  		}
    28  		sum := db.setIndex.murhash.EncodeSum128()
    29  		db.setIndex.murhash.Reset()
    30  
    31  		ent := &logfile.LogEntry{Key: key, Value: mem}
    32  		valuePos, err := db.writeLogEntry(ent, Set)
    33  		if err != nil {
    34  			return err
    35  		}
    36  		entry := &logfile.LogEntry{Key: sum, Value: mem}
    37  		_, size := logfile.EncodeEntry(ent)
    38  		valuePos.entrySize = size
    39  		if err := db.updateIndexTree(idxTree, entry, valuePos, true, Set); err != nil {
    40  			return err
    41  		}
    42  	}
    43  	return nil
    44  }
    45  
    46  // SPop removes and returns one or more random members from the set value store at key.
    47  func (db *RoseDB) SPop(key []byte, count uint) ([][]byte, error) {
    48  	db.setIndex.mu.Lock()
    49  	defer db.setIndex.mu.Unlock()
    50  	if db.setIndex.trees[string(key)] == nil {
    51  		return nil, nil
    52  	}
    53  	idxTree := db.setIndex.trees[string(key)]
    54  
    55  	var values [][]byte
    56  	iter := idxTree.Iterator()
    57  	for iter.HasNext() && count > 0 {
    58  		count--
    59  		node, _ := iter.Next()
    60  		if node == nil {
    61  			continue
    62  		}
    63  		val, err := db.getVal(idxTree, node.Key(), Set)
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  		values = append(values, val)
    68  	}
    69  	for _, val := range values {
    70  		if err := db.sremInternal(key, val); err != nil {
    71  			return nil, err
    72  		}
    73  	}
    74  	return values, nil
    75  }
    76  
    77  // SRem remove the specified members from the set stored at key.
    78  // Specified members that are not a member of this set are ignored.
    79  // If key does not exist, it is treated as an empty set and this command returns 0.
    80  func (db *RoseDB) SRem(key []byte, members ...[]byte) error {
    81  	db.setIndex.mu.Lock()
    82  	defer db.setIndex.mu.Unlock()
    83  
    84  	if db.setIndex.trees[string(key)] == nil {
    85  		return nil
    86  	}
    87  	for _, mem := range members {
    88  		if err := db.sremInternal(key, mem); err != nil {
    89  			return err
    90  		}
    91  	}
    92  	return nil
    93  }
    94  
    95  // SIsMember returns if member is a member of the set stored at key.
    96  func (db *RoseDB) SIsMember(key, member []byte) bool {
    97  	db.setIndex.mu.RLock()
    98  	defer db.setIndex.mu.RUnlock()
    99  
   100  	if db.setIndex.trees[string(key)] == nil {
   101  		return false
   102  	}
   103  	idxTree := db.setIndex.trees[string(key)]
   104  	if err := db.setIndex.murhash.Write(member); err != nil {
   105  		return false
   106  	}
   107  	sum := db.setIndex.murhash.EncodeSum128()
   108  	db.setIndex.murhash.Reset()
   109  	node := idxTree.Get(sum)
   110  	return node != nil
   111  }
   112  
   113  // SMembers returns all the members of the set value stored at key.
   114  func (db *RoseDB) SMembers(key []byte) ([][]byte, error) {
   115  	db.setIndex.mu.RLock()
   116  	defer db.setIndex.mu.RUnlock()
   117  	return db.sMembers(key)
   118  }
   119  
   120  // SCard returns the set cardinality (number of elements) of the set stored at key.
   121  func (db *RoseDB) SCard(key []byte) int {
   122  	db.setIndex.mu.RLock()
   123  	defer db.setIndex.mu.RUnlock()
   124  	if db.setIndex.trees[string(key)] == nil {
   125  		return 0
   126  	}
   127  	return db.setIndex.trees[string(key)].Size()
   128  }
   129  
   130  // SDiff returns the members of the set difference between the first set and
   131  // all the successive sets. Returns error if no key is passed as a parameter.
   132  func (db *RoseDB) SDiff(keys ...[]byte) ([][]byte, error) {
   133  	db.setIndex.mu.RLock()
   134  	defer db.setIndex.mu.RUnlock()
   135  	if len(keys) == 0 {
   136  		return nil, ErrWrongNumberOfArgs
   137  	}
   138  	if len(keys) == 1 {
   139  		return db.sMembers(keys[0])
   140  	}
   141  
   142  	firstSet, err := db.sMembers(keys[0])
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	successiveSet := make(map[uint64]struct{})
   147  	for _, key := range keys[1:] {
   148  		members, err := db.sMembers(key)
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  		for _, value := range members {
   153  			h := util.MemHash(value)
   154  			if _, ok := successiveSet[h]; !ok {
   155  				successiveSet[h] = struct{}{}
   156  			}
   157  		}
   158  	}
   159  	if len(successiveSet) == 0 {
   160  		return firstSet, nil
   161  	}
   162  	res := make([][]byte, 0)
   163  	for _, value := range firstSet {
   164  		h := util.MemHash(value)
   165  		if _, ok := successiveSet[h]; !ok {
   166  			res = append(res, value)
   167  		}
   168  	}
   169  	return res, nil
   170  }
   171  
   172  // SDiffStore is equal to SDiff, but instead of returning the resulting set, it is stored in first param.
   173  func (db *RoseDB) SDiffStore(keys ...[]byte) (int, error) {
   174  	destination := keys[0]
   175  	diff, err := db.SDiff(keys[1:]...)
   176  	if err != nil {
   177  		return -1, err
   178  	}
   179  	if err := db.sStore(destination, diff); err != nil {
   180  		return -1, err
   181  	}
   182  	return db.SCard(destination), nil
   183  }
   184  
   185  // SUnion returns the members of the set resulting from the union of all the given sets.
   186  func (db *RoseDB) SUnion(keys ...[]byte) ([][]byte, error) {
   187  	db.setIndex.mu.RLock()
   188  	defer db.setIndex.mu.RUnlock()
   189  
   190  	if len(keys) == 0 {
   191  		return nil, ErrWrongNumberOfArgs
   192  	}
   193  	if len(keys) == 1 {
   194  		return db.sMembers(keys[0])
   195  	}
   196  
   197  	set := make(map[uint64]struct{})
   198  	unionSet := make([][]byte, 0)
   199  	for _, key := range keys {
   200  		values, err := db.sMembers(key)
   201  		if err != nil {
   202  			return nil, err
   203  		}
   204  		for _, val := range values {
   205  			h := util.MemHash(val)
   206  			if _, ok := set[h]; !ok {
   207  				set[h] = struct{}{}
   208  				unionSet = append(unionSet, val)
   209  			}
   210  		}
   211  	}
   212  	return unionSet, nil
   213  }
   214  
   215  // SUnionStore Store the union result in first param
   216  func (db *RoseDB) SUnionStore(keys ...[]byte) (int, error) {
   217  	destination := keys[0]
   218  	union, err := db.SUnion(keys[1:]...)
   219  	if err != nil {
   220  		return -1, err
   221  	}
   222  	if err := db.sStore(destination, union); err != nil {
   223  		return -1, err
   224  	}
   225  	return db.SCard(destination), nil
   226  }
   227  
   228  func (db *RoseDB) sremInternal(key []byte, member []byte) error {
   229  	idxTree := db.setIndex.trees[string(key)]
   230  	if err := db.setIndex.murhash.Write(member); err != nil {
   231  		return err
   232  	}
   233  	sum := db.setIndex.murhash.EncodeSum128()
   234  	db.setIndex.murhash.Reset()
   235  
   236  	val, updated := idxTree.Delete(sum)
   237  	if !updated {
   238  		return nil
   239  	}
   240  	entry := &logfile.LogEntry{Key: key, Value: sum, Type: logfile.TypeDelete}
   241  	pos, err := db.writeLogEntry(entry, Set)
   242  	if err != nil {
   243  		return err
   244  	}
   245  
   246  	db.sendDiscard(val, updated, Set)
   247  	// The deleted entry itself is also invalid.
   248  	_, size := logfile.EncodeEntry(entry)
   249  	node := &indexNode{fid: pos.fid, entrySize: size}
   250  	select {
   251  	case db.discards[Set].valChan <- node:
   252  	default:
   253  		logger.Warn("send to discard chan fail")
   254  	}
   255  	return nil
   256  }
   257  
   258  // sMembers is a helper method to get all members of the given set key.
   259  func (db *RoseDB) sMembers(key []byte) ([][]byte, error) {
   260  	if db.setIndex.trees[string(key)] == nil {
   261  		return nil, nil
   262  	}
   263  
   264  	var values [][]byte
   265  	idxTree := db.setIndex.trees[string(key)]
   266  	iterator := idxTree.Iterator()
   267  	for iterator.HasNext() {
   268  		node, _ := iterator.Next()
   269  		if node == nil {
   270  			continue
   271  		}
   272  		val, err := db.getVal(idxTree, node.Key(), Set)
   273  		if err != nil {
   274  			return nil, err
   275  		}
   276  		values = append(values, val)
   277  	}
   278  	return values, nil
   279  }
   280  
   281  // SInter returns the members of the set resulting from the inter of all the given sets.
   282  func (db *RoseDB) SInter(keys ...[]byte) ([][]byte, error) {
   283  	db.setIndex.mu.RLock()
   284  	defer db.setIndex.mu.RUnlock()
   285  
   286  	if len(keys) == 0 {
   287  		return nil, ErrWrongNumberOfArgs
   288  	}
   289  	if len(keys) == 1 {
   290  		return db.sMembers(keys[0])
   291  	}
   292  	num := len(keys)
   293  	set := make(map[uint64]int)
   294  	interSet := make([][]byte, 0)
   295  	for _, key := range keys {
   296  		values, err := db.sMembers(key)
   297  		if err != nil {
   298  			return nil, err
   299  		}
   300  		for _, val := range values {
   301  			h := util.MemHash(val)
   302  			set[h]++
   303  			if set[h] == num {
   304  				interSet = append(interSet, val)
   305  			}
   306  		}
   307  	}
   308  	return interSet, nil
   309  }
   310  
   311  // SInterStore Store the inter result in first param
   312  func (db *RoseDB) SInterStore(keys ...[]byte) (int, error) {
   313  	destination := keys[0]
   314  	inter, err := db.SInter(keys[1:]...)
   315  	if err != nil {
   316  		return -1, err
   317  	}
   318  	if err := db.sStore(destination, inter); err != nil {
   319  		return -1, err
   320  	}
   321  	return db.SCard(destination), nil
   322  }
   323  
   324  // sStore store vals in the set the destination points to
   325  // sStore is called in SInterStore SUnionStore SDiffStore
   326  func (db *RoseDB) sStore(destination []byte, vals [][]byte) error {
   327  	for _, val := range vals {
   328  		if isMember := db.SIsMember(destination, val); !isMember {
   329  			if err := db.SAdd(destination, val); err != nil {
   330  				return err
   331  			}
   332  		}
   333  	}
   334  	return nil
   335  }