github.com/Finschia/finschia-sdk@v0.49.1/x/auth/tx/direct_test.go (about)

     1  package tx
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"github.com/Finschia/finschia-sdk/codec"
    10  	codectypes "github.com/Finschia/finschia-sdk/codec/types"
    11  	"github.com/Finschia/finschia-sdk/testutil/testdata"
    12  	sdk "github.com/Finschia/finschia-sdk/types"
    13  	txtypes "github.com/Finschia/finschia-sdk/types/tx"
    14  	signingtypes "github.com/Finschia/finschia-sdk/types/tx/signing"
    15  	"github.com/Finschia/finschia-sdk/x/auth/signing"
    16  )
    17  
    18  func TestDirectModeHandler(t *testing.T) {
    19  	privKey, pubkey, addr := testdata.KeyTestPubAddr()
    20  	interfaceRegistry := codectypes.NewInterfaceRegistry()
    21  	interfaceRegistry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
    22  	marshaler := codec.NewProtoCodec(interfaceRegistry)
    23  
    24  	txConfig := NewTxConfig(marshaler, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT})
    25  	txBuilder := txConfig.NewTxBuilder()
    26  
    27  	memo := "sometestmemo"
    28  	msgs := []sdk.Msg{testdata.NewTestMsg(addr)}
    29  	accSeq := uint64(2) // Arbitrary account sequence
    30  	any, err := codectypes.NewAnyWithValue(pubkey)
    31  	require.NoError(t, err)
    32  
    33  	var signerInfo []*txtypes.SignerInfo
    34  	signerInfo = append(signerInfo, &txtypes.SignerInfo{
    35  		PublicKey: any,
    36  		ModeInfo: &txtypes.ModeInfo{
    37  			Sum: &txtypes.ModeInfo_Single_{
    38  				Single: &txtypes.ModeInfo_Single{
    39  					Mode: signingtypes.SignMode_SIGN_MODE_DIRECT,
    40  				},
    41  			},
    42  		},
    43  		Sequence: accSeq,
    44  	})
    45  
    46  	sigData := &signingtypes.SingleSignatureData{
    47  		SignMode: signingtypes.SignMode_SIGN_MODE_DIRECT,
    48  	}
    49  	sig := signingtypes.SignatureV2{
    50  		PubKey:   pubkey,
    51  		Data:     sigData,
    52  		Sequence: accSeq,
    53  	}
    54  
    55  	fee := txtypes.Fee{Amount: sdk.NewCoins(sdk.NewInt64Coin("atom", 150)), GasLimit: 20000}
    56  
    57  	err = txBuilder.SetMsgs(msgs...)
    58  	require.NoError(t, err)
    59  	txBuilder.SetMemo(memo)
    60  	txBuilder.SetFeeAmount(fee.Amount)
    61  	txBuilder.SetGasLimit(fee.GasLimit)
    62  
    63  	err = txBuilder.SetSignatures(sig)
    64  	require.NoError(t, err)
    65  
    66  	t.Log("verify modes and default-mode")
    67  	modeHandler := txConfig.SignModeHandler()
    68  	require.Equal(t, modeHandler.DefaultMode(), signingtypes.SignMode_SIGN_MODE_DIRECT)
    69  	require.Len(t, modeHandler.Modes(), 1)
    70  
    71  	signingData := signing.SignerData{
    72  		ChainID:       "test-chain",
    73  		AccountNumber: 1,
    74  	}
    75  
    76  	signBytes, err := modeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
    77  
    78  	require.NoError(t, err)
    79  	require.NotNil(t, signBytes)
    80  
    81  	authInfo := &txtypes.AuthInfo{
    82  		Fee:         &fee,
    83  		SignerInfos: signerInfo,
    84  	}
    85  
    86  	authInfoBytes := marshaler.MustMarshal(authInfo)
    87  
    88  	anys := make([]*codectypes.Any, len(msgs))
    89  
    90  	for i, msg := range msgs {
    91  		var err error
    92  		anys[i], err = codectypes.NewAnyWithValue(msg)
    93  		if err != nil {
    94  			panic(err)
    95  		}
    96  	}
    97  
    98  	txBody := &txtypes.TxBody{
    99  		Memo:     memo,
   100  		Messages: anys,
   101  	}
   102  	bodyBytes := marshaler.MustMarshal(txBody)
   103  
   104  	t.Log("verify GetSignBytes with generating sign bytes by marshaling SignDoc")
   105  	signDoc := txtypes.SignDoc{
   106  		AccountNumber: 1,
   107  		AuthInfoBytes: authInfoBytes,
   108  		BodyBytes:     bodyBytes,
   109  		ChainId:       "test-chain",
   110  	}
   111  
   112  	expectedSignBytes, err := signDoc.Marshal()
   113  	require.NoError(t, err)
   114  	require.Equal(t, expectedSignBytes, signBytes)
   115  
   116  	t.Log("verify that setting signature doesn't change sign bytes")
   117  	sigData.Signature, err = privKey.Sign(signBytes)
   118  	require.NoError(t, err)
   119  	err = txBuilder.SetSignatures(sig)
   120  	require.NoError(t, err)
   121  	signBytes, err = modeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
   122  	require.NoError(t, err)
   123  	require.Equal(t, expectedSignBytes, signBytes)
   124  
   125  	t.Log("verify GetSignBytes with false txBody data")
   126  	signDoc.BodyBytes = []byte("dfafdasfds")
   127  	expectedSignBytes, err = signDoc.Marshal()
   128  	require.NoError(t, err)
   129  	require.NotEqual(t, expectedSignBytes, signBytes)
   130  }
   131  
   132  func TestDirectModeHandler_nonDIRECT_MODE(t *testing.T) {
   133  	invalidModes := []signingtypes.SignMode{
   134  		signingtypes.SignMode_SIGN_MODE_TEXTUAL,
   135  		signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON,
   136  		signingtypes.SignMode_SIGN_MODE_UNSPECIFIED,
   137  	}
   138  	for _, invalidMode := range invalidModes {
   139  		t.Run(invalidMode.String(), func(t *testing.T) {
   140  			var dh signModeDirectHandler
   141  			var signingData signing.SignerData
   142  			_, err := dh.GetSignBytes(invalidMode, signingData, nil)
   143  			require.Error(t, err)
   144  			wantErr := fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_DIRECT, invalidMode)
   145  			require.Equal(t, err, wantErr)
   146  		})
   147  	}
   148  }
   149  
   150  type nonProtoTx int
   151  
   152  func (npt *nonProtoTx) GetMsgs() []sdk.Msg   { return nil }
   153  func (npt *nonProtoTx) ValidateBasic() error { return nil }
   154  
   155  var _ sdk.Tx = (*nonProtoTx)(nil)
   156  
   157  func TestDirectModeHandler_nonProtoTx(t *testing.T) {
   158  	var dh signModeDirectHandler
   159  	var signingData signing.SignerData
   160  	tx := new(nonProtoTx)
   161  	_, err := dh.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, tx)
   162  	require.Error(t, err)
   163  	wantErr := fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
   164  	require.Equal(t, err, wantErr)
   165  }