github.com/zuoyebang/bitalosdb@v1.1.1-0.20240516111551-79a8c4d8ce20/commit.go (about)

     1  // Copyright 2021 The Bitalosdb author(hustxrb@163.com) and other contributors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package bitalosdb
    16  
    17  import (
    18  	"runtime"
    19  	"sync"
    20  	"sync/atomic"
    21  	"unsafe"
    22  
    23  	"github.com/zuoyebang/bitalosdb/internal/record"
    24  )
    25  
    26  type commitQueue struct {
    27  	headTail uint64
    28  	slots    [record.SyncConcurrency]unsafe.Pointer
    29  }
    30  
    31  const dequeueBits = 32
    32  
    33  func (q *commitQueue) unpack(ptrs uint64) (head, tail uint32) {
    34  	const mask = 1<<dequeueBits - 1
    35  	head = uint32((ptrs >> dequeueBits) & mask)
    36  	tail = uint32(ptrs & mask)
    37  	return
    38  }
    39  
    40  func (q *commitQueue) pack(head, tail uint32) uint64 {
    41  	const mask = 1<<dequeueBits - 1
    42  	return (uint64(head) << dequeueBits) |
    43  		uint64(tail&mask)
    44  }
    45  
    46  func (q *commitQueue) enqueue(b *BatchBitower) {
    47  	ptrs := atomic.LoadUint64(&q.headTail)
    48  	head, tail := q.unpack(ptrs)
    49  	if (tail+uint32(len(q.slots)))&(1<<dequeueBits-1) == head {
    50  		panic("bitalosdb: not reached")
    51  	}
    52  	slot := &q.slots[head&uint32(len(q.slots)-1)]
    53  
    54  	for atomic.LoadPointer(slot) != nil {
    55  		runtime.Gosched()
    56  	}
    57  
    58  	atomic.StorePointer(slot, unsafe.Pointer(b))
    59  
    60  	atomic.AddUint64(&q.headTail, 1<<dequeueBits)
    61  }
    62  
    63  func (q *commitQueue) dequeue() *BatchBitower {
    64  	for {
    65  		ptrs := atomic.LoadUint64(&q.headTail)
    66  		head, tail := q.unpack(ptrs)
    67  		if tail == head {
    68  			return nil
    69  		}
    70  
    71  		slot := &q.slots[tail&uint32(len(q.slots)-1)]
    72  		b := (*BatchBitower)(atomic.LoadPointer(slot))
    73  		if b == nil || atomic.LoadUint32(&b.applied) == 0 {
    74  			return nil
    75  		}
    76  
    77  		ptrs2 := q.pack(head, tail+1)
    78  		if atomic.CompareAndSwapUint64(&q.headTail, ptrs, ptrs2) {
    79  			atomic.StorePointer(slot, nil)
    80  			return b
    81  		}
    82  	}
    83  }
    84  
    85  type commitEnv struct {
    86  	logSeqNum     *uint64
    87  	visibleSeqNum *uint64
    88  	apply         func(b *BatchBitower, mem *memTable) error
    89  	write         func(b *BatchBitower, wg *sync.WaitGroup, err *error) (*memTable, error)
    90  	useQueue      bool
    91  }
    92  
    93  type commitPipeline struct {
    94  	pending commitQueue
    95  	env     commitEnv
    96  	sem     chan struct{}
    97  	mu      sync.Mutex
    98  }
    99  
   100  func newCommitPipeline(env commitEnv) *commitPipeline {
   101  	p := &commitPipeline{
   102  		env: env,
   103  		sem: make(chan struct{}, record.SyncConcurrency-1),
   104  	}
   105  	return p
   106  }
   107  
   108  func (p *commitPipeline) Commit(b *BatchBitower, syncWAL bool) error {
   109  	if b.Empty() {
   110  		return nil
   111  	}
   112  
   113  	if p.env.useQueue {
   114  		p.sem <- struct{}{}
   115  		defer func() {
   116  			<-p.sem
   117  		}()
   118  	}
   119  
   120  	mem, err := p.prepare(b, syncWAL)
   121  	if err != nil {
   122  		b.db = nil
   123  		return err
   124  	}
   125  
   126  	if err = p.env.apply(b, mem); err != nil {
   127  		b.db = nil
   128  		return err
   129  	}
   130  
   131  	p.publish(b)
   132  
   133  	if b.commitErr != nil {
   134  		b.db = nil
   135  	}
   136  	return b.commitErr
   137  }
   138  
   139  func (p *commitPipeline) prepare(b *BatchBitower, syncWAL bool) (*memTable, error) {
   140  	n := uint64(b.Count())
   141  	if n == invalidBatchCount {
   142  		return nil, ErrInvalidBatch
   143  	}
   144  
   145  	if p.env.useQueue {
   146  		count := 1
   147  		if syncWAL {
   148  			count++
   149  		}
   150  		b.commit.Add(count)
   151  	}
   152  
   153  	var syncWG *sync.WaitGroup
   154  	var syncErr *error
   155  	if syncWAL {
   156  		syncWG, syncErr = &b.commit, &b.commitErr
   157  	}
   158  
   159  	p.mu.Lock()
   160  
   161  	if p.env.useQueue {
   162  		p.pending.enqueue(b)
   163  	}
   164  
   165  	b.setSeqNum(atomic.AddUint64(p.env.logSeqNum, n) - n)
   166  
   167  	mem, err := p.env.write(b, syncWG, syncErr)
   168  
   169  	p.mu.Unlock()
   170  
   171  	return mem, err
   172  }
   173  
   174  func (p *commitPipeline) publish(b *BatchBitower) {
   175  	atomic.StoreUint32(&b.applied, 1)
   176  
   177  	if p.env.useQueue {
   178  		for {
   179  			t := p.pending.dequeue()
   180  			if t == nil {
   181  				b.commit.Wait()
   182  				break
   183  			}
   184  
   185  			if atomic.LoadUint32(&t.applied) != 1 {
   186  				b.db.opts.Logger.Errorf("panic: commitPipeline publish batch applied err")
   187  			}
   188  
   189  			for {
   190  				curSeqNum := atomic.LoadUint64(p.env.visibleSeqNum)
   191  				newSeqNum := t.SeqNum() + uint64(t.Count())
   192  				if newSeqNum <= curSeqNum {
   193  					break
   194  				}
   195  				if atomic.CompareAndSwapUint64(p.env.visibleSeqNum, curSeqNum, newSeqNum) {
   196  					break
   197  				}
   198  			}
   199  
   200  			t.commit.Done()
   201  		}
   202  	} else {
   203  		for {
   204  			curSeqNum := atomic.LoadUint64(p.env.visibleSeqNum)
   205  			newSeqNum := b.SeqNum() + uint64(b.Count())
   206  			if newSeqNum <= curSeqNum {
   207  				break
   208  			}
   209  			if atomic.CompareAndSwapUint64(p.env.visibleSeqNum, curSeqNum, newSeqNum) {
   210  				break
   211  			}
   212  		}
   213  	}
   214  }
   215  
   216  func (p *commitPipeline) ratchetSeqNum(nextSeqNum uint64) {
   217  	p.mu.Lock()
   218  	defer p.mu.Unlock()
   219  
   220  	logSeqNum := atomic.LoadUint64(p.env.logSeqNum)
   221  	if logSeqNum >= nextSeqNum {
   222  		return
   223  	}
   224  	count := nextSeqNum - logSeqNum
   225  	_ = atomic.AddUint64(p.env.logSeqNum, uint64(count)) - uint64(count)
   226  	atomic.StoreUint64(p.env.visibleSeqNum, nextSeqNum)
   227  }