github.com/matrixorigin/matrixone@v0.7.0/pkg/lockservice/deadlock.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package lockservice
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"sync"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    23  	"github.com/matrixorigin/matrixone/pkg/common/stopper"
    24  )
    25  
    26  var (
    27  	// ErrDeadlockDetectorClosed deadlock detector is closed
    28  	ErrDeadlockDetectorClosed = moerr.NewInvalidStateNoCtx("deadlock detector is closed")
    29  )
    30  
    31  type detector struct {
    32  	c                 chan []byte
    33  	waitTxnsFetchFunc func([]byte, *waiters) bool
    34  	waitTxnAbortFunc  func([]byte)
    35  	ignoreTxns        sync.Map // txnID -> any
    36  	stopper           *stopper.Stopper
    37  	mu                struct {
    38  		sync.RWMutex
    39  		closed bool
    40  	}
    41  }
    42  
    43  // newDeadlockDetector create a deadlock detector, waitTxnsFetchFun is used to get the waiting txns
    44  // for the given txn. Then the detector will recursively check all txns's waiting txns until deadlock
    45  // is found. When a deadlock is found, waitTxnAbortFunc is used to notify the external abort to drop a
    46  // txn.
    47  func newDeadlockDetector(waitTxnsFetchFunc func([]byte, *waiters) bool,
    48  	waitTxnAbortFunc func([]byte)) *detector {
    49  	d := &detector{
    50  		c:                 make(chan []byte, 1024),
    51  		waitTxnsFetchFunc: waitTxnsFetchFunc,
    52  		waitTxnAbortFunc:  waitTxnAbortFunc,
    53  		stopper:           stopper.NewStopper("deadlock-detector"),
    54  	}
    55  	err := d.stopper.RunTask(d.doCheck)
    56  	if err != nil {
    57  		panic("impossible")
    58  	}
    59  	return d
    60  }
    61  
    62  func (d *detector) close() {
    63  	d.mu.Lock()
    64  	d.mu.closed = true
    65  	d.mu.Unlock()
    66  	d.stopper.Stop()
    67  	close(d.c)
    68  }
    69  
    70  func (d *detector) txnClosed(txnID []byte) {
    71  	v := unsafeByteSliceToString(txnID)
    72  	d.ignoreTxns.Delete(v)
    73  }
    74  
    75  func (d *detector) check(txnID []byte) error {
    76  	d.mu.RLock()
    77  	defer d.mu.RUnlock()
    78  	if d.mu.closed {
    79  		return ErrDeadlockDetectorClosed
    80  	}
    81  
    82  	d.c <- txnID
    83  	return nil
    84  }
    85  
    86  func (d *detector) doCheck(ctx context.Context) {
    87  	w := &waiters{ignoreTxns: &d.ignoreTxns}
    88  	for {
    89  		select {
    90  		case <-ctx.Done():
    91  			return
    92  		case txnID := <-d.c:
    93  			w.reset(txnID)
    94  			v := string(txnID)
    95  			if !d.checkDeadlock(w) {
    96  				d.ignoreTxns.Store(v, struct{}{})
    97  				d.waitTxnAbortFunc(txnID)
    98  			}
    99  		}
   100  	}
   101  }
   102  
   103  func (d *detector) checkDeadlock(w *waiters) bool {
   104  	for {
   105  		if w.completed() {
   106  			return true
   107  		}
   108  
   109  		// find deadlock
   110  		txnID := w.getCheckTargetTxn()
   111  		if !d.waitTxnsFetchFunc(txnID, w) {
   112  			return false
   113  		}
   114  		w.next()
   115  	}
   116  }
   117  
   118  type waiters struct {
   119  	ignoreTxns *sync.Map
   120  	waitTxns   [][]byte
   121  	pos        int
   122  }
   123  
   124  func (w *waiters) getCheckTargetTxn() []byte {
   125  	return w.waitTxns[w.pos]
   126  }
   127  
   128  func (w *waiters) next() {
   129  	w.pos++
   130  }
   131  
   132  func (w *waiters) add(txnID []byte) bool {
   133  	if bytes.Equal(w.waitTxns[0], txnID) {
   134  		return false
   135  	}
   136  	v := unsafeByteSliceToString(txnID)
   137  	if _, ok := w.ignoreTxns.Load(v); ok {
   138  		return true
   139  	}
   140  	w.waitTxns = append(w.waitTxns, txnID)
   141  	return true
   142  }
   143  
   144  func (w *waiters) reset(txnID []byte) {
   145  	w.pos = 0
   146  	w.waitTxns = w.waitTxns[:0]
   147  	w.waitTxns = append(w.waitTxns, txnID)
   148  }
   149  
   150  func (w *waiters) completed() bool {
   151  	return w.pos == len(w.waitTxns)
   152  }