github.com/datachainlab/cross@v0.2.2/x/packets/middlewares_test.go (about) 1 package packets 2 3 import ( 4 "fmt" 5 "strconv" 6 "testing" 7 8 "github.com/cosmos/cosmos-sdk/codec" 9 sdk "github.com/cosmos/cosmos-sdk/types" 10 capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types" 11 "github.com/cosmos/ibc-go/modules/core/exported" 12 "github.com/stretchr/testify/require" 13 ) 14 15 var codecm codec.Codec 16 17 func TestMiddleware(t *testing.T) { 18 require := require.New(t) 19 m := NewCounterPacketMiddleware() 20 ctx := makeMockContext() 21 22 mps := &memPacketSender{} 23 _, ps, err := m.HandleMsg(ctx, nil, mps) 24 require.NoError(err) 25 outp := newTestPacket(Header{}, &TestPacketDataPayload{}) 26 require.NoError(ps.SendPacket(ctx, nil, outp)) 27 count, ok := getCount(outp.pd.Header) 28 require.True(ok) 29 require.Equal(uint32(1), count) 30 31 mps = &memPacketSender{} 32 as := NewBasicACKSender() 33 _, ps, as, err = m.HandlePacket(ctx, outp, mps, as) 34 require.NoError(err) 35 require.NoError(ps.SendPacket(ctx, nil, outp)) 36 count, ok = getCount(outp.pd.Header) 37 require.True(ok) 38 require.Equal(uint32(2), count) 39 } 40 41 func makeMockContext() sdk.Context { 42 return sdk.Context{} 43 } 44 45 func (p TestPacketDataPayload) Type() string { 46 return "cross/packets/test" 47 } 48 49 type testPacket struct { 50 exported.PacketI 51 pd PacketData 52 payload PacketDataPayload 53 } 54 55 func newTestPacket(header Header, payload PacketDataPayload) *testPacket { 56 return &testPacket{ 57 pd: NewPacketData(&header, payload), 58 payload: payload, 59 } 60 } 61 62 var _ IncomingPacket = (*testPacket)(nil) 63 var _ OutgoingPacket = (*testPacket)(nil) 64 65 func (p testPacket) PacketData() PacketData { 66 return p.pd 67 } 68 69 func (p testPacket) Header() Header { 70 return p.pd.Header 71 } 72 73 func (p testPacket) Payload() PacketDataPayload { 74 return p.payload 75 } 76 77 func (p *testPacket) SetPacketData(header Header, payload PacketDataPayload) { 78 *p = *newTestPacket(header, payload) 79 } 80 81 type memPacketSender struct { 82 packet *OutgoingPacket 83 } 84 85 func (s *memPacketSender) SendPacket( 86 ctx sdk.Context, 87 channelCap *capabilitytypes.Capability, 88 packet OutgoingPacket, 89 ) error { 90 s.packet = &packet 91 return nil 92 } 93 94 type counterPacketMiddleware struct{} 95 96 var _ PacketMiddleware = (*counterPacketMiddleware)(nil) 97 98 // NewCounterPacketMiddleware returns counterPacketMiddleware 99 func NewCounterPacketMiddleware() PacketMiddleware { 100 return counterPacketMiddleware{} 101 } 102 103 // HandleMsg implements PacketMiddleware.HandleMsg 104 func (m counterPacketMiddleware) HandleMsg(ctx sdk.Context, msg sdk.Msg, ps PacketSender) (sdk.Context, PacketSender, error) { 105 return ctx, newPacketSender(1, ps), nil 106 } 107 108 // HandlePacket implements PacketMiddleware.HandlePacket 109 func (m counterPacketMiddleware) HandlePacket(ctx sdk.Context, ip IncomingPacket, ps PacketSender, as ACKSender) (sdk.Context, PacketSender, ACKSender, error) { 110 var next uint32 111 count, found := getCount(ip.Header()) 112 if found { 113 next = count + 1 114 } else { 115 next = 1 116 } 117 return ctx, newPacketSender(next, ps), newACKSender(next, as), nil 118 } 119 120 // HandlePacket implements PacketMiddleware.HandleACK 121 func (m counterPacketMiddleware) HandleACK(ctx sdk.Context, ip IncomingPacket, ack IncomingPacketAcknowledgement, ps PacketSender) (sdk.Context, PacketSender, error) { 122 return ctx, ps, nil 123 } 124 125 type packetSender struct { 126 count uint32 127 next PacketSender 128 } 129 130 var _ PacketSender = (*packetSender)(nil) 131 132 func newPacketSender(count uint32, next PacketSender) PacketSender { 133 return packetSender{count: count, next: next} 134 } 135 136 func (ps packetSender) SendPacket( 137 ctx sdk.Context, 138 channelCap *capabilitytypes.Capability, 139 packet OutgoingPacket, 140 ) error { 141 h := packet.Header() 142 setCount(&h, ps.count) 143 packet.SetPacketData(h, packet.Payload()) 144 return ps.next.SendPacket(ctx, channelCap, packet) 145 } 146 147 type ackSender struct { 148 count uint32 149 next ACKSender 150 } 151 152 var _ ACKSender = (*ackSender)(nil) 153 154 func newACKSender(count uint32, next ACKSender) ACKSender { 155 return &ackSender{count: count, next: next} 156 } 157 158 func (as ackSender) SendACK(ctx sdk.Context, ack OutgoingPacketAcknowledgement) error { 159 h := ack.Header() 160 setCount(&h, as.count) 161 ack.SetData(h, ack.Payload()) 162 return nil 163 } 164 165 const testHeaderKey = "count" 166 167 func setCount(h *Header, count uint32) { 168 h.Set(testHeaderKey, []byte(fmt.Sprint(count))) 169 } 170 171 func getCount(h Header) (uint32, bool) { 172 v, ok := h.Get(testHeaderKey) 173 if !ok { 174 return 0, false 175 } 176 i, err := strconv.Atoi(string(v)) 177 if err != nil { 178 panic(err) 179 } 180 return uint32(i), true 181 }