github.com/diamondburned/arikawa/v2@v2.1.0/bot/extras/middlewares/middlewares_test.go (about) 1 package middlewares 2 3 import ( 4 "errors" 5 "testing" 6 7 "github.com/diamondburned/arikawa/v2/bot" 8 "github.com/diamondburned/arikawa/v2/discord" 9 "github.com/diamondburned/arikawa/v2/gateway" 10 "github.com/diamondburned/arikawa/v2/session" 11 "github.com/diamondburned/arikawa/v2/state" 12 "github.com/diamondburned/arikawa/v2/state/store" 13 ) 14 15 func TestAdminOnly(t *testing.T) { 16 var ctx = &bot.Context{ 17 State: &state.State{ 18 Session: &session.Session{ 19 Gateway: &gateway.Gateway{ 20 Identifier: &gateway.Identifier{ 21 IdentifyData: gateway.IdentifyData{ 22 Intents: gateway.IntentGuilds | gateway.IntentGuildMembers, 23 }, 24 }, 25 }, 26 }, 27 Cabinet: mockCabinet(), 28 }, 29 } 30 var middleware = AdminOnly(ctx) 31 32 t.Run("allow message", func(t *testing.T) { 33 var msg = &gateway.MessageCreateEvent{ 34 Message: discord.Message{ 35 ID: 1, 36 ChannelID: 69420, 37 Author: discord.User{ID: 69420}, 38 }, 39 } 40 expectNil(t, middleware(msg)) 41 }) 42 43 t.Run("deny message", func(t *testing.T) { 44 var msg = &gateway.MessageCreateEvent{ 45 Message: discord.Message{ 46 ID: 2, 47 ChannelID: 1337, 48 Author: discord.User{ID: 1337}, 49 }, 50 } 51 expectBreak(t, middleware(msg)) 52 var pin = &gateway.ChannelPinsUpdateEvent{ 53 ChannelID: 120, 54 } 55 expectBreak(t, middleware(pin)) 56 var tpg = &gateway.TypingStartEvent{} 57 expectBreak(t, middleware(tpg)) 58 }) 59 } 60 61 func TestGuildOnly(t *testing.T) { 62 var ctx = &bot.Context{ 63 State: &state.State{ 64 Session: &session.Session{ 65 Gateway: &gateway.Gateway{ 66 Identifier: &gateway.Identifier{ 67 IdentifyData: gateway.IdentifyData{ 68 Intents: gateway.IntentGuilds, 69 }, 70 }, 71 }, 72 }, 73 Cabinet: mockCabinet(), 74 }, 75 } 76 var middleware = GuildOnly(ctx) 77 78 t.Run("allow message with GuildID", func(t *testing.T) { 79 var msg = &gateway.MessageCreateEvent{ 80 Message: discord.Message{ 81 ID: 3, 82 GuildID: 1337, 83 }, 84 } 85 expectNil(t, middleware(msg)) 86 }) 87 88 t.Run("allow message with ChannelID", func(t *testing.T) { 89 var msg = &gateway.MessageCreateEvent{ 90 Message: discord.Message{ 91 ID: 3, 92 ChannelID: 69420, 93 }, 94 } 95 expectNil(t, middleware(msg)) 96 }) 97 98 t.Run("deny message", func(t *testing.T) { 99 var msg = &gateway.MessageCreateEvent{ 100 Message: discord.Message{ 101 ID: 1, 102 ChannelID: 12, 103 }, 104 } 105 expectBreak(t, middleware(msg)) 106 107 var msg2 = &gateway.MessageCreateEvent{} 108 expectBreak(t, middleware(msg2)) 109 }) 110 } 111 112 func expectNil(t *testing.T, err error) { 113 t.Helper() 114 if err != nil { 115 t.Fatal("Unexpected error:", err) 116 } 117 } 118 119 func expectBreak(t *testing.T, err error) { 120 t.Helper() 121 if errors.Is(err, bot.Break) { 122 return 123 } 124 if err != nil { 125 t.Fatal("Unexpected error:", err) 126 } 127 t.Fatal("Expected error, got nothing.") 128 } 129 130 // BenchmarkGuildOnly runs a message through the GuildOnly middleware to 131 // calculate the overhead of reflection. 132 func BenchmarkGuildOnly(b *testing.B) { 133 var ctx = &bot.Context{ 134 State: &state.State{ 135 Cabinet: mockCabinet(), 136 }, 137 } 138 var middleware = GuildOnly(ctx) 139 var msg = &gateway.MessageCreateEvent{ 140 Message: discord.Message{ 141 ID: 3, 142 GuildID: 1337, 143 }, 144 } 145 146 b.ResetTimer() 147 148 for i := 0; i < b.N; i++ { 149 if err := middleware(msg); err != nil { 150 b.Fatal("Unexpected error:", err) 151 } 152 } 153 } 154 155 // BenchmarkAdminOnly runs a message through the GuildOnly middleware to 156 // calculate the overhead of reflection. 157 func BenchmarkAdminOnly(b *testing.B) { 158 var ctx = &bot.Context{ 159 State: &state.State{ 160 Cabinet: mockCabinet(), 161 }, 162 } 163 var middleware = AdminOnly(ctx) 164 var msg = &gateway.MessageCreateEvent{ 165 Message: discord.Message{ 166 ID: 1, 167 ChannelID: 1337, 168 Author: discord.User{ID: 69420}, 169 }, 170 } 171 172 b.ResetTimer() 173 174 for i := 0; i < b.N; i++ { 175 if err := middleware(msg); err != nil { 176 b.Fatal("Unexpected error:", err) 177 } 178 } 179 } 180 181 type mockStore struct { 182 store.NoopStore 183 } 184 185 func mockCabinet() store.Cabinet { 186 c := store.NoopCabinet 187 c.GuildStore = &mockStore{} 188 c.MemberStore = &mockStore{} 189 c.ChannelStore = &mockStore{} 190 191 return c 192 } 193 194 func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) { 195 return &discord.Guild{ 196 ID: id, 197 Roles: []discord.Role{{ 198 ID: 69420, 199 Permissions: discord.PermissionAdministrator, 200 }}, 201 }, nil 202 } 203 204 func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) { 205 return &discord.Member{ 206 User: discord.User{ID: userID}, 207 RoleIDs: []discord.RoleID{discord.RoleID(userID)}, 208 }, nil 209 } 210 211 // Channel returns a channel with a guildID for #69420. 212 func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) { 213 if id == 69420 { 214 return &discord.Channel{ 215 ID: id, 216 GuildID: 1337, 217 }, nil 218 } 219 220 return &discord.Channel{ 221 ID: id, 222 }, nil 223 }