github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/domain/set.go (about)

     1  package domain
     2  
     3  import (
     4  	"math/bits"
     5  )
     6  
     7  const prefixLabel = '\r'
     8  
     9  // mod from https://github.com/openacid/succinct
    10  
    11  type succinctSet struct {
    12  	leaves, labelBitmap []uint64
    13  	labels              []byte
    14  	ranks, selects      []int32
    15  }
    16  
    17  func newSuccinctSet(keys []string) *succinctSet {
    18  	ss := &succinctSet{}
    19  	lIdx := 0
    20  	type qElt struct{ s, e, col int }
    21  	queue := []qElt{{0, len(keys), 0}}
    22  	for i := 0; i < len(queue); i++ {
    23  		elt := queue[i]
    24  		if elt.col == len(keys[elt.s]) {
    25  			// a leaf node
    26  			elt.s++
    27  			setBit(&ss.leaves, i, 1)
    28  		}
    29  		for j := elt.s; j < elt.e; {
    30  			frm := j
    31  			for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ {
    32  			}
    33  			queue = append(queue, qElt{frm, j, elt.col + 1})
    34  			ss.labels = append(ss.labels, keys[frm][elt.col])
    35  			setBit(&ss.labelBitmap, lIdx, 0)
    36  			lIdx++
    37  		}
    38  		setBit(&ss.labelBitmap, lIdx, 1)
    39  		lIdx++
    40  	}
    41  	ss.init()
    42  	return ss
    43  }
    44  
    45  func (ss *succinctSet) Has(key string) bool {
    46  	var nodeId, bmIdx int
    47  	for i := 0; i < len(key); i++ {
    48  		currentChar := key[i]
    49  		for ; ; bmIdx++ {
    50  			if getBit(ss.labelBitmap, bmIdx) != 0 {
    51  				return false
    52  			}
    53  			nextLabel := ss.labels[bmIdx-nodeId]
    54  			if nextLabel == prefixLabel {
    55  				return true
    56  			}
    57  			if nextLabel == currentChar {
    58  				break
    59  			}
    60  		}
    61  		nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
    62  		bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
    63  	}
    64  	if getBit(ss.leaves, nodeId) != 0 {
    65  		return true
    66  	}
    67  	for ; ; bmIdx++ {
    68  		if getBit(ss.labelBitmap, bmIdx) != 0 {
    69  			return false
    70  		}
    71  		if ss.labels[bmIdx-nodeId] == prefixLabel {
    72  			return true
    73  		}
    74  	}
    75  }
    76  
    77  func setBit(bm *[]uint64, i int, v int) {
    78  	for i>>6 >= len(*bm) {
    79  		*bm = append(*bm, 0)
    80  	}
    81  	(*bm)[i>>6] |= uint64(v) << uint(i&63)
    82  }
    83  
    84  func getBit(bm []uint64, i int) uint64 {
    85  	return bm[i>>6] & (1 << uint(i&63))
    86  }
    87  
    88  func (ss *succinctSet) init() {
    89  	ss.selects, ss.ranks = indexSelect32R64(ss.labelBitmap)
    90  }
    91  
    92  func countZeros(bm []uint64, ranks []int32, i int) int {
    93  	a, _ := rank64(bm, ranks, int32(i))
    94  	return i - int(a)
    95  }
    96  
    97  func selectIthOne(bm []uint64, ranks, selects []int32, i int) int {
    98  	a, _ := select32R64(bm, selects, ranks, int32(i))
    99  	return int(a)
   100  }
   101  
   102  func rank64(words []uint64, rindex []int32, i int32) (int32, int32) {
   103  	wordI := i >> 6
   104  	j := uint32(i & 63)
   105  	n := rindex[wordI]
   106  	w := words[wordI]
   107  	c1 := n + int32(bits.OnesCount64(w&mask[j]))
   108  	return c1, int32(w>>uint(j)) & 1
   109  }
   110  
   111  func indexRank64(words []uint64, opts ...bool) []int32 {
   112  	trailing := false
   113  	if len(opts) > 0 {
   114  		trailing = opts[0]
   115  	}
   116  	l := len(words)
   117  	if trailing {
   118  		l++
   119  	}
   120  	idx := make([]int32, l)
   121  	n := int32(0)
   122  	for i := 0; i < len(words); i++ {
   123  		idx[i] = n
   124  		n += int32(bits.OnesCount64(words[i]))
   125  	}
   126  	if trailing {
   127  		idx[len(words)] = n
   128  	}
   129  	return idx
   130  }
   131  
   132  func select32R64(words []uint64, selectIndex, rankIndex []int32, i int32) (int32, int32) {
   133  	a := int32(0)
   134  	l := int32(len(words))
   135  	wordI := selectIndex[i>>5] >> 6
   136  	for ; rankIndex[wordI+1] <= i; wordI++ {
   137  	}
   138  	w := words[wordI]
   139  	ww := w
   140  	base := wordI << 6
   141  	findIth := int(i - rankIndex[wordI])
   142  	offset := int32(0)
   143  	ones := bits.OnesCount32(uint32(ww))
   144  	if ones <= findIth {
   145  		findIth -= ones
   146  		offset |= 32
   147  		ww >>= 32
   148  	}
   149  	ones = bits.OnesCount16(uint16(ww))
   150  	if ones <= findIth {
   151  		findIth -= ones
   152  		offset |= 16
   153  		ww >>= 16
   154  	}
   155  	ones = bits.OnesCount8(uint8(ww))
   156  	if ones <= findIth {
   157  		a = int32(select8Lookup[(ww>>5)&(0x7f8)|uint64(findIth-ones)]) + offset + 8
   158  	} else {
   159  		a = int32(select8Lookup[(ww&0xff)<<3|uint64(findIth)]) + offset
   160  	}
   161  	a += base
   162  	w &= rMaskUpto[a&63]
   163  	if w != 0 {
   164  		return a, base + int32(bits.TrailingZeros64(w))
   165  	}
   166  	wordI++
   167  	for ; wordI < l; wordI++ {
   168  		w = words[wordI]
   169  		if w != 0 {
   170  			return a, wordI<<6 + int32(bits.TrailingZeros64(w))
   171  		}
   172  	}
   173  	return a, l << 6
   174  }
   175  
   176  func indexSelect32R64(words []uint64) ([]int32, []int32) {
   177  	l := len(words) << 6
   178  	sidx := make([]int32, 0, len(words))
   179  
   180  	ith := -1
   181  	for i := 0; i < l; i++ {
   182  		if words[i>>6]&(1<<uint(i&63)) != 0 {
   183  			ith++
   184  			if ith&31 == 0 {
   185  				sidx = append(sidx, int32(i))
   186  			}
   187  		}
   188  	}
   189  
   190  	// clone to reduce cap to len
   191  	sidx = append(sidx[:0:0], sidx...)
   192  	return sidx, indexRank64(words, true)
   193  }
   194  
   195  func init() {
   196  	initMasks()
   197  	initSelectLookup()
   198  }
   199  
   200  var (
   201  	mask      [65]uint64
   202  	rMaskUpto [64]uint64
   203  )
   204  
   205  func initMasks() {
   206  	for i := 0; i < 65; i++ {
   207  		mask[i] = (1 << uint(i)) - 1
   208  	}
   209  
   210  	var maskUpto [64]uint64
   211  	for i := 0; i < 64; i++ {
   212  		maskUpto[i] = (1 << uint(i+1)) - 1
   213  		rMaskUpto[i] = ^maskUpto[i]
   214  	}
   215  }
   216  
   217  var select8Lookup [256 * 8]uint8
   218  
   219  func initSelectLookup() {
   220  	for i := 0; i < 256; i++ {
   221  		w := uint8(i)
   222  		for j := 0; j < 8; j++ {
   223  			// x-th 1 in w
   224  			// if x-th 1 is not found, it is 8
   225  			x := bits.TrailingZeros8(w)
   226  			w &= w - 1
   227  
   228  			select8Lookup[i*8+j] = uint8(x)
   229  		}
   230  	}
   231  }