github.com/coyove/sdss@v0.0.0-20231129015646-c2ec58cca6a2/contrib/bitmap/bitmap.go (about)

     1  package bitmap
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash/crc32"
     8  	"io"
     9  	"math"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/coyove/sdss/contrib/roaring"
    14  	"github.com/coyove/sdss/contrib/simple"
    15  	"github.com/pierrec/lz4/v4"
    16  )
    17  
    18  const (
    19  	slotSize     = 1 << 14
    20  	slotNum      = 1 << 6
    21  	fastSlotNum  = 1 << 12
    22  	fastSlotSize = 1 << 8
    23  	fastSlotMask = 0xfffff000
    24  	bfHash       = 3
    25  
    26  	Capcity = slotSize * slotNum
    27  )
    28  
    29  type Range struct {
    30  	mu         sync.RWMutex
    31  	mfmu       sync.Mutex
    32  	start, end int64
    33  	fastTable  *roaring.Bitmap
    34  	slots      [slotNum]*subMap
    35  }
    36  
    37  func New(start int64) *Range {
    38  	d := &Range{
    39  		start:     start,
    40  		end:       -1,
    41  		fastTable: roaring.New(),
    42  	}
    43  	for i := range d.slots {
    44  		d.slots[i] = &subMap{}
    45  	}
    46  	return d
    47  }
    48  
    49  func (b *Range) Start() int64 {
    50  	return b.start
    51  }
    52  
    53  func (b *Range) End() int64 {
    54  	return b.start + b.end
    55  }
    56  
    57  func (b *Range) Len() int64 {
    58  	return b.end + 1
    59  }
    60  
    61  func (b *Range) FirstKey() Key {
    62  	if b.end < 0 {
    63  		return Key{}
    64  	}
    65  	m := b.slots[0]
    66  	m.mu.RLock()
    67  	defer m.mu.RUnlock()
    68  	return m.keys[0]
    69  }
    70  
    71  func (b *Range) LastKey() Key {
    72  	if b.end < 0 {
    73  		return Key{}
    74  	}
    75  	m := b.slots[b.end/slotSize]
    76  	m.mu.RLock()
    77  	defer m.mu.RUnlock()
    78  	return m.keys[len(m.keys)-1]
    79  }
    80  
    81  type subMap struct {
    82  	mu    sync.RWMutex
    83  	keys  []Key
    84  	spans []uint32
    85  	xfs   []byte
    86  }
    87  
    88  func (b *Range) Add(key Key, values []uint64) bool {
    89  	values = simple.Uint64.Dedup(values)
    90  	if len(values) == 0 {
    91  		panic("empty values")
    92  	}
    93  
    94  	b.mu.Lock()
    95  	defer b.mu.Unlock()
    96  
    97  	if b.end == slotSize*slotNum-1 {
    98  		return false
    99  	}
   100  
   101  	b.end++
   102  	offset := uint32(b.end / fastSlotSize)
   103  	for _, v := range values {
   104  		h := h16(uint32(v), b.start)
   105  		for i := 0; i < bfHash; i++ {
   106  			b.fastTable.Add(h[i]&fastSlotMask | offset)
   107  		}
   108  	}
   109  
   110  	m := b.slots[b.end/slotSize]
   111  	m.mu.Lock()
   112  	defer m.mu.Unlock()
   113  
   114  	xf := xfNew(values)
   115  
   116  	m.keys = append(m.keys, key)
   117  	m.xfs = append(m.xfs, xf...)
   118  	if len(m.spans) == 0 {
   119  		m.spans = append(m.spans, uint32(len(xf)))
   120  	} else {
   121  		m.spans = append(m.spans, m.spans[len(m.spans)-1]+uint32(len(xf)))
   122  	}
   123  	return true
   124  }
   125  
   126  func (b *Range) Join(vs Values, start int64, desc bool, f func(KeyIdScore) bool) (jm JoinMetrics) {
   127  	vs.Clean()
   128  	fastStart := time.Now()
   129  	fast := b.joinFast(&vs)
   130  	jm.FastElapsed = time.Since(fastStart)
   131  	jm.BaseStart = b.start
   132  	jm.Start = start
   133  	jm.Values = vs
   134  	jm.Desc = desc
   135  
   136  	if start == -1 {
   137  		start = b.end
   138  	} else {
   139  		start -= b.start
   140  		if start < 0 || start >= slotNum*slotSize {
   141  			return
   142  		}
   143  	}
   144  
   145  	startSlot := int(start / slotSize)
   146  
   147  	endSlot, endCmp, step := -1, 1, -1
   148  	if !desc {
   149  		endSlot, endCmp, step = slotNum, -1, 1
   150  	}
   151  
   152  	for i := startSlot; icmp(int64(i), int64(endSlot)) == endCmp; i += step {
   153  		if fast[i] == 0 {
   154  			continue
   155  		}
   156  
   157  		m := b.slots[i]
   158  		startOffset := start - int64(i*slotSize)
   159  		if startOffset >= int64(len(m.keys)) {
   160  			startOffset = int64(len(m.keys)) - 1
   161  		}
   162  		if startOffset < 0 {
   163  			startOffset = 0
   164  		}
   165  		if exit := m.join(vs, i, &fast, startOffset, desc, b.start, &jm, f); exit {
   166  			break
   167  		}
   168  	}
   169  
   170  	jm.Elapsed = time.Since(fastStart)
   171  	return
   172  }
   173  
   174  func (b *subMap) prevSpan(i int64) uint32 {
   175  	if i == 0 {
   176  		return 0
   177  	}
   178  	return b.spans[i-1]
   179  }
   180  
   181  func (b *subMap) join(v Values, hr int, fast *bitmap1024, end1 int64, desc bool,
   182  	baseStart int64, jm *JoinMetrics, f func(KeyIdScore) bool) bool {
   183  	b.mu.RLock()
   184  	defer b.mu.RUnlock()
   185  
   186  	start := time.Now()
   187  	exit := false
   188  	ms := v.majorScore()
   189  
   190  	iend, cmp, step := int64(-1), 1, int64(-1)
   191  	if !desc {
   192  		iend, cmp, step = int64(len(b.keys)), -1, 1
   193  	}
   194  
   195  NEXT:
   196  	for i := end1; icmp(i, iend) == cmp; i += step {
   197  		if !fast.contains(uint16((hr*slotSize + int(i)) / fastSlotSize)) {
   198  			continue
   199  		}
   200  		jm.Slots[hr].Scans++
   201  		xf, vs := xfBuild(b.xfs[b.prevSpan(i):b.spans[i]])
   202  
   203  		oneof := true
   204  		for _, hs := range v.Oneof {
   205  			if oneof = xfContains(xf, vs, hs); oneof {
   206  				break
   207  			}
   208  		}
   209  		if !oneof {
   210  			continue
   211  		}
   212  
   213  		s := 0
   214  		for _, hs := range v.Major {
   215  			if xfContains(xf, vs, hs) {
   216  				s++
   217  			}
   218  		}
   219  		if s < ms {
   220  			continue
   221  		}
   222  
   223  		for _, hs := range v.Exact {
   224  			if !xfContains(xf, vs, hs) {
   225  				continue NEXT
   226  			}
   227  		}
   228  
   229  		jm.Slots[hr].Hits++
   230  		if !f(KeyIdScore{
   231  			Key:   b.keys[i],
   232  			Id:    int64(hr*slotSize) + int64(i) + baseStart,
   233  			Score: s,
   234  		}) {
   235  			exit = true
   236  			break
   237  		}
   238  	}
   239  
   240  	jm.Slots[hr].Elapsed = time.Since(start)
   241  	return exit
   242  }
   243  
   244  func (b *Range) Clone() *Range {
   245  	b.mu.RLock()
   246  	defer b.mu.RUnlock()
   247  
   248  	b2 := &Range{}
   249  	b2.start = b.start
   250  	b2.end = b.end
   251  	b2.fastTable = b.fastTable.Clone()
   252  	for i := range b2.slots {
   253  		b2.slots[i] = b.slots[i].clone()
   254  	}
   255  	return b2
   256  }
   257  
   258  func (b *subMap) clone() *subMap {
   259  	b.mu.RLock()
   260  	defer b.mu.RUnlock()
   261  	return &subMap{
   262  		keys:  b.keys,
   263  		spans: b.spans,
   264  		xfs:   b.xfs,
   265  	}
   266  }
   267  
   268  func Unmarshal(rd io.Reader) (*Range, error) {
   269  	var err error
   270  	var ver byte
   271  	if err := binary.Read(rd, binary.BigEndian, &ver); err != nil {
   272  		return nil, fmt.Errorf("read version: %v", err)
   273  	}
   274  	if ver == 4 {
   275  		rd = lz4.NewReader(rd)
   276  	}
   277  
   278  	b := &Range{}
   279  	h := crc32.NewIEEE()
   280  	rd = io.TeeReader(rd, h)
   281  
   282  	if err := binary.Read(rd, binary.BigEndian, &b.start); err != nil {
   283  		return nil, fmt.Errorf("read start: %v", err)
   284  	}
   285  
   286  	if err := binary.Read(rd, binary.BigEndian, &b.end); err != nil {
   287  		return nil, fmt.Errorf("read end: %v", err)
   288  	}
   289  
   290  	var z byte = bfHash
   291  	if err := binary.Read(rd, binary.BigEndian, &z); err != nil {
   292  		return nil, fmt.Errorf("read hashNum: %v", err)
   293  	}
   294  
   295  	var topSize uint64
   296  	if err := binary.Read(rd, binary.BigEndian, &topSize); err != nil {
   297  		return nil, fmt.Errorf("read fast table bitmap size: %v", err)
   298  	}
   299  
   300  	b.fastTable = roaring.New()
   301  	if _, err := b.fastTable.ReadFrom(io.LimitReader(rd, int64(topSize))); err != nil {
   302  		return nil, fmt.Errorf("read fast table bitmap: %v", err)
   303  	}
   304  
   305  	for i := range b.slots {
   306  		b.slots[i], err = readSubMap(rd)
   307  		if err != nil {
   308  			return nil, err
   309  		}
   310  	}
   311  
   312  	verify := h.Sum32()
   313  	var checksum uint32
   314  	if err := binary.Read(rd, binary.BigEndian, &checksum); err != nil {
   315  		return nil, fmt.Errorf("read checksum: %v", err)
   316  	}
   317  	if checksum != verify {
   318  		return nil, fmt.Errorf("invalid header checksum %x and %x", verify, checksum)
   319  	}
   320  	if err != nil {
   321  		return nil, fmt.Errorf("read header: %v", err)
   322  	}
   323  
   324  	return b, nil
   325  }
   326  
   327  func readSubMap(rd io.Reader) (*subMap, error) {
   328  	b := &subMap{}
   329  
   330  	var keysLen uint32
   331  	if err := binary.Read(rd, binary.BigEndian, &keysLen); err != nil {
   332  		return nil, fmt.Errorf("read keys length: %v", err)
   333  	}
   334  
   335  	tmp := make([]byte, keysLen*uint32(KeySize))
   336  	if err := binary.Read(rd, binary.BigEndian, tmp); err != nil {
   337  		return nil, fmt.Errorf("read keys: %v", err)
   338  	}
   339  	b.keys = bytesKeys(tmp)
   340  
   341  	b.spans = make([]uint32, keysLen)
   342  	if err := binary.Read(rd, binary.BigEndian, b.spans); err != nil {
   343  		return nil, fmt.Errorf("read spans: %v", err)
   344  	}
   345  
   346  	if len(b.spans) > 0 {
   347  		b.xfs = make([]byte, b.spans[len(b.spans)-1])
   348  		if err := binary.Read(rd, binary.BigEndian, b.xfs); err != nil {
   349  			return nil, fmt.Errorf("read xfs: %v", err)
   350  		}
   351  	}
   352  
   353  	return b, nil
   354  }
   355  
   356  func (b *Range) MarshalBinary(compress bool) []byte {
   357  	p := &bytes.Buffer{}
   358  	b.Marshal(p, compress)
   359  	return p.Bytes()
   360  }
   361  
   362  func (b *Range) Marshal(w io.Writer, compress bool) (int, error) {
   363  	b.mu.RLock()
   364  	defer b.mu.RUnlock()
   365  
   366  	mw := &meterWriter{Writer: w}
   367  
   368  	var zw io.WriteCloser
   369  	if compress {
   370  		mw.Write([]byte{4})
   371  		zw = lz4.NewWriter(mw)
   372  	} else {
   373  		mw.Write([]byte{1})
   374  		zw = mw
   375  	}
   376  
   377  	h := crc32.NewIEEE()
   378  	w = io.MultiWriter(zw, h)
   379  
   380  	if err := binary.Write(w, binary.BigEndian, b.start); err != nil {
   381  		return 0, err
   382  	}
   383  	if err := binary.Write(w, binary.BigEndian, b.end); err != nil {
   384  		return 0, err
   385  	}
   386  	if err := binary.Write(w, binary.BigEndian, byte(bfHash)); err != nil {
   387  		return 0, err
   388  	}
   389  	if err := binary.Write(w, binary.BigEndian, b.fastTable.GetSerializedSizeInBytes()); err != nil {
   390  		return 0, err
   391  	}
   392  	if _, err := b.fastTable.WriteTo(w); err != nil {
   393  		return 0, err
   394  	}
   395  	for _, h := range b.slots {
   396  		if err := h.writeTo(w); err != nil {
   397  			return 0, err
   398  		}
   399  	}
   400  	// Write CRC32 checksum to the end of stream.
   401  	if err := binary.Write(w, binary.BigEndian, h.Sum32()); err != nil {
   402  		return 0, err
   403  	}
   404  	if err := zw.Close(); err != nil {
   405  		return 0, err
   406  	}
   407  	return mw.size, nil
   408  }
   409  
   410  func (b *subMap) writeTo(w io.Writer) error {
   411  	b.mu.RLock()
   412  	defer b.mu.RUnlock()
   413  
   414  	if err := binary.Write(w, binary.BigEndian, uint32(len(b.keys))); err != nil {
   415  		return err
   416  	}
   417  	if _, err := w.Write(keysBytes(b.keys)); err != nil {
   418  		return err
   419  	}
   420  	if err := binary.Write(w, binary.BigEndian, b.spans); err != nil {
   421  		return err
   422  	}
   423  	if err := binary.Write(w, binary.BigEndian, b.xfs); err != nil {
   424  		return err
   425  	}
   426  	return nil
   427  }
   428  
   429  func (b *Range) RoughSizeBytes() (sz int64) {
   430  	b.mu.RLock()
   431  	defer b.mu.RUnlock()
   432  	sz += int64(b.fastTable.GetSizeInBytes())
   433  	for i := range b.slots {
   434  		sz += int64(len(b.slots[i].xfs))
   435  		sz += int64(len(b.slots[i].keys)) * (int64(KeySize) + 4)
   436  	}
   437  	return
   438  }
   439  
   440  func (b *Range) String() string {
   441  	b.mu.RLock()
   442  	defer b.mu.RUnlock()
   443  
   444  	start := time.Now()
   445  	buf := &bytes.Buffer{}
   446  
   447  	m := roaring.New()
   448  	b.fastTable.Iterate(func(x uint32) bool {
   449  		m.Add(x & fastSlotMask)
   450  		return x < math.MaxUint32/32
   451  	})
   452  
   453  	fmt.Fprintf(buf, "range: %d-%d, len: %d, rough size: %db\n", b.Start(), b.End(), b.Len(), b.RoughSizeBytes())
   454  	fmt.Fprintf(buf, "fast table len: %d, approx hash num: %d, size: %db\n",
   455  		b.fastTable.GetCardinality(), m.GetCardinality()*32, b.fastTable.GetSerializedSizeInBytes())
   456  	for i, h := range b.slots {
   457  		h.debug(i, buf)
   458  	}
   459  
   460  	fmt.Fprintf(buf, "collected in %v", time.Since(start))
   461  	return buf.String()
   462  }
   463  
   464  func (b *subMap) debug(i int, buf io.Writer) {
   465  	b.mu.RLock()
   466  	defer b.mu.RUnlock()
   467  	if len(b.keys) > 0 {
   468  		fmt.Fprintf(buf, "[%02d;0x%05x] ", i, i*slotSize)
   469  		fmt.Fprintf(buf, "keys: %5d/%2d, last key: %v, filter size: %db\n",
   470  			len(b.keys), len(b.keys)/fastSlotSize, b.keys[len(b.keys)-1], len(b.xfs))
   471  	}
   472  }
   473  
   474  func (b *Range) joinFast(vs *Values) (res bitmap1024) {
   475  	b.mu.RLock()
   476  	defer b.mu.RUnlock()
   477  
   478  	type hashState struct {
   479  		h uint32
   480  		bitmap1024
   481  	}
   482  
   483  	hashStates := map[uint32]*hashState{}
   484  	fill := func(hashes []uint64) [][4]uint32 {
   485  		var out [][4]uint32
   486  		for _, v := range hashes {
   487  			h := h16(uint32(v), b.start)
   488  			for i := 0; i < bfHash; i++ {
   489  				hashStates[h[i]] = &hashState{h: h[i] & fastSlotMask}
   490  			}
   491  			out = append(out, h)
   492  		}
   493  		return out
   494  	}
   495  	oneofHashes := fill(vs.Oneof)
   496  	majorHashes := fill(vs.Major)
   497  	exactHashes := fill(vs.Exact)
   498  
   499  	iter := b.fastTable.Iterator().(*roaring.IntIterator)
   500  	for _, hs := range hashStates {
   501  		iter.Seek(hs.h)
   502  		for iter.HasNext() {
   503  			h2 := iter.Next()
   504  			if h2&fastSlotMask == hs.h {
   505  				hs.add(uint16(h2 - hs.h))
   506  			} else {
   507  				break
   508  			}
   509  		}
   510  	}
   511  
   512  	// z := time.Now()
   513  	var final *bitmap1024
   514  	for _, raw := range oneofHashes {
   515  		m := hashStates[raw[0]].bitmap1024
   516  		for i := 1; i < bfHash; i++ {
   517  			m.and(&hashStates[raw[i]].bitmap1024)
   518  		}
   519  		if final == nil {
   520  			final = &m
   521  		} else {
   522  			final.or(&m)
   523  		}
   524  	}
   525  
   526  	var major *bitmap1024
   527  	var scores map[uint16]int
   528  	for _, raw := range majorHashes {
   529  		m := hashStates[raw[0]].bitmap1024
   530  		for i := 1; i < bfHash; i++ {
   531  			m.and(&hashStates[raw[i]].bitmap1024)
   532  		}
   533  		if major == nil {
   534  			major = &m
   535  			scores = map[uint16]int{}
   536  		} else {
   537  			major.or(&m)
   538  		}
   539  		m.iterate(func(x uint16) bool { scores[x]++; return true })
   540  	}
   541  	if major != nil {
   542  		ms := vs.majorScore()
   543  		var res bitmap1024
   544  		major.iterate(func(offset uint16) bool {
   545  			if scores[offset] >= ms {
   546  				res.add(offset)
   547  			}
   548  			return true
   549  		})
   550  		if final == nil {
   551  			final = &res
   552  		} else {
   553  			final.and(&res)
   554  		}
   555  	}
   556  
   557  	for _, raw := range exactHashes {
   558  		m := hashStates[raw[0]].bitmap1024
   559  		for i := 1; i < bfHash; i++ {
   560  			m.and(&hashStates[raw[i]].bitmap1024)
   561  		}
   562  		if final == nil {
   563  			final = &m
   564  		} else {
   565  			final.and(&m)
   566  		}
   567  	}
   568  
   569  	if final == nil {
   570  		return bitmap1024{}
   571  	}
   572  	return *final
   573  }
   574  
   575  func (b *Range) Find(key Key) (int64, func(uint64) bool) {
   576  	for hr, m := range b.slots {
   577  		m.mu.Lock()
   578  		for i, k := range m.keys {
   579  			if k == key {
   580  				x, vs := xfBuild(m.xfs[m.prevSpan(int64(i)):m.spans[i]])
   581  				m.mu.Unlock()
   582  				return int64(hr)*slotSize + int64(i), func(k uint64) bool { return xfContains(x, vs, k) }
   583  			}
   584  		}
   585  		m.mu.Unlock()
   586  	}
   587  	return 0, nil
   588  }