github.com/mit-dci/lit@v0.0.0-20221102210550-8c3d3b49f2ce/qln/justicetx.go (about)

     1  package qln
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"github.com/boltdb/bolt"
     7  	"github.com/mit-dci/lit/btcutil/txscript"
     8  	"github.com/mit-dci/lit/consts"
     9  	"github.com/mit-dci/lit/lnutil"
    10  	"github.com/mit-dci/lit/logging"
    11  	"github.com/mit-dci/lit/sig64"
    12  	"github.com/mit-dci/lit/wire"
    13  )
    14  
    15  /*
    16  functions relating to the "justice transaction" (aka penalty transaction)
    17  
    18  
    19  because we're using the sipa/schnorr delinearization, we don't need to vary the PKH
    20  anymore.  We can hand over 1 point per commit & figure everything out from that.
    21  */
    22  
    23  type JusticeTx struct {
    24  	Sig  [64]byte
    25  	Txid [16]byte
    26  	Amt  int64
    27  	Data [32]byte
    28  	Pkh  [20]byte
    29  	Idx  uint64
    30  }
    31  
    32  func (jte *JusticeTx) ToBytes() ([]byte, error) {
    33  	var buf bytes.Buffer
    34  
    35  	// write the sig
    36  	_, err := buf.Write(jte.Sig[:])
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	// write tx id of the bad tx
    42  	_, err = buf.Write(jte.Txid[:])
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	// write the delta for this tx
    47  	_, err = buf.Write(lnutil.I64tB(jte.Amt)[:])
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	// then the data
    53  	_, err = buf.Write(jte.Data[:])
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	// done
    59  	return buf.Bytes(), nil
    60  }
    61  
    62  func JusticeTxFromBytes(jte []byte) (JusticeTx, error) {
    63  	var r JusticeTx
    64  	if len(jte) < 120 || len(jte) > 120 {
    65  		return r, fmt.Errorf("JusticeTx data %d bytes, expect 116", len(jte))
    66  	}
    67  
    68  	copy(r.Sig[:], jte[:64])
    69  	copy(r.Txid[:], jte[64:80])
    70  	r.Amt = lnutil.BtI64(jte[80:88])
    71  	copy(r.Data[:], jte[88:])
    72  
    73  	return r, nil
    74  }
    75  
    76  // BuildWatchTxidSig builds the partial txid and signature pair which can
    77  // be exported to the watchtower.
    78  // This get a channel that is 1 state old.  So we can produce a signature.
    79  func (nd *LitNode) BuildJusticeSig(q *Qchan) error {
    80  
    81  	if nd.SubWallet[q.Coin()] == nil {
    82  		return fmt.Errorf("Not connected to coin type %d\n", q.Coin())
    83  	}
    84  
    85  	// in this function, "bad" refers to the hypothetical transaction spending the
    86  	// com tx.  "justice" is the tx spending the bad tx
    87  
    88  	fee := int64(consts.JusticeTxBump * nd.SubWallet[q.Coin()].Fee())
    89  
    90  	// first we need the keys in the bad script.  Start by getting the elk-scalar
    91  	// we should have it at the "current" state number
    92  	elk, err := q.ElkRcv.AtIndex(q.State.StateIdx)
    93  	if err != nil {
    94  		return err
    95  	}
    96  	// build elkpoint, and rewind the channel's remote elkpoint by one state
    97  	// get elk scalar
    98  	elkScalar := lnutil.ElkScalar(elk)
    99  	// get elk point
   100  	elkPoint := lnutil.ElkPointFromHash(elk)
   101  	// overwrite remote elkpoint in channel state
   102  	q.State.ElkPoint = elkPoint
   103  
   104  	// make pubkeys, build script
   105  	badRevokePub := lnutil.CombinePubs(q.MyHAKDBase, elkPoint)
   106  	badTimeoutPub := lnutil.AddPubsEZ(q.TheirHAKDBase, elkPoint)
   107  	script := lnutil.CommitScript(badRevokePub, badTimeoutPub, q.Delay)
   108  	scriptHashOutScript := lnutil.P2WSHify(script)
   109  
   110  	// TODO: we have to build justics txs for each of the HTLCs too
   111  
   112  	// build the bad tx (redundant as we just build most of it...
   113  	badTx, _, _, err := q.BuildStateTxs(false)
   114  	if err != nil {
   115  		return err
   116  	}
   117  	var badAmt int64
   118  	badIdx := uint32(len(badTx.TxOut) + 1)
   119  
   120  	logging.Infof("made revpub %x timeout pub %x\nscript:%x\nhash %x\n",
   121  		badRevokePub[:], badTimeoutPub[:], script, scriptHashOutScript)
   122  	// figure out which output to bring justice to
   123  	for i, out := range badTx.TxOut {
   124  		logging.Infof("txout %d pkscript %x\n", i, out.PkScript)
   125  		if bytes.Equal(out.PkScript, scriptHashOutScript) {
   126  			badIdx = uint32(i)
   127  			badAmt = out.Value
   128  			break
   129  		}
   130  	}
   131  	if badIdx > uint32(len(badTx.TxOut)) {
   132  		return fmt.Errorf("BuildWatchTxidSig couldn't find revocable SH output")
   133  	}
   134  
   135  	// make a keygen to get the private HAKD base scalar
   136  	kg := q.KeyGen
   137  	kg.Step[2] = UseChannelHAKDBase
   138  	// get HAKD base scalar
   139  	privBase, err := nd.SubWallet[q.Coin()].GetPriv(kg)
   140  	if err != nil {
   141  		return err
   142  	}
   143  	// combine elk & HAKD base to make signing key
   144  	combinedPrivKey := lnutil.CombinePrivKeyWithBytes(privBase, elkScalar[:])
   145  
   146  	// get badtxid
   147  	badTxid := badTx.TxHash()
   148  	// make bad outpoint
   149  	badOP := wire.NewOutPoint(&badTxid, badIdx)
   150  	// make the justice txin, empty sig / witness
   151  	justiceIn := wire.NewTxIn(badOP, nil, nil)
   152  	justiceIn.Sequence = 1
   153  	// make justice output script
   154  	justiceScript := lnutil.DirectWPKHScriptFromPKH(q.WatchRefundAdr)
   155  	// make justice txout
   156  	justiceOut := wire.NewTxOut(badAmt-fee, justiceScript)
   157  
   158  	justiceTx := wire.NewMsgTx()
   159  	// set to version 2, though might not matter as no CSV is used
   160  	justiceTx.Version = 2
   161  
   162  	// add inputs and outputs
   163  	justiceTx.AddTxIn(justiceIn)
   164  	justiceTx.AddTxOut(justiceOut)
   165  
   166  	jtxid := justiceTx.TxHash()
   167  	logging.Infof("made justice tx %s\n", jtxid.String())
   168  
   169  	// get hashcache for signing
   170  	hCache := txscript.NewTxSigHashes(justiceTx)
   171  
   172  	// sign with combined key.  Justice txs always have only 1 input, so txin is 0
   173  	bigSig, err := txscript.RawTxInWitnessSignature(
   174  		justiceTx, hCache, 0, badAmt, script, txscript.SigHashAll, combinedPrivKey)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	// truncate sig (last byte is sighash type, always sighashAll)
   179  	bigSig = bigSig[:len(bigSig)-1]
   180  
   181  	sig, err := sig64.SigCompress(bigSig)
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	var jte JusticeTx
   187  	copy(jte.Sig[:], sig[:])
   188  	copy(jte.Txid[:], badTxid[:16])
   189  	jte.Data = q.State.Data
   190  	jte.Amt = q.State.MyAmt
   191  
   192  	justiceBytes, err := jte.ToBytes()
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	var justiceBytesFixed [120]byte
   198  	copy(justiceBytesFixed[:], justiceBytes[:120])
   199  
   200  	return nd.SaveJusticeSig(q.State.StateIdx, q.WatchRefundAdr, justiceBytesFixed)
   201  }
   202  
   203  // SaveJusticeSig save the txid/sig of a justice transaction to the db.  Pretty
   204  // straightforward
   205  func (nd *LitNode) SaveJusticeSig(comnum uint64, pkh [20]byte, txidsig [120]byte) error {
   206  	return nd.LitDB.Update(func(btx *bolt.Tx) error {
   207  		sigs := btx.Bucket(BKTWatch)
   208  		if sigs == nil {
   209  			return fmt.Errorf("no justice bucket")
   210  		}
   211  		// one bucket per refund PKH
   212  		justBkt, err := sigs.CreateBucketIfNotExists(pkh[:])
   213  		if err != nil {
   214  			return err
   215  		}
   216  
   217  		return justBkt.Put(lnutil.U64tB(comnum), txidsig[:])
   218  	})
   219  }
   220  
   221  func (nd *LitNode) LoadJusticeSig(comnum uint64, pkh [20]byte) (JusticeTx, error) {
   222  	var txidsig JusticeTx
   223  
   224  	err := nd.LitDB.View(func(btx *bolt.Tx) error {
   225  		sigs := btx.Bucket(BKTWatch)
   226  		if sigs == nil {
   227  			return fmt.Errorf("no justice bucket")
   228  		}
   229  		// one bucket per refund PKH
   230  		justBkt := sigs.Bucket(pkh[:])
   231  		if justBkt == nil {
   232  			return fmt.Errorf("pkh %x not in justice bucket", pkh)
   233  		}
   234  		sigbytes := justBkt.Get(lnutil.U64tB(comnum))
   235  		if sigbytes == nil {
   236  			return fmt.Errorf("state %d not in db under pkh %x", comnum, pkh)
   237  		}
   238  
   239  		var err error
   240  		txidsig, err = JusticeTxFromBytes(sigbytes)
   241  		if err != nil {
   242  			return err
   243  		}
   244  
   245  		return nil
   246  	})
   247  
   248  	return txidsig, err
   249  }
   250  
   251  func (nd *LitNode) DumpJusticeDB() ([]JusticeTx, error) {
   252  	var txs []JusticeTx
   253  
   254  	err := nd.LitDB.View(func(btx *bolt.Tx) error {
   255  		sigs := btx.Bucket(BKTWatch)
   256  		if sigs == nil {
   257  			return fmt.Errorf("no justice bucket")
   258  		}
   259  
   260  		// go through all pkh buckets
   261  		return sigs.ForEach(func(k, _ []byte) error {
   262  			pkhBucket := sigs.Bucket(k)
   263  			if pkhBucket == nil {
   264  				return fmt.Errorf("%x not a bucket", k)
   265  			}
   266  			return pkhBucket.ForEach(func(idx, txidsig []byte) error {
   267  				var jtx JusticeTx
   268  				jtx, err := JusticeTxFromBytes(txidsig)
   269  				if err != nil {
   270  					return err
   271  				}
   272  
   273  				copy(jtx.Pkh[:], k[:20])
   274  				jtx.Idx = lnutil.BtU64(idx)
   275  
   276  				txs = append(txs, jtx)
   277  
   278  				return nil
   279  			})
   280  		})
   281  	})
   282  	return txs, err
   283  }
   284  
   285  func (nd *LitNode) ShowJusticeDB() (string, error) {
   286  	var s string
   287  
   288  	err := nd.LitDB.View(func(btx *bolt.Tx) error {
   289  		sigs := btx.Bucket(BKTWatch)
   290  		if sigs == nil {
   291  			return fmt.Errorf("no justice bucket")
   292  		}
   293  
   294  		// go through all pkh buckets
   295  		return sigs.ForEach(func(k, _ []byte) error {
   296  			s += fmt.Sprintf("Channel refunding to pkh %x\n", k)
   297  			pkhBucket := sigs.Bucket(k)
   298  			if pkhBucket == nil {
   299  				return fmt.Errorf("%x not a bucket", k)
   300  			}
   301  			return pkhBucket.ForEach(func(idx, txidsig []byte) error {
   302  				s += fmt.Sprintf("\tidx %x\t txidsig: %x\n", idx, txidsig[:80])
   303  				return nil
   304  			})
   305  		})
   306  	})
   307  	return s, err
   308  }
   309  
   310  // SendWatch syncs up the remote watchtower with all justice signatures
   311  func (nd *LitNode) SyncWatch(qc *Qchan, watchPeer uint32) error {
   312  
   313  	if !nd.ConnectedToPeer(watchPeer) {
   314  		return fmt.Errorf("SyncWatch: not connected to peer %d", watchPeer)
   315  	}
   316  	// if watchUpTo isn't 2 behind the state number, there's nothing to send
   317  	// kindof confusing inequality: can't send state 0 info to watcher when at
   318  	// state 1.  State 0 needs special handling.
   319  	if qc.State.WatchUpTo+2 > qc.State.StateIdx || qc.State.StateIdx < 2 {
   320  		return fmt.Errorf("Channel at state %d, up to %d exported, nothing to do",
   321  			qc.State.StateIdx, qc.State.WatchUpTo)
   322  	}
   323  	// send initial description if we haven't sent anything yet
   324  	if qc.State.WatchUpTo == 0 {
   325  		desc := lnutil.NewWatchDescMsg(watchPeer, qc.Coin(),
   326  			qc.WatchRefundAdr, qc.Delay, consts.JusticeFee, qc.TheirHAKDBase, qc.MyHAKDBase)
   327  
   328  		nd.tmpSendLitMsg(desc)
   329  		// after sending description, must send at least states 0 and 1.
   330  		err := nd.SendWatchComMsg(qc, 0, watchPeer)
   331  		if err != nil {
   332  			return err
   333  		}
   334  		err = nd.SendWatchComMsg(qc, 1, watchPeer)
   335  		if err != nil {
   336  			return err
   337  		}
   338  		qc.State.WatchUpTo = 1
   339  	}
   340  	// send messages to get up to 1 less than current state
   341  	for qc.State.WatchUpTo < qc.State.StateIdx-1 {
   342  		// increment watchupto number
   343  		qc.State.WatchUpTo++
   344  		err := nd.SendWatchComMsg(qc, qc.State.WatchUpTo, watchPeer)
   345  		if err != nil {
   346  			return err
   347  		}
   348  	}
   349  	// save updated WatchUpTo number
   350  	return nd.SaveQchanState(qc)
   351  }
   352  
   353  // send WatchComMsg generates and sends the ComMsg to a watchtower
   354  func (nd *LitNode) SendWatchComMsg(qc *Qchan, idx uint64, watchPeer uint32) error {
   355  	// retrieve the sig data from db
   356  	txidsig, err := nd.LoadJusticeSig(idx, qc.WatchRefundAdr)
   357  	if err != nil {
   358  		return err
   359  	}
   360  	// get the elkrem
   361  	elk, err := qc.ElkRcv.AtIndex(idx)
   362  	if err != nil {
   363  		return err
   364  	}
   365  
   366  	comMsg := lnutil.NewComMsg(
   367  		watchPeer, qc.Coin(), qc.WatchRefundAdr, *elk, txidsig.Txid, txidsig.Sig)
   368  
   369  	nd.tmpSendLitMsg(comMsg)
   370  	return err
   371  }