github.com/0xsequence/ethkit@v1.25.0/ethwallet/ethwallet.go (about)

     1  package ethwallet
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"fmt"
     8  	"math/big"
     9  
    10  	"github.com/0xsequence/ethkit/ethrpc"
    11  	"github.com/0xsequence/ethkit/ethtxn"
    12  	"github.com/0xsequence/ethkit/go-ethereum/accounts"
    13  	"github.com/0xsequence/ethkit/go-ethereum/accounts/abi/bind"
    14  	"github.com/0xsequence/ethkit/go-ethereum/common"
    15  	"github.com/0xsequence/ethkit/go-ethereum/common/hexutil"
    16  	"github.com/0xsequence/ethkit/go-ethereum/core/types"
    17  	"github.com/0xsequence/ethkit/go-ethereum/crypto"
    18  )
    19  
    20  var DefaultWalletOptions = WalletOptions{
    21  	DerivationPath:             "m/44'/60'/0'/0/0",
    22  	RandomWalletEntropyBitSize: EntropyBitSize12WordMnemonic,
    23  }
    24  
    25  type Wallet struct {
    26  	hdnode         *HDNode
    27  	provider       *ethrpc.Provider
    28  	walletProvider *WalletProvider
    29  }
    30  
    31  type WalletOptions struct {
    32  	DerivationPath             string
    33  	RandomWalletEntropyBitSize int
    34  }
    35  
    36  func NewWalletFromPrivateKey(key string) (*Wallet, error) {
    37  	hdnode, err := NewHDNodeFromPrivateKey(key)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	return &Wallet{hdnode: hdnode}, nil
    42  }
    43  
    44  func NewWalletFromHDNode(hdnode *HDNode, optPath ...accounts.DerivationPath) (*Wallet, error) {
    45  	var err error
    46  	derivationPath := DefaultBaseDerivationPath
    47  	if len(optPath) > 0 {
    48  		derivationPath = optPath[0]
    49  	}
    50  
    51  	err = hdnode.DerivePath(derivationPath)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	return &Wallet{hdnode: hdnode}, nil
    57  }
    58  
    59  func NewWalletFromRandomEntropy(options ...WalletOptions) (*Wallet, error) {
    60  	opts := DefaultWalletOptions
    61  	if len(options) > 0 {
    62  		opts = options[0]
    63  	}
    64  
    65  	derivationPath, err := ParseDerivationPath(opts.DerivationPath)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	hdnode, err := NewHDNodeFromRandomEntropy(opts.RandomWalletEntropyBitSize, &derivationPath)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	wallet, err := NewWalletFromHDNode(hdnode, derivationPath)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return wallet, nil
    80  }
    81  
    82  func NewWalletFromMnemonic(mnemonic string, optPath ...string) (*Wallet, error) {
    83  	var err error
    84  	derivationPath := DefaultBaseDerivationPath
    85  	if len(optPath) > 0 {
    86  		derivationPath, err = ParseDerivationPath(optPath[0])
    87  		if err != nil {
    88  			return nil, err
    89  		}
    90  	}
    91  
    92  	hdnode, err := NewHDNodeFromMnemonic(mnemonic, &derivationPath)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	wallet, err := NewWalletFromHDNode(hdnode, derivationPath)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return wallet, nil
   102  }
   103  
   104  func (w *Wallet) Clone() (*Wallet, error) {
   105  	hdnode, err := w.hdnode.Clone()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	return &Wallet{
   110  		hdnode: hdnode, provider: w.provider,
   111  	}, nil
   112  }
   113  
   114  func (w *Wallet) Transactor(ctx context.Context) (*bind.TransactOpts, error) {
   115  	var chainID *big.Int
   116  	if w.provider != nil {
   117  		var err error
   118  		chainID, err = w.provider.ChainID(ctx)
   119  		if err != nil {
   120  			return nil, err
   121  		}
   122  	}
   123  	return w.TransactorForChainID(chainID)
   124  }
   125  
   126  func (w *Wallet) TransactorForChainID(chainID *big.Int) (*bind.TransactOpts, error) {
   127  	if chainID == nil {
   128  		// This is deprecated and will log a warning since it uses the original Homestead signer
   129  		return bind.NewKeyedTransactor(w.hdnode.PrivateKey()), nil
   130  	} else {
   131  		return bind.NewKeyedTransactorWithChainID(w.hdnode.PrivateKey(), chainID)
   132  	}
   133  }
   134  
   135  func (w *Wallet) GetProvider() *ethrpc.Provider {
   136  	return w.provider
   137  }
   138  
   139  func (w *Wallet) SetProvider(provider *ethrpc.Provider) {
   140  	w.provider = provider
   141  
   142  	if w.walletProvider == nil {
   143  		w.walletProvider = &WalletProvider{wallet: w}
   144  	}
   145  	w.walletProvider.provider = provider
   146  }
   147  
   148  func (w *Wallet) Provider() *WalletProvider {
   149  	return w.walletProvider
   150  }
   151  
   152  func (w *Wallet) SelfDerivePath(path accounts.DerivationPath) (common.Address, error) {
   153  	err := w.hdnode.DerivePath(path)
   154  	if err != nil {
   155  		return common.Address{}, err
   156  	}
   157  	return w.hdnode.Address(), nil
   158  }
   159  
   160  func (w *Wallet) DerivePath(path accounts.DerivationPath) (*Wallet, common.Address, error) {
   161  	wallet, err := w.Clone()
   162  	if err != nil {
   163  		return nil, common.Address{}, err
   164  	}
   165  	address, err := wallet.SelfDerivePath(path)
   166  	return wallet, address, err
   167  }
   168  
   169  func (w *Wallet) SelfDerivePathFromString(path string) (common.Address, error) {
   170  	err := w.hdnode.DerivePathFromString(path)
   171  	if err != nil {
   172  		return common.Address{}, err
   173  	}
   174  	return w.hdnode.Address(), nil
   175  }
   176  
   177  func (w *Wallet) DerivePathFromString(path string) (*Wallet, common.Address, error) {
   178  	wallet, err := w.Clone()
   179  	if err != nil {
   180  		return nil, common.Address{}, err
   181  	}
   182  	address, err := wallet.SelfDerivePathFromString(path)
   183  	return wallet, address, err
   184  }
   185  
   186  func (w *Wallet) SelfDeriveAccountIndex(accountIndex uint32) (common.Address, error) {
   187  	err := w.hdnode.DeriveAccountIndex(accountIndex)
   188  	if err != nil {
   189  		return common.Address{}, err
   190  	}
   191  	return w.hdnode.Address(), nil
   192  }
   193  
   194  func (w *Wallet) DeriveAccountIndex(accountIndex uint32) (*Wallet, common.Address, error) {
   195  	wallet, err := w.Clone()
   196  	if err != nil {
   197  		return nil, common.Address{}, err
   198  	}
   199  	address, err := wallet.SelfDeriveAccountIndex(accountIndex)
   200  	return wallet, address, err
   201  }
   202  
   203  func (w *Wallet) Address() common.Address {
   204  	return w.hdnode.Address()
   205  }
   206  
   207  func (w *Wallet) HDNode() *HDNode {
   208  	return w.hdnode
   209  }
   210  
   211  func (w *Wallet) PrivateKey() *ecdsa.PrivateKey {
   212  	return w.hdnode.PrivateKey()
   213  }
   214  
   215  func (w *Wallet) PublicKey() *ecdsa.PublicKey {
   216  	return w.hdnode.PublicKey()
   217  }
   218  
   219  func (w *Wallet) PrivateKeyHex() string {
   220  	privateKeyBytes := crypto.FromECDSA(w.hdnode.PrivateKey())
   221  	return hexutil.Encode(privateKeyBytes)
   222  }
   223  
   224  func (w *Wallet) PublicKeyHex() string {
   225  	publicKeyBytes := crypto.FromECDSAPub(w.hdnode.PublicKey())
   226  	return hexutil.Encode(publicKeyBytes)
   227  }
   228  
   229  func (w *Wallet) GetBalance(ctx context.Context) (*big.Int, error) {
   230  	return w.GetProvider().BalanceAt(ctx, w.Address(), nil)
   231  }
   232  
   233  func (w *Wallet) GetNonce(ctx context.Context) (uint64, error) {
   234  	return w.GetProvider().NonceAt(ctx, w.Address(), nil)
   235  }
   236  
   237  func (w *Wallet) SignTx(tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) {
   238  	signer := types.LatestSignerForChainID(chainID)
   239  	signedTx, err := types.SignTx(tx, signer, w.hdnode.PrivateKey())
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	msg, err := signedTx.AsMessage(signer, nil)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	sender := msg.From()
   250  	if sender != w.hdnode.Address() {
   251  		return nil, fmt.Errorf("signer mismatch: expected %s, got %s", w.hdnode.Address().Hex(), sender.Hex())
   252  	}
   253  
   254  	return signedTx, nil
   255  }
   256  
   257  func (w *Wallet) SignMessage(message []byte) ([]byte, error) {
   258  	message191 := []byte("\x19Ethereum Signed Message:\n")
   259  	if !bytes.HasPrefix(message, message191) {
   260  		mlen := fmt.Sprintf("%d", len(message))
   261  		message191 = append(message191, []byte(mlen)...)
   262  		message191 = append(message191, message...)
   263  	} else {
   264  		message191 = message
   265  	}
   266  
   267  	h := crypto.Keccak256(message191)
   268  
   269  	sig, err := crypto.Sign(h, w.hdnode.PrivateKey())
   270  	if err != nil {
   271  		return []byte{}, err
   272  	}
   273  	sig[64] += 27
   274  
   275  	return sig, nil
   276  }
   277  
   278  func (w *Wallet) SignData(data []byte) ([]byte, error) {
   279  	h := crypto.Keccak256(data)
   280  
   281  	sig, err := crypto.Sign(h, w.hdnode.PrivateKey())
   282  	if err != nil {
   283  		return []byte{}, err
   284  	}
   285  	sig[64] += 27
   286  
   287  	return sig, nil
   288  }
   289  
   290  func (w *Wallet) IsValidSignature(msg, sig []byte) (bool, error) {
   291  	recoveredAddress, err := RecoverAddress(msg, sig)
   292  	if err != nil {
   293  		return false, err
   294  	}
   295  	if recoveredAddress == w.Address() {
   296  		return true, nil
   297  	}
   298  	return false, fmt.Errorf("signature does not match recovered address for this message")
   299  }
   300  
   301  func (w *Wallet) IsValidSignatureOfDigest(digest, sig []byte) (bool, error) {
   302  	recoveredAddress, err := RecoverAddressFromDigest(digest, sig)
   303  	if err != nil {
   304  		return false, err
   305  	}
   306  	if recoveredAddress == w.Address() {
   307  		return true, nil
   308  	}
   309  	return false, fmt.Errorf("signature does not match recovered address for this message digest")
   310  }
   311  
   312  func (w *Wallet) NewTransaction(ctx context.Context, txnRequest *ethtxn.TransactionRequest) (*types.Transaction, error) {
   313  	if txnRequest == nil {
   314  		return nil, fmt.Errorf("ethwallet: txnRequest is required")
   315  	}
   316  
   317  	provider := w.GetProvider()
   318  	if provider == nil {
   319  		return nil, fmt.Errorf("ethwallet: provider is not set")
   320  	}
   321  
   322  	chainID, err := provider.ChainID(ctx)
   323  	if err != nil {
   324  		return nil, fmt.Errorf("ethwallet: %w", err)
   325  	}
   326  
   327  	txnRequest.From = w.Address()
   328  
   329  	rawTx, err := ethtxn.NewTransaction(ctx, provider, txnRequest)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	signedTx, err := w.SignTx(rawTx, chainID)
   335  	if err != nil {
   336  		return nil, fmt.Errorf("ethwallet: %w", err)
   337  	}
   338  
   339  	return signedTx, nil
   340  }
   341  
   342  func (w *Wallet) SendTransaction(ctx context.Context, signedTx *types.Transaction) (*types.Transaction, ethtxn.WaitReceipt, error) {
   343  	provider := w.GetProvider()
   344  	if provider == nil {
   345  		return nil, nil, fmt.Errorf("ethwallet (SendTransaction): provider is not set")
   346  	}
   347  	return ethtxn.SendTransaction(ctx, provider, signedTx)
   348  }