decred.org/dcrdex@v1.0.5/client/asset/dcr/native_wallet.go (about)

     1  // This code is available on the terms of the project LICENSE.md file,
     2  // also available online at https://blueoakcouncil.org/license/1.0.0.
     3  
     4  package dcr
     5  
     6  import (
     7  	"context"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"path/filepath"
    13  	"sort"
    14  	"sync"
    15  	"time"
    16  
    17  	"decred.org/dcrdex/client/asset"
    18  	"decred.org/dcrdex/dex"
    19  	walletjson "decred.org/dcrwallet/v5/rpc/jsonrpc/types"
    20  	"decred.org/dcrwallet/v5/wallet"
    21  	"github.com/decred/dcrd/chaincfg/chainhash"
    22  	"github.com/decred/dcrd/dcrutil/v4"
    23  )
    24  
    25  const (
    26  	csppConfigFileName = "cspp_config.json"
    27  )
    28  
    29  var nativeAccounts = []string{defaultAccountName, mixedAccountName, tradingAccountName}
    30  
    31  // mixingConfigFile is the structure for saving cspp server configuration to
    32  // file.
    33  type mixingConfigFile struct {
    34  	LegacyOn string `json:"csppserver"`
    35  	On       bool   `json:"on"`
    36  }
    37  
    38  // mixer is the settings and concurrency primitives for mixing operations.
    39  type mixer struct {
    40  	mtx    sync.RWMutex
    41  	ctx    context.Context
    42  	cancel func()
    43  	wg     sync.WaitGroup
    44  }
    45  
    46  // turnOn should be called with the mtx locked.
    47  func (m *mixer) turnOn(ctx context.Context) {
    48  	if m.cancel != nil {
    49  		m.cancel()
    50  	}
    51  	m.ctx, m.cancel = context.WithCancel(ctx)
    52  }
    53  
    54  // closeAndClear should be called with the mtx locked.
    55  func (m *mixer) closeAndClear() {
    56  	if m.cancel != nil {
    57  		m.cancel()
    58  	}
    59  	m.ctx, m.cancel = nil, nil
    60  }
    61  
    62  // NativeWallet implements optional interfaces that are only provided by the
    63  // built-in SPV wallet.
    64  type NativeWallet struct {
    65  	*ExchangeWallet
    66  	csppConfigFilePath string
    67  	spvw               *spvWallet
    68  
    69  	mixer mixer
    70  }
    71  
    72  // NativeWallet must also satisfy the following interface(s).
    73  var _ asset.FundsMixer = (*NativeWallet)(nil)
    74  var _ asset.Rescanner = (*NativeWallet)(nil)
    75  
    76  func initNativeWallet(ew *ExchangeWallet) (*NativeWallet, error) {
    77  	spvWallet, ok := ew.wallet.(*spvWallet)
    78  	if !ok {
    79  		return nil, fmt.Errorf("spvwallet is required to init NativeWallet")
    80  	}
    81  
    82  	csppConfigFilePath := filepath.Join(spvWallet.dir, csppConfigFileName)
    83  	cfgFileB, err := os.ReadFile(csppConfigFilePath)
    84  	if err != nil && !errors.Is(err, os.ErrNotExist) {
    85  		return nil, fmt.Errorf("unable to read cspp config file: %v", err)
    86  	}
    87  
    88  	if len(cfgFileB) > 0 {
    89  		var cfg mixingConfigFile
    90  		err = json.Unmarshal(cfgFileB, &cfg)
    91  		if err != nil {
    92  			return nil, fmt.Errorf("unable to unmarshal csppConfig: %v", err)
    93  		}
    94  		ew.mixing.Store(cfg.On || cfg.LegacyOn != "")
    95  	}
    96  
    97  	spvWallet.setAccounts(ew.mixing.Load())
    98  
    99  	w := &NativeWallet{
   100  		ExchangeWallet:     ew,
   101  		spvw:               spvWallet,
   102  		csppConfigFilePath: csppConfigFilePath,
   103  	}
   104  	ew.cycleMixer = func() {
   105  		w.mixer.mtx.RLock()
   106  		defer w.mixer.mtx.RUnlock()
   107  		w.mixFunds()
   108  	}
   109  
   110  	return w, nil
   111  }
   112  
   113  func (w *NativeWallet) Connect(ctx context.Context) (*sync.WaitGroup, error) {
   114  	wg, err := w.ExchangeWallet.Connect(ctx)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return wg, err
   119  }
   120  
   121  // ConfigureFundsMixer configures the wallet for funds mixing. The wallet must
   122  // be unlocked before calling. Part of the asset.FundsMixer interface.
   123  func (w *NativeWallet) ConfigureFundsMixer(enabled bool) (err error) {
   124  	csppCfgBytes, err := json.Marshal(&mixingConfigFile{
   125  		On: enabled,
   126  	})
   127  	if err != nil {
   128  		return fmt.Errorf("error marshaling cspp config file: %w", err)
   129  	}
   130  	if err := os.WriteFile(w.csppConfigFilePath, csppCfgBytes, 0644); err != nil {
   131  		return fmt.Errorf("error writing cspp config file: %w", err)
   132  	}
   133  
   134  	if !enabled {
   135  		return w.stopFundsMixer()
   136  	}
   137  	w.startFundsMixer()
   138  	w.emitBalance()
   139  	return nil
   140  }
   141  
   142  // FundsMixingStats returns the current state of the wallet's funds mixer. Part
   143  // of the asset.FundsMixer interface.
   144  func (w *NativeWallet) FundsMixingStats() (*asset.FundsMixingStats, error) {
   145  	mixedFunds, err := w.spvw.AccountBalance(w.ctx, 0, mixedAccountName)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  	tradingFunds, err := w.spvw.AccountBalance(w.ctx, 0, tradingAccountName)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	return &asset.FundsMixingStats{
   154  		Enabled:                 w.mixing.Load(),
   155  		UnmixedBalanceThreshold: smalletCSPPSplitPoint,
   156  		MixedFunds:              toAtoms(mixedFunds.Total),
   157  		TradingFunds:            toAtoms(tradingFunds.Total),
   158  	}, nil
   159  }
   160  
   161  // startFundsMixer starts the funds mixer.  This will error if the wallet does
   162  // not allow starting or stopping the mixer or if the mixer was already
   163  // started. Part of the asset.FundsMixer interface.
   164  func (w *NativeWallet) startFundsMixer() {
   165  	w.mixer.mtx.Lock()
   166  	defer w.mixer.mtx.Unlock()
   167  	w.mixer.turnOn(w.ctx)
   168  	w.spvw.setAccounts(true)
   169  	w.mixing.Store(true)
   170  	w.mixFunds()
   171  }
   172  
   173  func (w *NativeWallet) stopFundsMixer() error {
   174  	w.mixer.mtx.Lock()
   175  	defer w.mixer.mtx.Unlock()
   176  	w.mixer.closeAndClear()
   177  	w.mixer.wg.Wait()
   178  	if err := w.transferAccount(w.ctx, defaultAccountName, mixedAccountName, tradingAccountName); err != nil {
   179  		return fmt.Errorf("error transferring funds while disabling mixing: %w", err)
   180  	}
   181  	w.spvw.setAccounts(false)
   182  	w.mixing.Store(false)
   183  	return nil
   184  }
   185  
   186  // Lock locks all the native wallet accounts.
   187  func (w *NativeWallet) Lock() (err error) {
   188  	if w.mixing.Load() {
   189  		return fmt.Errorf("cannot lock wallet while mixing")
   190  	}
   191  	w.mixer.mtx.Lock()
   192  	w.mixer.closeAndClear()
   193  	w.mixer.mtx.Unlock()
   194  	w.mixer.wg.Wait()
   195  	for _, acct := range nativeAccounts {
   196  		if err = w.wallet.LockAccount(w.ctx, acct); err != nil {
   197  			return fmt.Errorf("error locking native wallet account %q: %w", acct, err)
   198  		}
   199  	}
   200  	return nil
   201  }
   202  
   203  // mixFunds checks the status of mixing operations and starts a mix cycle.
   204  // mixFunds must be called with the mixer.mtx >= RLock'd.
   205  func (w *NativeWallet) mixFunds() {
   206  	ss, _ := w.SyncStatus()
   207  	if !ss.Synced {
   208  		return
   209  	}
   210  	on := w.mixer.ctx != nil
   211  	if !on || !w.mixing.Load() {
   212  		return
   213  	}
   214  	ctx := w.mixer.ctx
   215  	if w.network == dex.Simnet {
   216  		w.mixer.wg.Add(1)
   217  		go func() {
   218  			defer w.mixer.wg.Done()
   219  			w.runSimnetMixer(ctx)
   220  		}()
   221  		return
   222  	}
   223  	w.mixer.wg.Add(1)
   224  	go func() {
   225  		defer w.mixer.wg.Done()
   226  		w.spvw.mix(ctx)
   227  		w.emitBalance()
   228  	}()
   229  }
   230  
   231  // runSimnetMixer just sends all funds from the mixed account to the default
   232  // account, after a short delay.
   233  func (w *NativeWallet) runSimnetMixer(ctx context.Context) {
   234  	if err := w.transferAccount(ctx, mixedAccountName, defaultAccountName); err != nil {
   235  		w.log.Errorf("error transferring funds while disabling mixing: %w", err)
   236  	}
   237  }
   238  
   239  // transferAccount sends all funds from the fromAccts to the toAcct.
   240  func (w *NativeWallet) transferAccount(ctx context.Context, toAcct string, fromAccts ...string) error {
   241  	// Move funds from mixed and trading account to default account.
   242  	var unspents []*walletjson.ListUnspentResult
   243  	for _, acctName := range fromAccts {
   244  		uns, err := w.spvw.Unspents(ctx, acctName)
   245  		if err != nil {
   246  			return fmt.Errorf("error listing unspent outputs for acct %q: %w", acctName, err)
   247  		}
   248  		unspents = append(unspents, uns...)
   249  	}
   250  	if len(unspents) == 0 {
   251  		return nil
   252  	}
   253  	var coinsToTransfer asset.Coins
   254  	for _, unspent := range unspents {
   255  		txHash, err := chainhash.NewHashFromStr(unspent.TxID)
   256  		if err != nil {
   257  			return fmt.Errorf("error decoding txid: %w", err)
   258  		}
   259  		v := toAtoms(unspent.Amount)
   260  		op := newOutput(txHash, unspent.Vout, v, unspent.Tree)
   261  		coinsToTransfer = append(coinsToTransfer, op)
   262  	}
   263  
   264  	tx, totalSent, err := w.sendAll(coinsToTransfer, toAcct)
   265  	if err != nil {
   266  		return fmt.Errorf("unable to transfer all funds from %+v accounts: %v", fromAccts, err)
   267  	} else {
   268  		w.log.Infof("Transferred %s from %+v accounts to %s account in tx %s.",
   269  			dcrutil.Amount(totalSent), fromAccts, toAcct, tx.TxHash())
   270  	}
   271  	return nil
   272  }
   273  
   274  // birthdayBlockHeight performs a binary search for the last block with a
   275  // timestamp lower than the provided birthday.
   276  func (w *NativeWallet) birthdayBlockHeight(ctx context.Context, bday uint64) int32 {
   277  	tipHeight := w.cachedBestBlock().height
   278  	var err error
   279  	firstBlockAfterBday := sort.Search(int(tipHeight), func(blockHeightI int) bool {
   280  		if err != nil { // if we see any errors, just give up.
   281  			return false
   282  		}
   283  		var blockHash *chainhash.Hash
   284  		if blockHash, err = w.spvw.GetBlockHash(ctx, int64(blockHeightI)); err != nil {
   285  			w.log.Errorf("Error getting block hash for height %d: %v", blockHeightI, err)
   286  			return false
   287  		}
   288  		stamp, err := w.spvw.BlockTimestamp(ctx, blockHash)
   289  		if err != nil {
   290  			w.log.Errorf("Error getting block header for hash %s: %v", blockHash, err)
   291  			return false
   292  		}
   293  		return uint64(stamp.Unix()) >= bday
   294  	})
   295  	if err != nil {
   296  		w.log.Errorf("Error encountered searching for birthday block: %v", err)
   297  		firstBlockAfterBday = 1
   298  	}
   299  	if firstBlockAfterBday == int(tipHeight) {
   300  		w.log.Errorf("Birthday %d is from the future", bday)
   301  		return 0
   302  	}
   303  
   304  	if firstBlockAfterBday == 0 {
   305  		return 0
   306  	}
   307  	return int32(firstBlockAfterBday - 1)
   308  }
   309  
   310  // Rescan initiates a rescan of the wallet from height 0. Rescan only blocks
   311  // long enough for the first asynchronous update, either an error or after the
   312  // first 2000 blocks are scanned.
   313  func (w *NativeWallet) Rescan(ctx context.Context, bday uint64) (err error) {
   314  	// Make sure we don't already have one running.
   315  	w.rescan.Lock()
   316  	rescanInProgress := w.rescan.progress != nil
   317  	if !rescanInProgress {
   318  		w.rescan.progress = &rescanProgress{}
   319  	}
   320  	w.rescan.Unlock()
   321  	if rescanInProgress {
   322  		return errors.New("rescan already in progress")
   323  	}
   324  
   325  	if bday == 0 {
   326  		bday = defaultWalletBirthdayUnix
   327  	}
   328  	bdayHeight := w.birthdayBlockHeight(ctx, bday)
   329  	// Add a little buffer.
   330  	const blockBufferN = 100
   331  	if bdayHeight >= blockBufferN {
   332  		bdayHeight -= blockBufferN
   333  	} else {
   334  		bdayHeight = 0
   335  	}
   336  
   337  	setProgress := func(height int32) {
   338  		w.rescan.Lock()
   339  		w.rescan.progress = &rescanProgress{scannedThrough: int64(height)}
   340  		w.rescan.Unlock()
   341  	}
   342  
   343  	c := make(chan wallet.RescanProgress)
   344  	go w.spvw.rescan(ctx, bdayHeight, c) // RescanProgressWithHeight will defer close(c)
   345  
   346  	// First update will either be an error or a report of the first 2000
   347  	// blocks. We can block until we get one.
   348  	errC := make(chan error, 1)
   349  	sendErr := func(err error) {
   350  		select {
   351  		case errC <- err:
   352  		default:
   353  		}
   354  		if err == nil {
   355  			w.receiveTxLastQuery.Store(0)
   356  		} else {
   357  			w.log.Errorf("Error encountered in rescan: %v", err)
   358  		}
   359  	}
   360  
   361  	w.wg.Add(1)
   362  	go func() {
   363  		defer func() {
   364  			w.rescan.Lock()
   365  			lastUpdate := w.rescan.progress
   366  			w.rescan.progress = nil
   367  			w.rescan.Unlock()
   368  			if lastUpdate != nil && lastUpdate.scannedThrough > 0 {
   369  				w.log.Infof("Completed rescan of %d blocks", lastUpdate.scannedThrough)
   370  			}
   371  			w.wg.Done()
   372  		}()
   373  		// Rescans are quick. Timeouts > a second are probably too high, but
   374  		// we'll give ample buffer.
   375  		timeout := time.After(time.Minute)
   376  		for {
   377  			select {
   378  			case u, open := <-c:
   379  				if !open { // channel was closed. rescan is finished.
   380  					if timeout == nil {
   381  						sendErr(nil)
   382  					} else {
   383  						// We never saw an update.
   384  						sendErr(errors.New("rescan finished without a progress update"))
   385  					}
   386  					return
   387  				}
   388  				sendErr(u.Err) // Hopefully nil, causing Rescan to return nil.
   389  				timeout = nil  // Any update cancels timeout.
   390  				if u.Err == nil {
   391  					setProgress(u.ScannedThrough)
   392  				}
   393  			case <-timeout:
   394  				sendErr(errors.New("rescan never sent progress updates"))
   395  				return
   396  			case <-ctx.Done():
   397  				sendErr(ctx.Err())
   398  				return
   399  			}
   400  		}
   401  	}()
   402  	return <-errC
   403  }