decred.org/dcrwallet/v3@v3.1.0/wallet/coinjoin.go (about) 1 package wallet 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/subtle" 7 8 "decred.org/dcrwallet/v3/errors" 9 "decred.org/dcrwallet/v3/wallet/walletdb" 10 "github.com/decred/dcrd/dcrec" 11 "github.com/decred/dcrd/dcrutil/v4" 12 "github.com/decred/dcrd/txscript/v4" 13 "github.com/decred/dcrd/txscript/v4/sign" 14 "github.com/decred/dcrd/txscript/v4/stdaddr" 15 "github.com/decred/dcrd/txscript/v4/stdscript" 16 "github.com/decred/dcrd/wire" 17 ) 18 19 type missingGenError struct{} 20 21 var errMissingGen missingGenError 22 23 func (missingGenError) Error() string { return "coinjoin is missing gen output" } 24 func (missingGenError) MissingMessage() {} 25 26 type csppJoin struct { 27 tx *wire.MsgTx 28 txInputs map[wire.OutPoint]int 29 myPrevScripts [][]byte 30 myIns []*wire.TxIn 31 change *wire.TxOut 32 mcount int 33 genScripts [][]byte 34 genIndex []int 35 amount int64 36 wallet *Wallet 37 mixAccount uint32 38 mixBranch uint32 39 40 ctx context.Context 41 } 42 43 func (w *Wallet) newCsppJoin(ctx context.Context, change *wire.TxOut, amount dcrutil.Amount, mixAccount, mixBranch uint32, mcount int) *csppJoin { 44 cj := &csppJoin{ 45 tx: &wire.MsgTx{Version: 1}, 46 change: change, 47 mcount: mcount, 48 amount: int64(amount), 49 wallet: w, 50 mixAccount: mixAccount, 51 mixBranch: mixBranch, 52 ctx: ctx, 53 } 54 if change != nil { 55 cj.tx.TxOut = append(cj.tx.TxOut, change) 56 } 57 return cj 58 } 59 60 func (c *csppJoin) addTxIn(prevScript []byte, in *wire.TxIn) { 61 c.tx.TxIn = append(c.tx.TxIn, in) 62 c.myPrevScripts = append(c.myPrevScripts, prevScript) 63 c.myIns = append(c.myIns, in) 64 } 65 66 func (c *csppJoin) Gen() ([][]byte, error) { 67 const op errors.Op = "cspp.Gen" 68 gen := make([][]byte, c.mcount) 69 c.genScripts = make([][]byte, c.mcount) 70 var updates []func(walletdb.ReadWriteTx) error 71 for i := 0; i < c.mcount; i++ { 72 persist := c.wallet.deferPersistReturnedChild(c.ctx, &updates) 73 const accountName = "" // not used, so can be faked. 74 mixAddr, err := c.wallet.nextAddress(c.ctx, op, persist, 75 accountName, c.mixAccount, c.mixBranch, WithGapPolicyIgnore()) 76 if err != nil { 77 return nil, err 78 } 79 version, script := mixAddr.PaymentScript() 80 if version != 0 { 81 return nil, errors.E("expected script version 0") 82 } 83 hash160er, ok := mixAddr.(stdaddr.Hash160er) 84 if !ok { 85 return nil, errors.E("address does not have Hash160 method") 86 } 87 c.genScripts[i] = script 88 gen[i] = hash160er.Hash160()[:] 89 } 90 err := walletdb.Update(c.ctx, c.wallet.db, func(dbtx walletdb.ReadWriteTx) error { 91 for _, f := range updates { 92 if err := f(dbtx); err != nil { 93 return err 94 } 95 } 96 return nil 97 }) 98 if err != nil { 99 return nil, errors.E(op, err) 100 } 101 return gen, nil 102 } 103 104 func (c *csppJoin) Confirm() error { 105 const op errors.Op = "cspp.Confirm" 106 err := walletdb.View(c.ctx, c.wallet.db, func(dbtx walletdb.ReadTx) error { 107 addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey) 108 for outx, in := range c.myIns { 109 outScript := c.myPrevScripts[outx] 110 index, ok := c.txInputs[in.PreviousOutPoint] 111 if !ok { 112 return errors.E("coinjoin is missing inputs") 113 } 114 in = c.tx.TxIn[index] 115 116 const scriptVersion = 0 117 _, addrs := stdscript.ExtractAddrs(scriptVersion, outScript, c.wallet.chainParams) 118 if len(addrs) != 1 { 119 continue 120 } 121 apkh, ok := addrs[0].(*stdaddr.AddressPubKeyHashEcdsaSecp256k1V0) 122 if !ok { 123 return errors.E(errors.Bug, "previous output is not P2PKH") 124 } 125 privKey, done, err := c.wallet.manager.PrivateKey(addrmgrNs, apkh) 126 if err != nil { 127 return err 128 } 129 defer done() 130 sigscript, err := sign.SignatureScript(c.tx, index, outScript, 131 txscript.SigHashAll, privKey.Serialize(), dcrec.STEcdsaSecp256k1, true) 132 if err != nil { 133 return errors.E(errors.Op("txscript.SignatureScript"), err) 134 } 135 in.SignatureScript = sigscript 136 } 137 return nil 138 }) 139 if err != nil { 140 return errors.E(op, err) 141 } 142 return nil 143 } 144 145 func (c *csppJoin) mixOutputIndexes() []int { 146 return c.genIndex 147 } 148 149 func (c *csppJoin) MarshalBinary() ([]byte, error) { 150 buf := new(bytes.Buffer) 151 buf.Grow(c.tx.SerializeSize()) 152 err := c.tx.Serialize(buf) 153 return buf.Bytes(), err 154 } 155 156 func (c *csppJoin) UnmarshalBinary(b []byte) error { 157 tx := new(wire.MsgTx) 158 err := tx.Deserialize(bytes.NewReader(b)) 159 if err != nil { 160 return err 161 } 162 163 // Ensure all unmixed inputs, unmixed outputs, and mixed outputs exist. 164 // Mixed outputs must be searched in constant time to avoid sidechannel leakage. 165 txInputs := make(map[wire.OutPoint]int, len(tx.TxIn)) 166 for i, in := range tx.TxIn { 167 txInputs[in.PreviousOutPoint] = i 168 } 169 var n int 170 for _, in := range c.myIns { 171 if index, ok := txInputs[in.PreviousOutPoint]; ok { 172 other := tx.TxIn[index] 173 if in.Sequence != other.Sequence || in.ValueIn != other.ValueIn { 174 break 175 } 176 n++ 177 } 178 } 179 if n != len(c.myIns) { 180 return errors.E("coinjoin is missing inputs") 181 } 182 if c.change != nil { 183 var hasChange bool 184 for _, out := range tx.TxOut { 185 if out.Value != c.change.Value { 186 continue 187 } 188 if out.Version != c.change.Version { 189 continue 190 } 191 if !bytes.Equal(out.PkScript, c.change.PkScript) { 192 continue 193 } 194 hasChange = true 195 break 196 } 197 if !hasChange { 198 return errors.E("coinjoin is missing change") 199 } 200 } 201 indexes, err := constantTimeOutputSearch(tx, c.amount, 0, c.genScripts) 202 if err != nil { 203 return err 204 } 205 206 c.tx = tx 207 c.txInputs = txInputs 208 c.genIndex = indexes 209 return nil 210 } 211 212 // constantTimeOutputSearch searches for the output indexes of mixed outputs to 213 // verify inclusion in a coinjoin. It is constant time such that, for each 214 // searched script, all outputs with equal value, script versions, and script 215 // lengths matching the searched output are checked in constant time. 216 func constantTimeOutputSearch(tx *wire.MsgTx, value int64, scriptVer uint16, scripts [][]byte) ([]int, error) { 217 var scan []int 218 for i, out := range tx.TxOut { 219 if out.Value != value { 220 continue 221 } 222 if out.Version != scriptVer { 223 continue 224 } 225 if len(out.PkScript) != len(scripts[0]) { 226 continue 227 } 228 scan = append(scan, i) 229 } 230 indexes := make([]int, 0, len(scan)) 231 var missing int 232 for _, s := range scripts { 233 idx := -1 234 for _, i := range scan { 235 eq := subtle.ConstantTimeCompare(tx.TxOut[i].PkScript, s) 236 idx = subtle.ConstantTimeSelect(eq, i, idx) 237 } 238 indexes = append(indexes, idx) 239 eq := subtle.ConstantTimeEq(int32(idx), -1) 240 missing = subtle.ConstantTimeSelect(eq, 1, missing) 241 } 242 if missing == 1 { 243 return nil, errMissingGen 244 } 245 return indexes, nil 246 }