decred.org/dcrwallet/v3@v3.1.0/wallet/coinjoin.go (about)

     1  package wallet
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/subtle"
     7  
     8  	"decred.org/dcrwallet/v3/errors"
     9  	"decred.org/dcrwallet/v3/wallet/walletdb"
    10  	"github.com/decred/dcrd/dcrec"
    11  	"github.com/decred/dcrd/dcrutil/v4"
    12  	"github.com/decred/dcrd/txscript/v4"
    13  	"github.com/decred/dcrd/txscript/v4/sign"
    14  	"github.com/decred/dcrd/txscript/v4/stdaddr"
    15  	"github.com/decred/dcrd/txscript/v4/stdscript"
    16  	"github.com/decred/dcrd/wire"
    17  )
    18  
    19  type missingGenError struct{}
    20  
    21  var errMissingGen missingGenError
    22  
    23  func (missingGenError) Error() string   { return "coinjoin is missing gen output" }
    24  func (missingGenError) MissingMessage() {}
    25  
    26  type csppJoin struct {
    27  	tx            *wire.MsgTx
    28  	txInputs      map[wire.OutPoint]int
    29  	myPrevScripts [][]byte
    30  	myIns         []*wire.TxIn
    31  	change        *wire.TxOut
    32  	mcount        int
    33  	genScripts    [][]byte
    34  	genIndex      []int
    35  	amount        int64
    36  	wallet        *Wallet
    37  	mixAccount    uint32
    38  	mixBranch     uint32
    39  
    40  	ctx context.Context
    41  }
    42  
    43  func (w *Wallet) newCsppJoin(ctx context.Context, change *wire.TxOut, amount dcrutil.Amount, mixAccount, mixBranch uint32, mcount int) *csppJoin {
    44  	cj := &csppJoin{
    45  		tx:         &wire.MsgTx{Version: 1},
    46  		change:     change,
    47  		mcount:     mcount,
    48  		amount:     int64(amount),
    49  		wallet:     w,
    50  		mixAccount: mixAccount,
    51  		mixBranch:  mixBranch,
    52  		ctx:        ctx,
    53  	}
    54  	if change != nil {
    55  		cj.tx.TxOut = append(cj.tx.TxOut, change)
    56  	}
    57  	return cj
    58  }
    59  
    60  func (c *csppJoin) addTxIn(prevScript []byte, in *wire.TxIn) {
    61  	c.tx.TxIn = append(c.tx.TxIn, in)
    62  	c.myPrevScripts = append(c.myPrevScripts, prevScript)
    63  	c.myIns = append(c.myIns, in)
    64  }
    65  
    66  func (c *csppJoin) Gen() ([][]byte, error) {
    67  	const op errors.Op = "cspp.Gen"
    68  	gen := make([][]byte, c.mcount)
    69  	c.genScripts = make([][]byte, c.mcount)
    70  	var updates []func(walletdb.ReadWriteTx) error
    71  	for i := 0; i < c.mcount; i++ {
    72  		persist := c.wallet.deferPersistReturnedChild(c.ctx, &updates)
    73  		const accountName = "" // not used, so can be faked.
    74  		mixAddr, err := c.wallet.nextAddress(c.ctx, op, persist,
    75  			accountName, c.mixAccount, c.mixBranch, WithGapPolicyIgnore())
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  		version, script := mixAddr.PaymentScript()
    80  		if version != 0 {
    81  			return nil, errors.E("expected script version 0")
    82  		}
    83  		hash160er, ok := mixAddr.(stdaddr.Hash160er)
    84  		if !ok {
    85  			return nil, errors.E("address does not have Hash160 method")
    86  		}
    87  		c.genScripts[i] = script
    88  		gen[i] = hash160er.Hash160()[:]
    89  	}
    90  	err := walletdb.Update(c.ctx, c.wallet.db, func(dbtx walletdb.ReadWriteTx) error {
    91  		for _, f := range updates {
    92  			if err := f(dbtx); err != nil {
    93  				return err
    94  			}
    95  		}
    96  		return nil
    97  	})
    98  	if err != nil {
    99  		return nil, errors.E(op, err)
   100  	}
   101  	return gen, nil
   102  }
   103  
   104  func (c *csppJoin) Confirm() error {
   105  	const op errors.Op = "cspp.Confirm"
   106  	err := walletdb.View(c.ctx, c.wallet.db, func(dbtx walletdb.ReadTx) error {
   107  		addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
   108  		for outx, in := range c.myIns {
   109  			outScript := c.myPrevScripts[outx]
   110  			index, ok := c.txInputs[in.PreviousOutPoint]
   111  			if !ok {
   112  				return errors.E("coinjoin is missing inputs")
   113  			}
   114  			in = c.tx.TxIn[index]
   115  
   116  			const scriptVersion = 0
   117  			_, addrs := stdscript.ExtractAddrs(scriptVersion, outScript, c.wallet.chainParams)
   118  			if len(addrs) != 1 {
   119  				continue
   120  			}
   121  			apkh, ok := addrs[0].(*stdaddr.AddressPubKeyHashEcdsaSecp256k1V0)
   122  			if !ok {
   123  				return errors.E(errors.Bug, "previous output is not P2PKH")
   124  			}
   125  			privKey, done, err := c.wallet.manager.PrivateKey(addrmgrNs, apkh)
   126  			if err != nil {
   127  				return err
   128  			}
   129  			defer done()
   130  			sigscript, err := sign.SignatureScript(c.tx, index, outScript,
   131  				txscript.SigHashAll, privKey.Serialize(), dcrec.STEcdsaSecp256k1, true)
   132  			if err != nil {
   133  				return errors.E(errors.Op("txscript.SignatureScript"), err)
   134  			}
   135  			in.SignatureScript = sigscript
   136  		}
   137  		return nil
   138  	})
   139  	if err != nil {
   140  		return errors.E(op, err)
   141  	}
   142  	return nil
   143  }
   144  
   145  func (c *csppJoin) mixOutputIndexes() []int {
   146  	return c.genIndex
   147  }
   148  
   149  func (c *csppJoin) MarshalBinary() ([]byte, error) {
   150  	buf := new(bytes.Buffer)
   151  	buf.Grow(c.tx.SerializeSize())
   152  	err := c.tx.Serialize(buf)
   153  	return buf.Bytes(), err
   154  }
   155  
   156  func (c *csppJoin) UnmarshalBinary(b []byte) error {
   157  	tx := new(wire.MsgTx)
   158  	err := tx.Deserialize(bytes.NewReader(b))
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	// Ensure all unmixed inputs, unmixed outputs, and mixed outputs exist.
   164  	// Mixed outputs must be searched in constant time to avoid sidechannel leakage.
   165  	txInputs := make(map[wire.OutPoint]int, len(tx.TxIn))
   166  	for i, in := range tx.TxIn {
   167  		txInputs[in.PreviousOutPoint] = i
   168  	}
   169  	var n int
   170  	for _, in := range c.myIns {
   171  		if index, ok := txInputs[in.PreviousOutPoint]; ok {
   172  			other := tx.TxIn[index]
   173  			if in.Sequence != other.Sequence || in.ValueIn != other.ValueIn {
   174  				break
   175  			}
   176  			n++
   177  		}
   178  	}
   179  	if n != len(c.myIns) {
   180  		return errors.E("coinjoin is missing inputs")
   181  	}
   182  	if c.change != nil {
   183  		var hasChange bool
   184  		for _, out := range tx.TxOut {
   185  			if out.Value != c.change.Value {
   186  				continue
   187  			}
   188  			if out.Version != c.change.Version {
   189  				continue
   190  			}
   191  			if !bytes.Equal(out.PkScript, c.change.PkScript) {
   192  				continue
   193  			}
   194  			hasChange = true
   195  			break
   196  		}
   197  		if !hasChange {
   198  			return errors.E("coinjoin is missing change")
   199  		}
   200  	}
   201  	indexes, err := constantTimeOutputSearch(tx, c.amount, 0, c.genScripts)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	c.tx = tx
   207  	c.txInputs = txInputs
   208  	c.genIndex = indexes
   209  	return nil
   210  }
   211  
   212  // constantTimeOutputSearch searches for the output indexes of mixed outputs to
   213  // verify inclusion in a coinjoin.  It is constant time such that, for each
   214  // searched script, all outputs with equal value, script versions, and script
   215  // lengths matching the searched output are checked in constant time.
   216  func constantTimeOutputSearch(tx *wire.MsgTx, value int64, scriptVer uint16, scripts [][]byte) ([]int, error) {
   217  	var scan []int
   218  	for i, out := range tx.TxOut {
   219  		if out.Value != value {
   220  			continue
   221  		}
   222  		if out.Version != scriptVer {
   223  			continue
   224  		}
   225  		if len(out.PkScript) != len(scripts[0]) {
   226  			continue
   227  		}
   228  		scan = append(scan, i)
   229  	}
   230  	indexes := make([]int, 0, len(scan))
   231  	var missing int
   232  	for _, s := range scripts {
   233  		idx := -1
   234  		for _, i := range scan {
   235  			eq := subtle.ConstantTimeCompare(tx.TxOut[i].PkScript, s)
   236  			idx = subtle.ConstantTimeSelect(eq, i, idx)
   237  		}
   238  		indexes = append(indexes, idx)
   239  		eq := subtle.ConstantTimeEq(int32(idx), -1)
   240  		missing = subtle.ConstantTimeSelect(eq, 1, missing)
   241  	}
   242  	if missing == 1 {
   243  		return nil, errMissingGen
   244  	}
   245  	return indexes, nil
   246  }