github.com/matrixorigin/matrixone@v0.7.0/pkg/lockservice/waiter.go (about)

     1  // Copyright 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package lockservice
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"runtime"
    21  	"sync"
    22  	"sync/atomic"
    23  )
    24  
    25  var (
    26  	waiterPool = sync.Pool{
    27  		New: func() any {
    28  			return newWaiter()
    29  		},
    30  	}
    31  )
    32  
    33  func acquireWaiter(txnID []byte) *waiter {
    34  	w := waiterPool.Get().(*waiter)
    35  	w.txnID = txnID
    36  	if w.ref() != 1 {
    37  		panic("BUG: invalid ref count")
    38  	}
    39  	w.beforeSwapStatusAdjustFunc = func() {}
    40  	return w
    41  }
    42  
    43  func newWaiter() *waiter {
    44  	w := &waiter{
    45  		c:       make(chan error, 1),
    46  		waiters: newWaiterQueue(),
    47  	}
    48  	w.setFinalizer()
    49  	w.setStatus(waiting)
    50  	return w
    51  }
    52  
    53  type waiterStatus int32
    54  
    55  const (
    56  	waiting waiterStatus = iota
    57  	notified
    58  	completed
    59  )
    60  
    61  // waiter is used to allow locking operations to wait for the previous
    62  // lock to be released if a conflict is encountered.
    63  // Each Lock holds one instance of waiter to hold all waiters. Suppose
    64  // we have 3 transactions A, B and a record k1, the pseudocode of how to
    65  // use waiter is as follows:
    66  // 1. A get LockStorage s1
    67  // 2. s1.Lock()
    68  // 3. use s1.Seek(k1) to check conflict, s1.add(Lock(k1, waiter-k1-A))
    69  // 4. s1.Unlock()
    70  // 5. B get LockStorage s1
    71  // 6. s1.Lock
    72  // 7. use s1.Seek(k1) to check conflict, and found Lock(k1, waiter-k1-A)
    73  // 8. so waiter-k1-A.add(waiter-k1-B)
    74  // 9. s1.Unlock
    75  // 10. waiter-k1-B.wait()
    76  // 11. A completed
    77  // 12. s1.Lock()
    78  // 14. replace Lock(k1, waiter-k1-A) to Lock(k1, waiter-k1-B)
    79  // 15. waiter-k1-A.close(), move all others waiters into waiter-k1-B.
    80  // 16. s1.Unlock()
    81  // 17. waiter-k1-B.wait() returned and get the lock
    82  type waiter struct {
    83  	txnID    []byte
    84  	status   atomic.Int32
    85  	c        chan error
    86  	waiters  waiterQueue
    87  	refCount atomic.Int32
    88  
    89  	// just used for testing
    90  	beforeSwapStatusAdjustFunc func()
    91  }
    92  
    93  func (w *waiter) setFinalizer() {
    94  	// close the channel if gc
    95  	runtime.SetFinalizer(w, func(w *waiter) {
    96  		close(w.c)
    97  	})
    98  }
    99  
   100  func (w *waiter) ref() int32 {
   101  	return w.refCount.Add(1)
   102  }
   103  
   104  func (w *waiter) unref() {
   105  	n := w.refCount.Add(-1)
   106  	if n < 0 {
   107  		panic("BUG: invalid ref count")
   108  	}
   109  	if n == 0 {
   110  		w.reset()
   111  	}
   112  }
   113  
   114  func (w *waiter) add(waiter ...*waiter) {
   115  	if len(waiter) == 0 {
   116  		return
   117  	}
   118  	w.waiters.put(waiter...)
   119  	for i := range waiter {
   120  		waiter[i].ref()
   121  	}
   122  }
   123  
   124  func (w *waiter) getStatus() waiterStatus {
   125  	return waiterStatus(w.status.Load())
   126  }
   127  
   128  func (w *waiter) setStatus(status waiterStatus) {
   129  	w.status.Store(int32(status))
   130  }
   131  
   132  func (w *waiter) casStatus(old, new waiterStatus) bool {
   133  	return w.status.CompareAndSwap(int32(old), int32(new))
   134  }
   135  
   136  func (w *waiter) mustRecvNotification() error {
   137  	select {
   138  	case err := <-w.c:
   139  		return err
   140  	default:
   141  	}
   142  	panic("BUG: must recv result from channel")
   143  }
   144  
   145  func (w *waiter) mustSendNotification(value error) {
   146  	select {
   147  	case w.c <- value:
   148  		return
   149  	default:
   150  	}
   151  	panic("BUG: must send value to channel")
   152  }
   153  
   154  func (w *waiter) resetWait() {
   155  	if w.casStatus(completed, waiting) {
   156  		return
   157  	}
   158  	panic("invalid reset wait")
   159  }
   160  
   161  func (w *waiter) wait(ctx context.Context) error {
   162  	status := w.getStatus()
   163  	if status == notified {
   164  		w.setStatus(completed)
   165  		return w.mustRecvNotification()
   166  	}
   167  	if status != waiting {
   168  		panic(fmt.Sprintf("BUG: waiter's status cannot be %d", status))
   169  	}
   170  
   171  	w.beforeSwapStatusAdjustFunc()
   172  
   173  	select {
   174  	case err := <-w.c:
   175  		w.setStatus(completed)
   176  		return err
   177  	case <-ctx.Done():
   178  	}
   179  
   180  	w.beforeSwapStatusAdjustFunc()
   181  
   182  	// context is timeout, and status not changed, no concurrent happen
   183  	if w.casStatus(status, completed) {
   184  		return ctx.Err()
   185  	}
   186  
   187  	// notify and timeout are concurrently issued, we use real result to replace
   188  	// timeout error
   189  	w.setStatus(completed)
   190  	return w.mustRecvNotification()
   191  }
   192  
   193  // notify return false means this waiter is completed, cannot be used to notify
   194  func (w *waiter) notify(value error) bool {
   195  	for {
   196  		status := w.getStatus()
   197  		if status == notified {
   198  			panic("already notified")
   199  		}
   200  		if status == completed {
   201  			// wait already completed, wait timeout or wait a result.
   202  			return false
   203  		}
   204  
   205  		w.beforeSwapStatusAdjustFunc()
   206  		// if status changed, notify and timeout are concurrently issued, need
   207  		// retry.
   208  		if w.casStatus(status, notified) {
   209  			w.mustSendNotification(value)
   210  			return true
   211  		}
   212  	}
   213  }
   214  
   215  // close returns the next waiter to hold the lock, and others waiters will move
   216  // into the next waiter.
   217  func (w *waiter) close() *waiter {
   218  	nextWaiter := w.fetchNextWaiter()
   219  	w.unref()
   220  	return nextWaiter
   221  }
   222  
   223  func (w *waiter) fetchNextWaiter() *waiter {
   224  	if w.waiters.len() == 0 {
   225  		return nil
   226  	}
   227  	next := w.awakeNextWaiter()
   228  	for {
   229  		if next.notify(nil) {
   230  			next.unref()
   231  			return next
   232  		}
   233  		if next.waiters.len() == 0 {
   234  			return nil
   235  		}
   236  		next = next.awakeNextWaiter()
   237  	}
   238  }
   239  
   240  func (w *waiter) awakeNextWaiter() *waiter {
   241  	next, remains := w.waiters.pop()
   242  	next.add(remains...)
   243  	w.waiters.reset()
   244  	return next
   245  }
   246  
   247  func (w *waiter) reset() {
   248  	if w.waiters.len() > 0 || len(w.c) > 0 {
   249  		panic("BUG: waiter should be empty.")
   250  	}
   251  	w.setStatus(waiting)
   252  	w.waiters.reset()
   253  	waiterPool.Put(w)
   254  }