github.com/diamondburned/arikawa@v1.3.14/bot/ctx_test.go (about) 1 package bot 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "strconv" 8 "strings" 9 "testing" 10 "time" 11 12 "github.com/diamondburned/arikawa/discord" 13 "github.com/diamondburned/arikawa/gateway" 14 "github.com/diamondburned/arikawa/state" 15 "github.com/diamondburned/arikawa/utils/handler" 16 ) 17 18 type testc struct { 19 Ctx *Context 20 Return chan interface{} 21 Counter uint64 22 Typed int8 23 } 24 25 func (t *testc) Setup(sub *Subcommand) { 26 sub.AddMiddleware("*,GetCounter", func(v interface{}) { 27 t.Counter++ 28 }) 29 sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) { 30 t.Counter++ 31 }) 32 // stub middleware for testing 33 sub.AddMiddleware("OnTyping", func(*gateway.TypingStartEvent) { 34 t.Typed = 2 35 }) 36 sub.Hide("Hidden") 37 } 38 func (t *testc) Hidden(*gateway.MessageCreateEvent) {} 39 func (t *testc) Noop(*gateway.MessageCreateEvent) {} 40 func (t *testc) GetCounter(*gateway.MessageCreateEvent) { 41 t.Return <- strconv.FormatUint(t.Counter, 10) 42 } 43 func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error { 44 t.Return <- args 45 return errors.New("oh no") 46 } 47 func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) { 48 t.Return <- []string(*c) 49 } 50 func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) { 51 t.Return <- c[len(c)-1] 52 } 53 func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) { 54 t.Return <- c 55 } 56 func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) { 57 t.Return <- c 58 } 59 func (t *testc) NoArgs(*gateway.MessageCreateEvent) error { 60 return errors.New("passed") 61 } 62 func (t *testc) OnTyping(*gateway.TypingStartEvent) { 63 t.Typed-- 64 } 65 66 func TestNewContext(t *testing.T) { 67 var s = &state.State{ 68 Store: state.NewDefaultStore(nil), 69 } 70 71 c, err := New(s, &testc{}) 72 if err != nil { 73 t.Fatal("Failed to create new context:", err) 74 } 75 76 if !reflect.DeepEqual(c.Subcommands(), c.subcommands) { 77 t.Fatal("Subcommands mismatch.") 78 } 79 } 80 81 func TestContext(t *testing.T) { 82 var given = &testc{} 83 var s = &state.State{ 84 Store: state.NewDefaultStore(nil), 85 Handler: handler.New(), 86 } 87 88 sub, err := NewSubcommand(given) 89 if err != nil { 90 t.Fatal("Failed to create subcommand:", err) 91 } 92 93 var ctx = &Context{ 94 Name: "arikawa/bot test", 95 Description: "Just a test.", 96 97 Subcommand: sub, 98 State: s, 99 ParseArgs: DefaultArgsParser(), 100 } 101 102 t.Run("init commands", func(t *testing.T) { 103 if err := ctx.Subcommand.InitCommands(ctx); err != nil { 104 t.Fatal("Failed to init commands:", err) 105 } 106 107 if given.Ctx == nil { 108 t.Fatal("given'sub Context field is nil") 109 } 110 111 if given.Ctx.State.Store == nil { 112 t.Fatal("given'sub State is nil") 113 } 114 }) 115 116 t.Run("find commands", func(t *testing.T) { 117 cmd := ctx.FindCommand("", "NoArgs") 118 if cmd == nil { 119 t.Fatal("Failed to find NoArgs") 120 } 121 }) 122 123 t.Run("help", func(t *testing.T) { 124 ctx.MustRegisterSubcommandCustom(&testc{}, "helper") 125 126 h := ctx.Help() 127 if h == "" { 128 t.Fatal("Empty help?") 129 } 130 131 if strings.Contains(h, "hidden") { 132 t.Fatal("Hidden command shown in help.") 133 } 134 135 if !strings.Contains(h, "arikawa/bot test") { 136 t.Fatal("Name not found.") 137 } 138 if !strings.Contains(h, "Just a test.") { 139 t.Fatal("Description not found.") 140 } 141 }) 142 143 t.Run("middleware", func(t *testing.T) { 144 ctx.HasPrefix = NewPrefix("pls do ") 145 146 // This should trigger the middleware first. 147 if err := expect(ctx, given, "3", "pls do getCounter"); err != nil { 148 t.Fatal("Unexpected error:", err) 149 } 150 }) 151 152 t.Run("typing event", func(t *testing.T) { 153 typing := &gateway.TypingStartEvent{} 154 155 if err := ctx.callCmd(typing); err != nil { 156 t.Fatal("Failed to call with TypingStart:", err) 157 } 158 159 // -1 none ran 160 if given.Typed != 1 { 161 t.Fatal("Typed bool is false") 162 } 163 }) 164 165 t.Run("call command", func(t *testing.T) { 166 // Set a custom prefix 167 ctx.HasPrefix = NewPrefix("~") 168 169 var ( 170 send = "hacka doll no. 3" 171 expects = []string{"hacka", "doll", "no.", "3"} 172 ) 173 174 if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" { 175 t.Fatal("Unexpected error:", err) 176 } 177 }) 178 179 t.Run("call command rawarguments", func(t *testing.T) { 180 ctx.HasPrefix = NewPrefix("!") 181 expects := RawArguments("just things") 182 183 if err := expect(ctx, given, expects, "!content just things"); err != nil { 184 t.Fatal("Unexpected call error:", err) 185 } 186 }) 187 188 t.Run("call command custom manual parser", func(t *testing.T) { 189 ctx.HasPrefix = NewPrefix("!") 190 expects := []string{"arg1", ":)"} 191 192 if err := expect(ctx, given, expects, "!custom arg1 :)"); err != nil { 193 t.Fatal("Unexpected call error:", err) 194 } 195 }) 196 197 t.Run("call command custom variadic parser", func(t *testing.T) { 198 ctx.HasPrefix = NewPrefix("!") 199 expects := &customParsed{true} 200 201 if err := expect(ctx, given, expects, "!variadic bruh moment"); err != nil { 202 t.Fatal("Unexpected call error:", err) 203 } 204 }) 205 206 t.Run("call command custom trailing manual parser", func(t *testing.T) { 207 ctx.HasPrefix = NewPrefix("!") 208 expects := ArgumentParts{"arikawa"} 209 210 if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil { 211 t.Fatal("Unexpected call error:", err) 212 } 213 214 if expects.Length() != 1 { 215 t.Fatal("Unexpected ArgumentParts length.") 216 } 217 if expects.After(1)+expects.After(2)+expects.After(-1) != "" { 218 t.Fatal("Unexpected ArgumentsParts after.") 219 } 220 if expects.String() != "arikawa" { 221 t.Fatal("Unexpected ArgumentsParts string.") 222 } 223 if expects.Arg(0) != "arikawa" { 224 t.Fatal("Unexpected ArgumentParts arg 0") 225 } 226 if expects.Arg(1) != "" { 227 t.Fatal("Unexpected ArgumentParts arg 1") 228 } 229 }) 230 231 testMessage := func(content string) error { 232 // Mock a messageCreate event 233 m := &gateway.MessageCreateEvent{ 234 Message: discord.Message{ 235 Content: content, 236 }, 237 } 238 239 return ctx.callCmd(m) 240 } 241 242 t.Run("call command without args", func(t *testing.T) { 243 ctx.HasPrefix = NewPrefix("") 244 245 if err := testMessage("noArgs"); err.Error() != "passed" { 246 t.Fatal("unexpected error:", err) 247 } 248 }) 249 250 // Test error cases 251 252 t.Run("call unknown command", func(t *testing.T) { 253 ctx.HasPrefix = NewPrefix("joe pls ") 254 255 err := testMessage("joe pls no") 256 257 if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") { 258 t.Fatal("unexpected error:", err) 259 } 260 }) 261 262 // Test subcommands 263 264 t.Run("register subcommand", func(t *testing.T) { 265 ctx.HasPrefix = NewPrefix("run ") 266 267 sub := &testc{} 268 ctx.MustRegisterSubcommand(sub) 269 270 if err := testMessage("run testc noop"); err != nil { 271 t.Fatal("Unexpected error:", err) 272 } 273 274 expects := RawArguments("hackadoll no. 3") 275 276 if err := expect(ctx, sub, expects, "run testc content hackadoll no. 3"); err != nil { 277 t.Fatal("Unexpected call error:", err) 278 } 279 280 if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil { 281 t.Fatal("Failed to find subcommand Noop") 282 } 283 }) 284 285 t.Run("register subcommand custom", func(t *testing.T) { 286 ctx.MustRegisterSubcommandCustom(&testc{}, "arikawa") 287 }) 288 289 t.Run("duplicate subcommand", func(t *testing.T) { 290 _, err := ctx.RegisterSubcommandCustom(&testc{}, "arikawa") 291 if err := err.Error(); !strings.Contains(err, "duplicate") { 292 t.Fatal("Unexpected error:", err) 293 } 294 }) 295 296 t.Run("start", func(t *testing.T) { 297 cancel := ctx.Start() 298 defer cancel() 299 300 ctx.HasPrefix = NewPrefix("!") 301 given.Return = make(chan interface{}) 302 303 ctx.Handler.Call(&gateway.MessageCreateEvent{ 304 Message: discord.Message{ 305 Content: "!content hime arikawa best trap", 306 }, 307 }) 308 309 if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" { 310 t.Fatal("Unexpected content:", c) 311 } 312 }) 313 } 314 315 func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) { 316 var v interface{} 317 if call = sendMsg(ctx, given, &v, content); call != nil { 318 return 319 } 320 if !reflect.DeepEqual(v, expects) { 321 return fmt.Errorf("returned argument is invalid: %v", v) 322 } 323 return nil 324 } 325 326 func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) { 327 // Return channel for testing 328 ret := make(chan interface{}) 329 given.Return = ret 330 331 // Mock a messageCreate event 332 m := &gateway.MessageCreateEvent{ 333 Message: discord.Message{ 334 Content: content, 335 }, 336 } 337 338 var callCh = make(chan error) 339 go func() { 340 callCh <- ctx.Call(m) 341 }() 342 343 select { 344 case arg := <-ret: 345 call = <-callCh 346 reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg)) 347 return 348 349 case call = <-callCh: 350 return fmt.Errorf("expected return before error: %w", call) 351 352 case <-time.After(time.Second): 353 return errors.New("timed out while waiting") 354 } 355 } 356 357 func BenchmarkConstructor(b *testing.B) { 358 var s = &state.State{ 359 Store: state.NewDefaultStore(nil), 360 } 361 362 for i := 0; i < b.N; i++ { 363 _, _ = New(s, &testc{}) 364 } 365 } 366 367 func BenchmarkCall(b *testing.B) { 368 var given = &testc{} 369 var s = &state.State{ 370 Store: state.NewDefaultStore(nil), 371 } 372 373 sub, _ := NewSubcommand(given) 374 375 var ctx = &Context{ 376 Subcommand: sub, 377 State: s, 378 HasPrefix: NewPrefix("~"), 379 ParseArgs: DefaultArgsParser(), 380 } 381 382 m := &gateway.MessageCreateEvent{ 383 Message: discord.Message{ 384 Content: "~noop", 385 }, 386 } 387 388 b.ResetTimer() 389 390 for i := 0; i < b.N; i++ { 391 ctx.callCmd(m) 392 } 393 } 394 395 func BenchmarkHelp(b *testing.B) { 396 var given = &testc{} 397 var s = &state.State{ 398 Store: state.NewDefaultStore(nil), 399 } 400 401 sub, _ := NewSubcommand(given) 402 403 var ctx = &Context{ 404 Subcommand: sub, 405 State: s, 406 HasPrefix: NewPrefix("~"), 407 ParseArgs: DefaultArgsParser(), 408 } 409 410 b.ResetTimer() 411 412 for i := 0; i < b.N; i++ { 413 _ = ctx.Help() 414 } 415 }