github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/wallet/wallet_test.go (about)

     1  package wallet
     2  
     3  import (
     4  	"encoding/json"
     5  	"io/ioutil"
     6  	"os"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/bytom/bytom/account"
    12  	"github.com/bytom/bytom/asset"
    13  	"github.com/bytom/bytom/blockchain/pseudohsm"
    14  	"github.com/bytom/bytom/blockchain/signers"
    15  	"github.com/bytom/bytom/blockchain/txbuilder"
    16  	"github.com/bytom/bytom/config"
    17  	"github.com/bytom/bytom/consensus"
    18  	"github.com/bytom/bytom/crypto/ed25519/chainkd"
    19  	"github.com/bytom/bytom/database"
    20  	dbm "github.com/bytom/bytom/database/leveldb"
    21  	"github.com/bytom/bytom/event"
    22  	"github.com/bytom/bytom/protocol"
    23  	"github.com/bytom/bytom/protocol/bc"
    24  	"github.com/bytom/bytom/protocol/bc/types"
    25  )
    26  
    27  func TestEncodeDecodeGlobalTxIndex(t *testing.T) {
    28  	want := &struct {
    29  		BlockHash bc.Hash
    30  		Position  uint64
    31  	}{
    32  		BlockHash: bc.NewHash([32]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}),
    33  		Position:  1,
    34  	}
    35  
    36  	globalTxIdx := calcGlobalTxIndex(&want.BlockHash, want.Position)
    37  	blockHashGot, positionGot := parseGlobalTxIdx(globalTxIdx)
    38  	if *blockHashGot != want.BlockHash {
    39  		t.Errorf("blockHash mismatch. Get: %v. Expect: %v", *blockHashGot, want.BlockHash)
    40  	}
    41  
    42  	if positionGot != want.Position {
    43  		t.Errorf("position mismatch. Get: %v. Expect: %v", positionGot, want.Position)
    44  	}
    45  }
    46  
    47  func TestWalletVersion(t *testing.T) {
    48  	// prepare wallet
    49  	dirPath, err := ioutil.TempDir(".", "")
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  	defer os.RemoveAll(dirPath)
    54  
    55  	testDB := dbm.NewDB("testdb", "leveldb", "temp")
    56  	defer os.RemoveAll("temp")
    57  
    58  	dispatcher := event.NewDispatcher()
    59  	w := mockWallet(testDB, nil, nil, nil, dispatcher, false)
    60  
    61  	// legacy status test case
    62  	type legacyStatusInfo struct {
    63  		WorkHeight uint64
    64  		WorkHash   bc.Hash
    65  		BestHeight uint64
    66  		BestHash   bc.Hash
    67  	}
    68  	rawWallet, err := json.Marshal(legacyStatusInfo{})
    69  	if err != nil {
    70  		t.Fatal("Marshal legacyStatusInfo")
    71  	}
    72  
    73  	w.DB.Set(walletKey, rawWallet)
    74  	rawWallet = w.DB.Get(walletKey)
    75  	if rawWallet == nil {
    76  		t.Fatal("fail to load wallet StatusInfo")
    77  	}
    78  
    79  	if err := json.Unmarshal(rawWallet, &w.status); err != nil {
    80  		t.Fatal(err)
    81  	}
    82  
    83  	if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
    84  		t.Fatal("fail to detect legacy wallet version")
    85  	}
    86  
    87  	// lower wallet version test case
    88  	lowerVersion := StatusInfo{Version: currentVersion - 1}
    89  	rawWallet, err = json.Marshal(lowerVersion)
    90  	if err != nil {
    91  		t.Fatal("save wallet info")
    92  	}
    93  
    94  	w.DB.Set(walletKey, rawWallet)
    95  	rawWallet = w.DB.Get(walletKey)
    96  	if rawWallet == nil {
    97  		t.Fatal("fail to load wallet StatusInfo")
    98  	}
    99  
   100  	if err := json.Unmarshal(rawWallet, &w.status); err != nil {
   101  		t.Fatal(err)
   102  	}
   103  
   104  	if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
   105  		t.Fatal("fail to detect expired wallet version")
   106  	}
   107  }
   108  
   109  func TestWalletUpdate(t *testing.T) {
   110  	dirPath, err := ioutil.TempDir(".", "")
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  	defer os.RemoveAll(dirPath)
   115  
   116  	testDB := dbm.NewDB("testdb", "leveldb", "temp")
   117  	defer os.RemoveAll("temp")
   118  
   119  	store := database.NewStore(testDB)
   120  	dispatcher := event.NewDispatcher()
   121  	txPool := protocol.NewTxPool(store, dispatcher)
   122  
   123  	chain, err := protocol.NewChain(store, txPool)
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  
   128  	accountManager := account.NewManager(testDB, chain)
   129  	hsm, err := pseudohsm.New(dirPath)
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  
   134  	xpub1, _, err := hsm.XCreate("test_pub1", "password", "en")
   135  	if err != nil {
   136  		t.Fatal(err)
   137  	}
   138  
   139  	testAccount, err := accountManager.Create([]chainkd.XPub{xpub1.XPub}, 1, "testAccount", signers.BIP0044)
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  
   144  	controlProg, err := accountManager.CreateAddress(testAccount.ID, false)
   145  	if err != nil {
   146  		t.Fatal(err)
   147  	}
   148  
   149  	controlProg.KeyIndex = 1
   150  
   151  	reg := asset.NewRegistry(testDB, chain)
   152  	asset, err := reg.Define([]chainkd.XPub{xpub1.XPub}, 1, nil, 0, "TESTASSET", nil)
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  
   157  	utxos := []*account.UTXO{}
   158  	btmUtxo := mockUTXO(controlProg, consensus.BTMAssetID)
   159  	utxos = append(utxos, btmUtxo)
   160  	OtherUtxo := mockUTXO(controlProg, &asset.AssetID)
   161  	utxos = append(utxos, OtherUtxo)
   162  
   163  	_, txData, err := mockTxData(utxos, testAccount)
   164  	if err != nil {
   165  		t.Fatal(err)
   166  	}
   167  
   168  	tx := types.NewTx(*txData)
   169  	block := mockSingleBlock(tx)
   170  	txStatus := bc.NewTransactionStatus()
   171  	txStatus.SetStatus(0, false)
   172  	txStatus.SetStatus(1, false)
   173  	store.SaveBlock(block, txStatus)
   174  
   175  	w := mockWallet(testDB, accountManager, reg, chain, dispatcher, true)
   176  	err = w.AttachBlock(block)
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  
   181  	if _, err := w.GetTransactionByTxID(tx.ID.String()); err != nil {
   182  		t.Fatal(err)
   183  	}
   184  
   185  	wants, err := w.GetTransactions("")
   186  	if len(wants) != 1 {
   187  		t.Fatal(err)
   188  	}
   189  
   190  	if wants[0].ID != tx.ID {
   191  		t.Fatal("account txID mismatch")
   192  	}
   193  
   194  	for position, tx := range block.Transactions {
   195  		get := w.DB.Get(calcGlobalTxIndexKey(tx.ID.String()))
   196  		bh := block.BlockHeader.Hash()
   197  		expect := calcGlobalTxIndex(&bh, uint64(position))
   198  		if !reflect.DeepEqual(get, expect) {
   199  			t.Fatalf("position#%d: compare retrieved globalTxIdx err", position)
   200  		}
   201  	}
   202  }
   203  
   204  func TestRescanWallet(t *testing.T) {
   205  	// prepare wallet & db
   206  	dirPath, err := ioutil.TempDir(".", "")
   207  	if err != nil {
   208  		t.Fatal(err)
   209  	}
   210  	defer os.RemoveAll(dirPath)
   211  
   212  	testDB := dbm.NewDB("testdb", "leveldb", "temp")
   213  	defer os.RemoveAll("temp")
   214  
   215  	store := database.NewStore(testDB)
   216  	dispatcher := event.NewDispatcher()
   217  	txPool := protocol.NewTxPool(store, dispatcher)
   218  	chain, err := protocol.NewChain(store, txPool)
   219  	if err != nil {
   220  		t.Fatal(err)
   221  	}
   222  
   223  	statusInfo := StatusInfo{
   224  		Version:  currentVersion,
   225  		WorkHash: bc.Hash{V0: 0xff},
   226  	}
   227  	rawWallet, err := json.Marshal(statusInfo)
   228  	if err != nil {
   229  		t.Fatal("save wallet info")
   230  	}
   231  
   232  	w := mockWallet(testDB, nil, nil, chain, dispatcher, false)
   233  	w.DB.Set(walletKey, rawWallet)
   234  	rawWallet = w.DB.Get(walletKey)
   235  	if rawWallet == nil {
   236  		t.Fatal("fail to load wallet StatusInfo")
   237  	}
   238  
   239  	if err := json.Unmarshal(rawWallet, &w.status); err != nil {
   240  		t.Fatal(err)
   241  	}
   242  
   243  	// rescan wallet
   244  	if err := w.loadWalletInfo(); err != nil {
   245  		t.Fatal(err)
   246  	}
   247  
   248  	block := config.GenesisBlock()
   249  	if w.status.WorkHash != block.Hash() {
   250  		t.Fatal("reattach from genesis block")
   251  	}
   252  }
   253  
   254  func TestMemPoolTxQueryLoop(t *testing.T) {
   255  	dirPath, err := ioutil.TempDir(".", "")
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	defer os.RemoveAll(dirPath)
   260  
   261  	testDB := dbm.NewDB("testdb", "leveldb", dirPath)
   262  
   263  	store := database.NewStore(testDB)
   264  	dispatcher := event.NewDispatcher()
   265  	txPool := protocol.NewTxPool(store, dispatcher)
   266  
   267  	chain, err := protocol.NewChain(store, txPool)
   268  	if err != nil {
   269  		t.Fatal(err)
   270  	}
   271  
   272  	accountManager := account.NewManager(testDB, chain)
   273  	hsm, err := pseudohsm.New(dirPath)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  
   278  	xpub1, _, err := hsm.XCreate("test_pub1", "password", "en")
   279  	if err != nil {
   280  		t.Fatal(err)
   281  	}
   282  
   283  	testAccount, err := accountManager.Create([]chainkd.XPub{xpub1.XPub}, 1, "testAccount", signers.BIP0044)
   284  	if err != nil {
   285  		t.Fatal(err)
   286  	}
   287  
   288  	controlProg, err := accountManager.CreateAddress(testAccount.ID, false)
   289  	if err != nil {
   290  		t.Fatal(err)
   291  	}
   292  
   293  	controlProg.KeyIndex = 1
   294  
   295  	reg := asset.NewRegistry(testDB, chain)
   296  	asset, err := reg.Define([]chainkd.XPub{xpub1.XPub}, 1, nil, 0, "TESTASSET", nil)
   297  	if err != nil {
   298  		t.Fatal(err)
   299  	}
   300  
   301  	utxos := []*account.UTXO{}
   302  	btmUtxo := mockUTXO(controlProg, consensus.BTMAssetID)
   303  	utxos = append(utxos, btmUtxo)
   304  	OtherUtxo := mockUTXO(controlProg, &asset.AssetID)
   305  	utxos = append(utxos, OtherUtxo)
   306  
   307  	_, txData, err := mockTxData(utxos, testAccount)
   308  	if err != nil {
   309  		t.Fatal(err)
   310  	}
   311  
   312  	tx := types.NewTx(*txData)
   313  	//block := mockSingleBlock(tx)
   314  	txStatus := bc.NewTransactionStatus()
   315  	txStatus.SetStatus(0, false)
   316  	w, err := NewWallet(testDB, accountManager, reg, hsm, chain, dispatcher, false)
   317  	go w.memPoolTxQueryLoop()
   318  	w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: protocol.MsgNewTx}})
   319  	time.Sleep(time.Millisecond * 10)
   320  	if _, err = w.GetUnconfirmedTxByTxID(tx.ID.String()); err != nil {
   321  		t.Fatal("disaptch new tx msg error:", err)
   322  	}
   323  	w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: protocol.MsgRemoveTx}})
   324  	time.Sleep(time.Millisecond * 10)
   325  	txs, err := w.GetUnconfirmedTxs(testAccount.ID)
   326  	if err != nil {
   327  		t.Fatal("get unconfirmed tx error:", err)
   328  	}
   329  
   330  	if len(txs) != 0 {
   331  		t.Fatal("disaptch remove tx msg error")
   332  	}
   333  
   334  	w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: 2}})
   335  }
   336  
   337  func mockUTXO(controlProg *account.CtrlProgram, assetID *bc.AssetID) *account.UTXO {
   338  	utxo := &account.UTXO{}
   339  	utxo.OutputID = bc.Hash{V0: 1}
   340  	utxo.SourceID = bc.Hash{V0: 2}
   341  	utxo.AssetID = *assetID
   342  	utxo.Amount = 1000000000
   343  	utxo.SourcePos = 0
   344  	utxo.ControlProgram = controlProg.ControlProgram
   345  	utxo.AccountID = controlProg.AccountID
   346  	utxo.Address = controlProg.Address
   347  	utxo.ControlProgramIndex = controlProg.KeyIndex
   348  	return utxo
   349  }
   350  
   351  func mockTxData(utxos []*account.UTXO, testAccount *account.Account) (*txbuilder.Template, *types.TxData, error) {
   352  	tplBuilder := txbuilder.NewBuilder(time.Now())
   353  
   354  	for _, utxo := range utxos {
   355  		txInput, sigInst, err := account.UtxoToInputs(testAccount.Signer, utxo)
   356  		if err != nil {
   357  			return nil, nil, err
   358  		}
   359  		tplBuilder.AddInput(txInput, sigInst)
   360  
   361  		out := &types.TxOutput{}
   362  		if utxo.AssetID == *consensus.BTMAssetID {
   363  			out = types.NewTxOutput(utxo.AssetID, 100, utxo.ControlProgram)
   364  		} else {
   365  			out = types.NewTxOutput(utxo.AssetID, utxo.Amount, utxo.ControlProgram)
   366  		}
   367  		tplBuilder.AddOutput(out)
   368  	}
   369  
   370  	return tplBuilder.Build()
   371  }
   372  
   373  func mockWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) *Wallet {
   374  	wallet := &Wallet{
   375  		DB:              walletDB,
   376  		AccountMgr:      account,
   377  		AssetReg:        asset,
   378  		chain:           chain,
   379  		RecoveryMgr:     newRecoveryManager(walletDB, account),
   380  		eventDispatcher: dispatcher,
   381  		TxIndexFlag:     txIndexFlag,
   382  	}
   383  	wallet.txMsgSub, _ = wallet.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
   384  	return wallet
   385  }
   386  
   387  func mockSingleBlock(tx *types.Tx) *types.Block {
   388  	return &types.Block{
   389  		BlockHeader: types.BlockHeader{
   390  			Version: 1,
   391  			Height:  1,
   392  			Bits:    2305843009230471167,
   393  		},
   394  		Transactions: []*types.Tx{config.GenesisTx(), tx},
   395  	}
   396  }