decred.org/dcrdex@v1.0.3/client/asset/btc/redemption_finder.go (about)

     1  package btc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"sync"
    10  	"time"
    11  
    12  	"decred.org/dcrdex/dex"
    13  	dexbtc "decred.org/dcrdex/dex/networks/btc"
    14  	"github.com/btcsuite/btcd/chaincfg"
    15  	"github.com/btcsuite/btcd/chaincfg/chainhash"
    16  	"github.com/btcsuite/btcd/wire"
    17  )
    18  
    19  // FindRedemptionReq represents a request to find a contract's redemption,
    20  // which is submitted to the RedemptionFinder.
    21  type FindRedemptionReq struct {
    22  	outPt        OutPoint
    23  	blockHash    *chainhash.Hash
    24  	blockHeight  int32
    25  	resultChan   chan *FindRedemptionResult
    26  	pkScript     []byte
    27  	contractHash []byte
    28  }
    29  
    30  func (req *FindRedemptionReq) fail(s string, a ...any) {
    31  	req.sendResult(&FindRedemptionResult{err: fmt.Errorf(s, a...)})
    32  }
    33  
    34  func (req *FindRedemptionReq) success(res *FindRedemptionResult) {
    35  	req.sendResult(res)
    36  }
    37  
    38  func (req *FindRedemptionReq) sendResult(res *FindRedemptionResult) {
    39  	select {
    40  	case req.resultChan <- res:
    41  	default:
    42  		// In-case two separate threads find a result.
    43  	}
    44  }
    45  
    46  // FindRedemptionResult models the result of a find redemption attempt.
    47  type FindRedemptionResult struct {
    48  	redemptionCoinID dex.Bytes
    49  	secret           dex.Bytes
    50  	err              error
    51  }
    52  
    53  // RedemptionFinder searches on-chain for the redemption of a swap transactions.
    54  type RedemptionFinder struct {
    55  	mtx         sync.RWMutex
    56  	log         dex.Logger
    57  	redemptions map[OutPoint]*FindRedemptionReq
    58  
    59  	getWalletTransaction      func(txHash *chainhash.Hash) (*GetTransactionResult, error)
    60  	getBlockHeight            func(*chainhash.Hash) (int32, error)
    61  	getBlock                  func(h chainhash.Hash) (*wire.MsgBlock, error)
    62  	getBlockHeader            func(blockHash *chainhash.Hash) (hdr *BlockHeader, mainchain bool, err error)
    63  	hashTx                    func(*wire.MsgTx) *chainhash.Hash
    64  	deserializeTx             func([]byte) (*wire.MsgTx, error)
    65  	getBestBlockHeight        func() (int32, error)
    66  	searchBlockForRedemptions func(ctx context.Context, reqs map[OutPoint]*FindRedemptionReq, blockHash chainhash.Hash) (discovered map[OutPoint]*FindRedemptionResult)
    67  	getBlockHash              func(blockHeight int64) (*chainhash.Hash, error)
    68  	findRedemptionsInMempool  func(ctx context.Context, reqs map[OutPoint]*FindRedemptionReq) (discovered map[OutPoint]*FindRedemptionResult)
    69  }
    70  
    71  func NewRedemptionFinder(
    72  	log dex.Logger,
    73  	getWalletTransaction func(txHash *chainhash.Hash) (*GetTransactionResult, error),
    74  	getBlockHeight func(*chainhash.Hash) (int32, error),
    75  	getBlock func(h chainhash.Hash) (*wire.MsgBlock, error),
    76  	getBlockHeader func(blockHash *chainhash.Hash) (hdr *BlockHeader, mainchain bool, err error),
    77  	hashTx func(*wire.MsgTx) *chainhash.Hash,
    78  	deserializeTx func([]byte) (*wire.MsgTx, error),
    79  	getBestBlockHeight func() (int32, error),
    80  	searchBlockForRedemptions func(ctx context.Context, reqs map[OutPoint]*FindRedemptionReq, blockHash chainhash.Hash) (discovered map[OutPoint]*FindRedemptionResult),
    81  	getBlockHash func(blockHeight int64) (*chainhash.Hash, error),
    82  	findRedemptionsInMempool func(ctx context.Context, reqs map[OutPoint]*FindRedemptionReq) (discovered map[OutPoint]*FindRedemptionResult),
    83  ) *RedemptionFinder {
    84  	return &RedemptionFinder{
    85  		log:                       log,
    86  		getWalletTransaction:      getWalletTransaction,
    87  		getBlockHeight:            getBlockHeight,
    88  		getBlock:                  getBlock,
    89  		getBlockHeader:            getBlockHeader,
    90  		hashTx:                    hashTx,
    91  		deserializeTx:             deserializeTx,
    92  		getBestBlockHeight:        getBestBlockHeight,
    93  		searchBlockForRedemptions: searchBlockForRedemptions,
    94  		getBlockHash:              getBlockHash,
    95  		findRedemptionsInMempool:  findRedemptionsInMempool,
    96  		redemptions:               make(map[OutPoint]*FindRedemptionReq),
    97  	}
    98  }
    99  
   100  func (r *RedemptionFinder) FindRedemption(ctx context.Context, coinID dex.Bytes) (redemptionCoin, secret dex.Bytes, err error) {
   101  	txHash, vout, err := decodeCoinID(coinID)
   102  	if err != nil {
   103  		return nil, nil, fmt.Errorf("cannot decode contract coin id: %w", err)
   104  	}
   105  
   106  	outPt := NewOutPoint(txHash, vout)
   107  
   108  	tx, err := r.getWalletTransaction(txHash)
   109  	if err != nil {
   110  		return nil, nil, fmt.Errorf("error finding wallet transaction: %v", err)
   111  	}
   112  
   113  	txOut, err := TxOutFromTxBytes(tx.Bytes, vout, r.deserializeTx, r.hashTx)
   114  	if err != nil {
   115  		return nil, nil, err
   116  	}
   117  	pkScript := txOut.PkScript
   118  
   119  	var blockHash *chainhash.Hash
   120  	if tx.BlockHash != "" {
   121  		blockHash, err = chainhash.NewHashFromStr(tx.BlockHash)
   122  		if err != nil {
   123  			return nil, nil, fmt.Errorf("error decoding block hash from string %q: %w",
   124  				tx.BlockHash, err)
   125  		}
   126  	}
   127  
   128  	var blockHeight int32
   129  	if blockHash != nil {
   130  		r.log.Infof("FindRedemption - Checking block %v for swap %v", blockHash, outPt)
   131  		blockHeight, err = r.checkRedemptionBlockDetails(outPt, blockHash, pkScript)
   132  		if err != nil {
   133  			return nil, nil, fmt.Errorf("checkRedemptionBlockDetails: op %v / block %q: %w",
   134  				outPt, tx.BlockHash, err)
   135  		}
   136  	}
   137  
   138  	req := &FindRedemptionReq{
   139  		outPt:        outPt,
   140  		blockHash:    blockHash,
   141  		blockHeight:  blockHeight,
   142  		resultChan:   make(chan *FindRedemptionResult, 1),
   143  		pkScript:     pkScript,
   144  		contractHash: dexbtc.ExtractScriptHash(pkScript),
   145  	}
   146  
   147  	if err := r.queueFindRedemptionRequest(req); err != nil {
   148  		return nil, nil, fmt.Errorf("queueFindRedemptionRequest error for redemption %s: %w", outPt, err)
   149  	}
   150  
   151  	go r.tryRedemptionRequests(ctx, nil, []*FindRedemptionReq{req})
   152  
   153  	var result *FindRedemptionResult
   154  	select {
   155  	case result = <-req.resultChan:
   156  		if result == nil {
   157  			err = fmt.Errorf("unexpected nil result for redemption search for %s", outPt)
   158  		}
   159  	case <-ctx.Done():
   160  		err = fmt.Errorf("context cancelled during search for redemption for %s", outPt)
   161  	}
   162  
   163  	// If this contract is still tracked, remove from the queue to prevent
   164  	// further redemption search attempts for this contract.
   165  	r.mtx.Lock()
   166  	delete(r.redemptions, outPt)
   167  	r.mtx.Unlock()
   168  
   169  	// result would be nil if ctx is canceled or the result channel is closed
   170  	// without data, which would happen if the redemption search is aborted when
   171  	// this ExchangeWallet is shut down.
   172  	if result != nil {
   173  		return result.redemptionCoinID, result.secret, result.err
   174  	}
   175  	return nil, nil, err
   176  }
   177  
   178  func (r *RedemptionFinder) checkRedemptionBlockDetails(outPt OutPoint, blockHash *chainhash.Hash, pkScript []byte) (int32, error) {
   179  	blockHeight, err := r.getBlockHeight(blockHash)
   180  	if err != nil {
   181  		return 0, fmt.Errorf("GetBlockHeight for redemption block %s error: %w", blockHash, err)
   182  	}
   183  	blk, err := r.getBlock(*blockHash)
   184  	if err != nil {
   185  		return 0, fmt.Errorf("error retrieving redemption block %s: %w", blockHash, err)
   186  	}
   187  
   188  	var tx *wire.MsgTx
   189  out:
   190  	for _, iTx := range blk.Transactions {
   191  		if *r.hashTx(iTx) == outPt.TxHash {
   192  			tx = iTx
   193  			break out
   194  		}
   195  	}
   196  	if tx == nil {
   197  		return 0, fmt.Errorf("transaction %s not found in block %s", outPt.TxHash, blockHash)
   198  	}
   199  	if uint32(len(tx.TxOut)) < outPt.Vout+1 {
   200  		return 0, fmt.Errorf("no output %d in redemption transaction %s found in block %s", outPt.Vout, outPt.TxHash, blockHash)
   201  	}
   202  	if !bytes.Equal(tx.TxOut[outPt.Vout].PkScript, pkScript) {
   203  		return 0, fmt.Errorf("pubkey script mismatch for redemption at %s", outPt)
   204  	}
   205  
   206  	return blockHeight, nil
   207  }
   208  
   209  func (r *RedemptionFinder) queueFindRedemptionRequest(req *FindRedemptionReq) error {
   210  	r.mtx.Lock()
   211  	defer r.mtx.Unlock()
   212  	if _, exists := r.redemptions[req.outPt]; exists {
   213  		return fmt.Errorf("duplicate find redemption request for %s", req.outPt)
   214  	}
   215  	r.redemptions[req.outPt] = req
   216  	return nil
   217  }
   218  
   219  // tryRedemptionRequests searches all mainchain blocks with height >= startBlock
   220  // for redemptions.
   221  func (r *RedemptionFinder) tryRedemptionRequests(ctx context.Context, startBlock *chainhash.Hash, reqs []*FindRedemptionReq) {
   222  	undiscovered := make(map[OutPoint]*FindRedemptionReq, len(reqs))
   223  	mempoolReqs := make(map[OutPoint]*FindRedemptionReq)
   224  	for _, req := range reqs {
   225  		// If there is no block hash yet, this request hasn't been mined, and a
   226  		// spending tx cannot have been mined. Only check mempool.
   227  		if req.blockHash == nil {
   228  			mempoolReqs[req.outPt] = req
   229  			continue
   230  		}
   231  		undiscovered[req.outPt] = req
   232  	}
   233  
   234  	epicFail := func(s string, a ...any) {
   235  		for _, req := range reqs {
   236  			req.fail(s, a...)
   237  		}
   238  	}
   239  
   240  	// Only search up to the current tip. This does leave two unhandled
   241  	// scenarios worth mentioning.
   242  	//  1) A new block is mined during our search. In this case, we won't
   243  	//     see the new block, but tryRedemptionRequests should be called again
   244  	//     by the block monitoring loop.
   245  	//  2) A reorg happens, and this tip becomes orphaned. In this case, the
   246  	//     worst that can happen is that a shorter chain will replace a longer
   247  	//     one (extremely rare). Even in that case, we'll just log the error and
   248  	//     exit the block loop.
   249  	tipHeight, err := r.getBestBlockHeight()
   250  	if err != nil {
   251  		epicFail("tryRedemptionRequests getBestBlockHeight error: %v", err)
   252  		return
   253  	}
   254  
   255  	// If a startBlock is provided at a higher height, use that as the starting
   256  	// point.
   257  	var iHash *chainhash.Hash
   258  	var iHeight int32
   259  	if startBlock != nil {
   260  		h, err := r.getBlockHeight(startBlock)
   261  		if err != nil {
   262  			epicFail("tryRedemptionRequests startBlock getBlockHeight error: %v", err)
   263  			return
   264  		}
   265  		iHeight = h
   266  		iHash = startBlock
   267  	} else {
   268  		iHeight = math.MaxInt32
   269  		for _, req := range undiscovered {
   270  			if req.blockHash != nil && req.blockHeight < iHeight {
   271  				iHeight = req.blockHeight
   272  				iHash = req.blockHash
   273  			}
   274  		}
   275  	}
   276  
   277  	// Helper function to check that the request hasn't been located in another
   278  	// thread and removed from queue already.
   279  	reqStillQueued := func(outPt OutPoint) bool {
   280  		_, found := r.redemptions[outPt]
   281  		return found
   282  	}
   283  
   284  	for iHeight <= tipHeight {
   285  		validReqs := make(map[OutPoint]*FindRedemptionReq, len(undiscovered))
   286  		r.mtx.RLock()
   287  		for outPt, req := range undiscovered {
   288  			if iHeight >= req.blockHeight && reqStillQueued(req.outPt) {
   289  				validReqs[outPt] = req
   290  			}
   291  		}
   292  		r.mtx.RUnlock()
   293  
   294  		if len(validReqs) == 0 {
   295  			iHeight++
   296  			continue
   297  		}
   298  
   299  		r.log.Debugf("tryRedemptionRequests - Checking block %v for redemptions...", iHash)
   300  		discovered := r.searchBlockForRedemptions(ctx, validReqs, *iHash)
   301  		for outPt, res := range discovered {
   302  			req, found := undiscovered[outPt]
   303  			if !found {
   304  				r.log.Critical("Request not found in undiscovered map. This shouldn't be possible.")
   305  				continue
   306  			}
   307  			redeemTxID, redeemTxInput, _ := decodeCoinID(res.redemptionCoinID)
   308  			r.log.Debugf("Found redemption %s:%d", redeemTxID, redeemTxInput)
   309  			req.success(res)
   310  			delete(undiscovered, outPt)
   311  		}
   312  
   313  		if len(undiscovered) == 0 {
   314  			break
   315  		}
   316  
   317  		iHeight++
   318  		if iHeight <= tipHeight {
   319  			if iHash, err = r.getBlockHash(int64(iHeight)); err != nil {
   320  				// This might be due to a reorg. Don't abandon yet, since
   321  				// tryRedemptionRequests will be tried again by the block
   322  				// monitor loop.
   323  				r.log.Warn("error getting block hash for height %d: %v", iHeight, err)
   324  				return
   325  			}
   326  		}
   327  	}
   328  
   329  	// Check mempool for any remaining undiscovered requests.
   330  	for outPt, req := range undiscovered {
   331  		mempoolReqs[outPt] = req
   332  	}
   333  
   334  	if len(mempoolReqs) == 0 {
   335  		return
   336  	}
   337  
   338  	// Do we really want to do this? Mempool could be huge.
   339  	searchDur := time.Minute * 5
   340  	searchCtx, cancel := context.WithTimeout(ctx, searchDur)
   341  	defer cancel()
   342  	for outPt, res := range r.findRedemptionsInMempool(searchCtx, mempoolReqs) {
   343  		req, ok := mempoolReqs[outPt]
   344  		if !ok {
   345  			r.log.Errorf("findRedemptionsInMempool discovered outpoint not found")
   346  			continue
   347  		}
   348  		req.success(res)
   349  	}
   350  	if err := searchCtx.Err(); err != nil {
   351  		if errors.Is(err, context.DeadlineExceeded) {
   352  			r.log.Errorf("mempool search exceeded %s time limit", searchDur)
   353  		} else {
   354  			r.log.Error("mempool search was cancelled")
   355  		}
   356  	}
   357  }
   358  
   359  // prepareRedemptionRequestsForBlockCheck prepares a copy of the currently
   360  // tracked redemptions, checking for missing block data along the way.
   361  func (r *RedemptionFinder) prepareRedemptionRequestsForBlockCheck() []*FindRedemptionReq {
   362  	// Search for contract redemption in new blocks if there
   363  	// are contracts pending redemption.
   364  	r.mtx.Lock()
   365  	defer r.mtx.Unlock()
   366  	reqs := make([]*FindRedemptionReq, 0, len(r.redemptions))
   367  	for _, req := range r.redemptions {
   368  		// If the request doesn't have a block hash yet, check if we can get one
   369  		// now.
   370  		if req.blockHash == nil {
   371  			r.trySetRedemptionRequestBlock(req)
   372  		}
   373  		reqs = append(reqs, req)
   374  	}
   375  	return reqs
   376  }
   377  
   378  // ReportNewTip sets the currentTip. The tipChange callback function is invoked
   379  // and a goroutine is started to check if any contracts in the
   380  // findRedemptionQueue are redeemed in the new blocks.
   381  func (r *RedemptionFinder) ReportNewTip(ctx context.Context, prevTip, newTip *BlockVector) {
   382  	reqs := r.prepareRedemptionRequestsForBlockCheck()
   383  	// Redemption search would be compromised if the starting point cannot
   384  	// be determined, as searching just the new tip might result in blocks
   385  	// being omitted from the search operation. If that happens, cancel all
   386  	// find redemption requests in queue.
   387  	notifyFatalFindRedemptionError := func(s string, a ...any) {
   388  		for _, req := range reqs {
   389  			req.fail("tipChange handler - "+s, a...)
   390  		}
   391  	}
   392  
   393  	var startPoint *BlockVector
   394  	// Check if the previous tip is still part of the mainchain (prevTip confs >= 0).
   395  	// Redemption search would typically resume from prevTipHeight + 1 unless the
   396  	// previous tip was re-orged out of the mainchain, in which case redemption
   397  	// search will resume from the mainchain ancestor of the previous tip.
   398  	prevTipHeader, isMainchain, err := r.getBlockHeader(&prevTip.Hash)
   399  	switch {
   400  	case err != nil:
   401  		// Redemption search cannot continue reliably without knowing if there
   402  		// was a reorg, cancel all find redemption requests in queue.
   403  		notifyFatalFindRedemptionError("getBlockHeader error for prev tip hash %s: %w",
   404  			prevTip.Hash, err)
   405  		return
   406  
   407  	case !isMainchain:
   408  		// The previous tip is no longer part of the mainchain. Crawl blocks
   409  		// backwards until finding a mainchain block. Start with the block
   410  		// that is the immediate ancestor to the previous tip.
   411  		ancestorBlockHash, err := chainhash.NewHashFromStr(prevTipHeader.PreviousBlockHash)
   412  		if err != nil {
   413  			notifyFatalFindRedemptionError("hash decode error for block %s: %w", prevTipHeader.PreviousBlockHash, err)
   414  			return
   415  		}
   416  		for {
   417  			aBlock, isMainchain, err := r.getBlockHeader(ancestorBlockHash)
   418  			if err != nil {
   419  				notifyFatalFindRedemptionError("getBlockHeader error for block %s: %w", ancestorBlockHash, err)
   420  				return
   421  			}
   422  			if isMainchain {
   423  				// Found the mainchain ancestor of previous tip.
   424  				startPoint = &BlockVector{Height: aBlock.Height, Hash: *ancestorBlockHash}
   425  				r.log.Debugf("reorg detected from height %d to %d", aBlock.Height, newTip.Height)
   426  				break
   427  			}
   428  			if aBlock.Height == 0 {
   429  				// Crawled back to genesis block without finding a mainchain ancestor
   430  				// for the previous tip. Should never happen!
   431  				notifyFatalFindRedemptionError("no mainchain ancestor for orphaned block %s", prevTipHeader.Hash)
   432  				return
   433  			}
   434  			ancestorBlockHash, err = chainhash.NewHashFromStr(aBlock.PreviousBlockHash)
   435  			if err != nil {
   436  				notifyFatalFindRedemptionError("hash decode error for block %s: %w", prevTipHeader.PreviousBlockHash, err)
   437  				return
   438  			}
   439  		}
   440  
   441  	case newTip.Height-prevTipHeader.Height > 1:
   442  		// 2 or more blocks mined since last tip, start at prevTip height + 1.
   443  		afterPrivTip := prevTipHeader.Height + 1
   444  		hashAfterPrevTip, err := r.getBlockHash(afterPrivTip)
   445  		if err != nil {
   446  			notifyFatalFindRedemptionError("getBlockHash error for height %d: %w", afterPrivTip, err)
   447  			return
   448  		}
   449  		startPoint = &BlockVector{Hash: *hashAfterPrevTip, Height: afterPrivTip}
   450  
   451  	default:
   452  		// Just 1 new block since last tip report, search the lone block.
   453  		startPoint = newTip
   454  	}
   455  
   456  	if len(reqs) > 0 {
   457  		go r.tryRedemptionRequests(ctx, &startPoint.Hash, reqs)
   458  	}
   459  }
   460  
   461  // trySetRedemptionRequestBlock should be called with findRedemptionMtx Lock'ed.
   462  func (r *RedemptionFinder) trySetRedemptionRequestBlock(req *FindRedemptionReq) {
   463  	tx, err := r.getWalletTransaction(&req.outPt.TxHash)
   464  	if err != nil {
   465  		r.log.Errorf("getWalletTransaction error for FindRedemption transaction: %v", err)
   466  		return
   467  	}
   468  
   469  	if tx.BlockHash == "" {
   470  		return
   471  	}
   472  	blockHash, err := chainhash.NewHashFromStr(tx.BlockHash)
   473  	if err != nil {
   474  		r.log.Errorf("error decoding block hash %q: %v", tx.BlockHash, err)
   475  		return
   476  	}
   477  
   478  	blockHeight, err := r.checkRedemptionBlockDetails(req.outPt, blockHash, req.pkScript)
   479  	if err != nil {
   480  		r.log.Error(err)
   481  		return
   482  	}
   483  	// Don't update the FindRedemptionReq, since the findRedemptionMtx only
   484  	// protects the map.
   485  	req = &FindRedemptionReq{
   486  		outPt:        req.outPt,
   487  		blockHash:    blockHash,
   488  		blockHeight:  blockHeight,
   489  		resultChan:   req.resultChan,
   490  		pkScript:     req.pkScript,
   491  		contractHash: req.contractHash,
   492  	}
   493  	r.redemptions[req.outPt] = req
   494  }
   495  
   496  func (r *RedemptionFinder) CancelRedemptionSearches() {
   497  	// Close all open channels for contract redemption searches
   498  	// to prevent leakages and ensure goroutines that are started
   499  	// to wait on these channels end gracefully.
   500  	r.mtx.Lock()
   501  	for contractOutpoint, req := range r.redemptions {
   502  		req.fail("shutting down")
   503  		delete(r.redemptions, contractOutpoint)
   504  	}
   505  	r.mtx.Unlock()
   506  }
   507  
   508  func findRedemptionsInTxWithHasher(ctx context.Context, segwit bool, reqs map[OutPoint]*FindRedemptionReq, msgTx *wire.MsgTx,
   509  	chainParams *chaincfg.Params, hashTx func(*wire.MsgTx) *chainhash.Hash) (discovered map[OutPoint]*FindRedemptionResult) {
   510  
   511  	discovered = make(map[OutPoint]*FindRedemptionResult, len(reqs))
   512  
   513  	for vin, txIn := range msgTx.TxIn {
   514  		if ctx.Err() != nil {
   515  			return discovered
   516  		}
   517  		poHash, poVout := txIn.PreviousOutPoint.Hash, txIn.PreviousOutPoint.Index
   518  		for outPt, req := range reqs {
   519  			if discovered[outPt] != nil {
   520  				continue
   521  			}
   522  			if outPt.TxHash == poHash && outPt.Vout == poVout {
   523  				// Match!
   524  				txHash := hashTx(msgTx)
   525  				secret, err := dexbtc.FindKeyPush(txIn.Witness, txIn.SignatureScript, req.contractHash[:], segwit, chainParams)
   526  				if err != nil {
   527  					req.fail("no secret extracted from redemption input %s:%d for swap output %s: %v",
   528  						txHash, vin, outPt, err)
   529  					continue
   530  				}
   531  				discovered[outPt] = &FindRedemptionResult{
   532  					redemptionCoinID: ToCoinID(txHash, uint32(vin)),
   533  					secret:           secret,
   534  				}
   535  			}
   536  		}
   537  	}
   538  	return
   539  }