github.com/diamondburned/arikawa@v1.3.14/bot/extras/middlewares/middlewares_test.go (about) 1 package middlewares 2 3 import ( 4 "errors" 5 "testing" 6 7 "github.com/diamondburned/arikawa/bot" 8 "github.com/diamondburned/arikawa/discord" 9 "github.com/diamondburned/arikawa/gateway" 10 "github.com/diamondburned/arikawa/state" 11 ) 12 13 func TestAdminOnly(t *testing.T) { 14 var ctx = &bot.Context{ 15 State: &state.State{ 16 Store: &mockStore{}, 17 }, 18 } 19 var middleware = AdminOnly(ctx) 20 21 t.Run("allow message", func(t *testing.T) { 22 var msg = &gateway.MessageCreateEvent{ 23 Message: discord.Message{ 24 ID: 1, 25 ChannelID: 1337, 26 Author: discord.User{ID: 69420}, 27 }, 28 } 29 expectNil(t, middleware(msg)) 30 }) 31 32 t.Run("deny message", func(t *testing.T) { 33 var msg = &gateway.MessageCreateEvent{ 34 Message: discord.Message{ 35 ID: 2, 36 ChannelID: 1337, 37 Author: discord.User{ID: 1337}, 38 }, 39 } 40 expectBreak(t, middleware(msg)) 41 var pin = &gateway.ChannelPinsUpdateEvent{ 42 ChannelID: 120, 43 } 44 expectBreak(t, middleware(pin)) 45 var tpg = &gateway.TypingStartEvent{} 46 expectBreak(t, middleware(tpg)) 47 }) 48 } 49 50 func TestGuildOnly(t *testing.T) { 51 var ctx = &bot.Context{ 52 State: &state.State{ 53 Store: &mockStore{}, 54 }, 55 } 56 var middleware = GuildOnly(ctx) 57 58 t.Run("allow message with GuildID", func(t *testing.T) { 59 var msg = &gateway.MessageCreateEvent{ 60 Message: discord.Message{ 61 ID: 3, 62 GuildID: 1337, 63 }, 64 } 65 expectNil(t, middleware(msg)) 66 }) 67 68 t.Run("allow message with ChannelID", func(t *testing.T) { 69 var msg = &gateway.MessageCreateEvent{ 70 Message: discord.Message{ 71 ID: 3, 72 ChannelID: 69420, 73 }, 74 } 75 expectNil(t, middleware(msg)) 76 }) 77 78 t.Run("deny message", func(t *testing.T) { 79 var msg = &gateway.MessageCreateEvent{ 80 Message: discord.Message{ 81 ID: 1, 82 ChannelID: 12, 83 }, 84 } 85 expectBreak(t, middleware(msg)) 86 87 var msg2 = &gateway.MessageCreateEvent{} 88 expectBreak(t, middleware(msg2)) 89 }) 90 } 91 92 func expectNil(t *testing.T, err error) { 93 t.Helper() 94 if err != nil { 95 t.Fatal("Unexpected error:", err) 96 } 97 } 98 99 func expectBreak(t *testing.T, err error) { 100 t.Helper() 101 if errors.Is(err, bot.Break) { 102 return 103 } 104 if err != nil { 105 t.Fatal("Unexpected error:", err) 106 } 107 t.Fatal("Expected error, got nothing.") 108 } 109 110 // BenchmarkGuildOnly runs a message through the GuildOnly middleware to 111 // calculate the overhead of reflection. 112 func BenchmarkGuildOnly(b *testing.B) { 113 var ctx = &bot.Context{ 114 State: &state.State{ 115 Store: &mockStore{}, 116 }, 117 } 118 var middleware = GuildOnly(ctx) 119 var msg = &gateway.MessageCreateEvent{ 120 Message: discord.Message{ 121 ID: 3, 122 GuildID: 1337, 123 }, 124 } 125 126 b.ResetTimer() 127 128 for i := 0; i < b.N; i++ { 129 if err := middleware(msg); err != nil { 130 b.Fatal("Unexpected error:", err) 131 } 132 } 133 } 134 135 // BenchmarkAdminOnly runs a message through the GuildOnly middleware to 136 // calculate the overhead of reflection. 137 func BenchmarkAdminOnly(b *testing.B) { 138 var ctx = &bot.Context{ 139 State: &state.State{ 140 Store: &mockStore{}, 141 }, 142 } 143 var middleware = AdminOnly(ctx) 144 var msg = &gateway.MessageCreateEvent{ 145 Message: discord.Message{ 146 ID: 1, 147 ChannelID: 1337, 148 Author: discord.User{ID: 69420}, 149 }, 150 } 151 152 b.ResetTimer() 153 154 for i := 0; i < b.N; i++ { 155 if err := middleware(msg); err != nil { 156 b.Fatal("Unexpected error:", err) 157 } 158 } 159 } 160 161 type mockStore struct { 162 state.NoopStore 163 } 164 165 func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) { 166 return &discord.Guild{ 167 ID: id, 168 Roles: []discord.Role{{ 169 ID: 69420, 170 Permissions: discord.PermissionAdministrator, 171 }}, 172 }, nil 173 } 174 175 func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) { 176 return &discord.Member{ 177 User: discord.User{ID: userID}, 178 RoleIDs: []discord.RoleID{discord.RoleID(userID)}, 179 }, nil 180 } 181 182 // Channel returns a channel with a guildID for #69420. 183 func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) { 184 if id == 69420 { 185 return &discord.Channel{ 186 ID: id, 187 GuildID: 1337, 188 }, nil 189 } 190 191 return &discord.Channel{ 192 ID: id, 193 }, nil 194 }