github.com/datachainlab/cross@v0.2.2/x/packets/middlewares_test.go (about)

     1  package packets
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"testing"
     7  
     8  	"github.com/cosmos/cosmos-sdk/codec"
     9  	sdk "github.com/cosmos/cosmos-sdk/types"
    10  	capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"
    11  	"github.com/cosmos/ibc-go/modules/core/exported"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  var codecm codec.Codec
    16  
    17  func TestMiddleware(t *testing.T) {
    18  	require := require.New(t)
    19  	m := NewCounterPacketMiddleware()
    20  	ctx := makeMockContext()
    21  
    22  	mps := &memPacketSender{}
    23  	_, ps, err := m.HandleMsg(ctx, nil, mps)
    24  	require.NoError(err)
    25  	outp := newTestPacket(Header{}, &TestPacketDataPayload{})
    26  	require.NoError(ps.SendPacket(ctx, nil, outp))
    27  	count, ok := getCount(outp.pd.Header)
    28  	require.True(ok)
    29  	require.Equal(uint32(1), count)
    30  
    31  	mps = &memPacketSender{}
    32  	as := NewBasicACKSender()
    33  	_, ps, as, err = m.HandlePacket(ctx, outp, mps, as)
    34  	require.NoError(err)
    35  	require.NoError(ps.SendPacket(ctx, nil, outp))
    36  	count, ok = getCount(outp.pd.Header)
    37  	require.True(ok)
    38  	require.Equal(uint32(2), count)
    39  }
    40  
    41  func makeMockContext() sdk.Context {
    42  	return sdk.Context{}
    43  }
    44  
    45  func (p TestPacketDataPayload) Type() string {
    46  	return "cross/packets/test"
    47  }
    48  
    49  type testPacket struct {
    50  	exported.PacketI
    51  	pd      PacketData
    52  	payload PacketDataPayload
    53  }
    54  
    55  func newTestPacket(header Header, payload PacketDataPayload) *testPacket {
    56  	return &testPacket{
    57  		pd:      NewPacketData(&header, payload),
    58  		payload: payload,
    59  	}
    60  }
    61  
    62  var _ IncomingPacket = (*testPacket)(nil)
    63  var _ OutgoingPacket = (*testPacket)(nil)
    64  
    65  func (p testPacket) PacketData() PacketData {
    66  	return p.pd
    67  }
    68  
    69  func (p testPacket) Header() Header {
    70  	return p.pd.Header
    71  }
    72  
    73  func (p testPacket) Payload() PacketDataPayload {
    74  	return p.payload
    75  }
    76  
    77  func (p *testPacket) SetPacketData(header Header, payload PacketDataPayload) {
    78  	*p = *newTestPacket(header, payload)
    79  }
    80  
    81  type memPacketSender struct {
    82  	packet *OutgoingPacket
    83  }
    84  
    85  func (s *memPacketSender) SendPacket(
    86  	ctx sdk.Context,
    87  	channelCap *capabilitytypes.Capability,
    88  	packet OutgoingPacket,
    89  ) error {
    90  	s.packet = &packet
    91  	return nil
    92  }
    93  
    94  type counterPacketMiddleware struct{}
    95  
    96  var _ PacketMiddleware = (*counterPacketMiddleware)(nil)
    97  
    98  // NewCounterPacketMiddleware returns counterPacketMiddleware
    99  func NewCounterPacketMiddleware() PacketMiddleware {
   100  	return counterPacketMiddleware{}
   101  }
   102  
   103  // HandleMsg implements PacketMiddleware.HandleMsg
   104  func (m counterPacketMiddleware) HandleMsg(ctx sdk.Context, msg sdk.Msg, ps PacketSender) (sdk.Context, PacketSender, error) {
   105  	return ctx, newPacketSender(1, ps), nil
   106  }
   107  
   108  // HandlePacket implements PacketMiddleware.HandlePacket
   109  func (m counterPacketMiddleware) HandlePacket(ctx sdk.Context, ip IncomingPacket, ps PacketSender, as ACKSender) (sdk.Context, PacketSender, ACKSender, error) {
   110  	var next uint32
   111  	count, found := getCount(ip.Header())
   112  	if found {
   113  		next = count + 1
   114  	} else {
   115  		next = 1
   116  	}
   117  	return ctx, newPacketSender(next, ps), newACKSender(next, as), nil
   118  }
   119  
   120  // HandlePacket implements PacketMiddleware.HandleACK
   121  func (m counterPacketMiddleware) HandleACK(ctx sdk.Context, ip IncomingPacket, ack IncomingPacketAcknowledgement, ps PacketSender) (sdk.Context, PacketSender, error) {
   122  	return ctx, ps, nil
   123  }
   124  
   125  type packetSender struct {
   126  	count uint32
   127  	next  PacketSender
   128  }
   129  
   130  var _ PacketSender = (*packetSender)(nil)
   131  
   132  func newPacketSender(count uint32, next PacketSender) PacketSender {
   133  	return packetSender{count: count, next: next}
   134  }
   135  
   136  func (ps packetSender) SendPacket(
   137  	ctx sdk.Context,
   138  	channelCap *capabilitytypes.Capability,
   139  	packet OutgoingPacket,
   140  ) error {
   141  	h := packet.Header()
   142  	setCount(&h, ps.count)
   143  	packet.SetPacketData(h, packet.Payload())
   144  	return ps.next.SendPacket(ctx, channelCap, packet)
   145  }
   146  
   147  type ackSender struct {
   148  	count uint32
   149  	next  ACKSender
   150  }
   151  
   152  var _ ACKSender = (*ackSender)(nil)
   153  
   154  func newACKSender(count uint32, next ACKSender) ACKSender {
   155  	return &ackSender{count: count, next: next}
   156  }
   157  
   158  func (as ackSender) SendACK(ctx sdk.Context, ack OutgoingPacketAcknowledgement) error {
   159  	h := ack.Header()
   160  	setCount(&h, as.count)
   161  	ack.SetData(h, ack.Payload())
   162  	return nil
   163  }
   164  
   165  const testHeaderKey = "count"
   166  
   167  func setCount(h *Header, count uint32) {
   168  	h.Set(testHeaderKey, []byte(fmt.Sprint(count)))
   169  }
   170  
   171  func getCount(h Header) (uint32, bool) {
   172  	v, ok := h.Get(testHeaderKey)
   173  	if !ok {
   174  		return 0, false
   175  	}
   176  	i, err := strconv.Atoi(string(v))
   177  	if err != nil {
   178  		panic(err)
   179  	}
   180  	return uint32(i), true
   181  }