github.com/lino-network/lino@v0.6.11/x/auth/ante.go (about)

     1  package auth
     2  
     3  import (
     4  	"fmt"
     5  
     6  	sdk "github.com/cosmos/cosmos-sdk/types"
     7  	"github.com/cosmos/cosmos-sdk/x/auth"
     8  
     9  	"github.com/lino-network/lino/types"
    10  	acc "github.com/lino-network/lino/x/account"
    11  	"github.com/lino-network/lino/x/bandwidth"
    12  	"github.com/tendermint/tendermint/crypto"
    13  )
    14  
    15  const (
    16  	maxMemoCharacters = 100
    17  )
    18  
    19  // getAccOrAddrSignersFromMsg allows AddrMsg to override signers
    20  func getAccOrAddrSignersFromMsg(msg sdk.Msg) []types.AccOrAddr {
    21  	switch v := msg.(type) {
    22  	case types.AddrMsg:
    23  		return v.GetAccOrAddrSigners()
    24  	default:
    25  		rst := make([]types.AccOrAddr, 0)
    26  		for _, signer := range msg.GetSigners() {
    27  			rst = append(rst, types.NewAccOrAddrFromAcc(types.AccountKey(signer)))
    28  		}
    29  		return rst
    30  	}
    31  }
    32  
    33  type msgAndSigs struct {
    34  	msg     sdk.Msg
    35  	signers []types.AccOrAddr
    36  	sigs    []auth.StdSignature
    37  }
    38  
    39  func validateAndExtract(stdTx auth.StdTx) ([]msgAndSigs, sdk.Error) {
    40  	// validate memo
    41  	if len(stdTx.GetMemo()) > maxMemoCharacters {
    42  		return nil, sdk.ErrMemoTooLarge(fmt.Sprintf(
    43  			"maximum number of characters is %d but received %d characters",
    44  			maxMemoCharacters, len(stdTx.GetMemo())))
    45  	}
    46  
    47  	// validate sigs
    48  	// 1. that there are signatures.
    49  	// 2. no more than limit.
    50  	var sigs = stdTx.GetSignatures()
    51  	if len(sigs) == 0 {
    52  		return nil, ErrNoSignatures()
    53  	}
    54  	if len(sigs) > types.TxSigLimit {
    55  		return nil, sdk.ErrTooManySignatures(fmt.Sprintf(
    56  			"signatures: %d, limit: %d",
    57  			len(sigs), types.TxSigLimit))
    58  	}
    59  
    60  	// extract signers
    61  	msgs := stdTx.GetMsgs()
    62  	rst := make([]msgAndSigs, len(msgs))
    63  	for i, msg := range msgs {
    64  		signers := getAccOrAddrSignersFromMsg(msg)
    65  		nSigRequired := len(signers)
    66  		if len(sigs) < nSigRequired {
    67  			return nil, ErrWrongNumberOfSigners()
    68  		}
    69  		rst[i] = msgAndSigs{
    70  			msg:     msg,
    71  			signers: signers,
    72  			sigs:    sigs[:nSigRequired],
    73  		}
    74  		sigs = sigs[nSigRequired:]
    75  	}
    76  	if len(sigs) != 0 {
    77  		return nil, ErrWrongNumberOfSigners()
    78  	}
    79  
    80  	return rst, nil
    81  }
    82  
    83  type signBytesFactory = func(seq uint64) []byte
    84  
    85  // NewAnteHandler - return an AnteHandler
    86  func NewAnteHandler(am acc.AccountKeeper, bm bandwidth.BandwidthKeeper) sdk.AnteHandler {
    87  	return func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, sdk.Result, bool) {
    88  		stdTx, ok := tx.(auth.StdTx)
    89  		if !ok {
    90  			return ctx, ErrIncorrectStdTxType().Result(), true
    91  		}
    92  		msgAndSigs, err := validateAndExtract(stdTx)
    93  		if err != nil {
    94  			return ctx, err.Result(), true
    95  		}
    96  
    97  		// signbyte creator returns the bytes that should be signed.
    98  		signBytesCreator := func(seq uint64) []byte {
    99  			return auth.StdSignBytes(
   100  				ctx.ChainID(), uint64(0), seq, stdTx.Fee, stdTx.GetMsgs(), stdTx.GetMemo())
   101  		}
   102  
   103  		// validate each msg.
   104  		for _, msgSigs := range msgAndSigs {
   105  			if err := validateMsg(ctx, am, bm, msgSigs, signBytesCreator, stdTx.Fee); err != nil {
   106  				return ctx, err.Result(), true
   107  			}
   108  		}
   109  
   110  		return ctx, sdk.Result{}, false
   111  	}
   112  }
   113  
   114  func validateMsg(ctx sdk.Context, am acc.AccountKeeper, bm bandwidth.BandwidthKeeper, msgSigs msgAndSigs, signBytesCreator signBytesFactory, fee auth.StdFee) sdk.Error {
   115  	// validate each signature.
   116  	paid := false
   117  	for i, signer := range msgSigs.signers {
   118  		sig := msgSigs.sigs[i]
   119  		var signerAddr sdk.AccAddress
   120  		if signer.IsAddr {
   121  			err := checkAddrSigner(ctx, am, signer.Addr, sig.PubKey, paid)
   122  			if err != nil {
   123  				return err
   124  			}
   125  			signerAddr = signer.Addr
   126  		} else {
   127  			var err sdk.Error
   128  			signerAddr, err = checkAccountSigner(ctx, am, signer.AccountKey, sig.PubKey)
   129  			if err != nil {
   130  				return err
   131  			}
   132  		}
   133  
   134  		// 1. verify seq.
   135  		seq, err := am.GetSequence(ctx, signerAddr)
   136  		if err != nil {
   137  			return err
   138  		}
   139  		// 2. verify signature
   140  		signBytes := signBytesCreator(seq)
   141  		if !sig.PubKey.VerifyBytes(signBytes, sig.Signature) {
   142  			return ErrUnverifiedBytes(fmt.Sprintf(
   143  				"signature verification failed, chain-id:%v, seq:%d",
   144  				ctx.ChainID(), seq))
   145  		}
   146  		// 3. increase seq
   147  		if err := am.IncreaseSequenceByOne(ctx, signerAddr); err != nil {
   148  			return err
   149  		}
   150  		// 4. only pay fee in the end.
   151  		// only the first signer pays the fee
   152  		if !paid {
   153  			if err := bm.CheckBandwidth(ctx, signerAddr, fee); err != nil {
   154  				return err
   155  			}
   156  		}
   157  		paid = true
   158  	}
   159  	return nil
   160  }
   161  
   162  func checkAddrSigner(ctx sdk.Context, am acc.AccountKeeper, addr sdk.AccAddress, signKey crypto.PubKey, isPaid bool) sdk.Error {
   163  	// if signer is address
   164  	if err := am.CheckSigningPubKeyOwnerByAddress(ctx, addr, signKey, isPaid); err != nil {
   165  		return err
   166  	}
   167  	return nil
   168  }
   169  
   170  // this function return the actual signer of the msg.
   171  func checkAccountSigner(ctx sdk.Context, am acc.AccountKeeper, msgSigner types.AccountKey, signKey crypto.PubKey) (signerAddr sdk.AccAddress, err sdk.Error) {
   172  	// check public key is valid to sign this msg
   173  	// return signer is the actual signer of the msg
   174  	signer, err := am.CheckSigningPubKeyOwner(ctx, msgSigner, signKey)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	// get address of actual signer.
   179  	return am.GetAddress(ctx, signer)
   180  }