github.com/diamondburned/arikawa/v2@v2.1.0/bot/ctx_test.go (about) 1 package bot 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "strconv" 8 "strings" 9 "testing" 10 "time" 11 12 "github.com/diamondburned/arikawa/v2/discord" 13 "github.com/diamondburned/arikawa/v2/gateway" 14 "github.com/diamondburned/arikawa/v2/state" 15 "github.com/diamondburned/arikawa/v2/state/store" 16 "github.com/diamondburned/arikawa/v2/utils/handler" 17 ) 18 19 type testc struct { 20 Ctx *Context 21 Return chan interface{} 22 Counter uint64 23 Typed int8 24 } 25 26 func (t *testc) Setup(sub *Subcommand) { 27 sub.AddMiddleware([]string{"*", "GetCounter"}, func(v interface{}) { 28 t.Counter++ 29 }) 30 sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) { 31 t.Counter++ 32 }) 33 // stub middleware for testing 34 sub.AddMiddleware(t.OnTyping, func(*gateway.TypingStartEvent) { 35 t.Typed = 2 36 }) 37 sub.Hide(t.Hidden) 38 } 39 func (t *testc) Hidden(*gateway.MessageCreateEvent) {} 40 func (t *testc) Noop(*gateway.MessageCreateEvent) {} 41 func (t *testc) GetCounter(*gateway.MessageCreateEvent) { 42 t.Return <- strconv.FormatUint(t.Counter, 10) 43 } 44 func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error { 45 t.Return <- args 46 return errors.New("oh no") 47 } 48 func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) { 49 t.Return <- []string(*c) 50 } 51 func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) { 52 t.Return <- c[len(c)-1] 53 } 54 func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) { 55 t.Return <- c 56 } 57 func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) { 58 t.Return <- c 59 } 60 func (t *testc) NoArgs(*gateway.MessageCreateEvent) error { 61 return errors.New("passed") 62 } 63 func (t *testc) OnTyping(*gateway.TypingStartEvent) { 64 t.Typed-- 65 } 66 67 func TestNewContext(t *testing.T) { 68 var s = &state.State{ 69 Cabinet: store.NoopCabinet, 70 } 71 72 c, err := New(s, &testc{}) 73 if err != nil { 74 t.Fatal("Failed to create new context:", err) 75 } 76 77 if !reflect.DeepEqual(c.Subcommands(), c.subcommands) { 78 t.Fatal("Subcommands mismatch.") 79 } 80 } 81 82 func TestContext(t *testing.T) { 83 var given = &testc{} 84 var s = &state.State{ 85 Cabinet: store.NoopCabinet, 86 Handler: handler.New(), 87 } 88 89 sub, err := NewSubcommand(given) 90 if err != nil { 91 t.Fatal("Failed to create subcommand:", err) 92 } 93 94 var ctx = &Context{ 95 Name: "arikawa/bot test", 96 Description: "Just a test.", 97 98 Subcommand: sub, 99 State: s, 100 ParseArgs: DefaultArgsParser(), 101 } 102 103 t.Run("init commands", func(t *testing.T) { 104 if err := ctx.Subcommand.InitCommands(ctx); err != nil { 105 t.Fatal("Failed to init commands:", err) 106 } 107 108 if given.Ctx != ctx { 109 t.Fatal("given Context field has invalid pointer") 110 } 111 }) 112 113 t.Run("find commands", func(t *testing.T) { 114 cmd := ctx.FindCommand("", "NoArgs") 115 if cmd == nil { 116 t.Fatal("Failed to find NoArgs") 117 } 118 }) 119 120 t.Run("middleware", func(t *testing.T) { 121 ctx.HasPrefix = NewPrefix("pls do ") 122 123 // This should trigger the middleware first. 124 if err := expect(ctx, given, "3", "pls do getCounter"); err != nil { 125 t.Fatal("Unexpected error:", err) 126 } 127 }) 128 129 t.Run("derive intents", func(t *testing.T) { 130 intents := ctx.DeriveIntents() 131 132 assertIntents := func(target gateway.Intents, name string) { 133 if !intents.Has(target) { 134 t.Error("Derived intents do not have", name) 135 } 136 } 137 138 assertIntents(gateway.IntentGuildMessages, "guild messages") 139 assertIntents(gateway.IntentDirectMessages, "direct messages") 140 assertIntents(gateway.IntentGuildMessageTyping, "guild typing") 141 assertIntents(gateway.IntentDirectMessageTyping, "direct message typing") 142 }) 143 144 t.Run("typing event", func(t *testing.T) { 145 typing := &gateway.TypingStartEvent{} 146 147 if err := ctx.callCmd(typing); err != nil { 148 t.Fatal("Failed to call with TypingStart:", err) 149 } 150 151 // -1 none ran 152 if given.Typed != 1 { 153 t.Fatal("Typed bool is false") 154 } 155 }) 156 157 t.Run("call command", func(t *testing.T) { 158 // Set a custom prefix 159 ctx.HasPrefix = NewPrefix("~") 160 161 var ( 162 send = "hacka doll no. 3" 163 expects = []string{"hacka", "doll", "no.", "3"} 164 ) 165 166 if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" { 167 t.Fatal("Unexpected error:", err) 168 } 169 }) 170 171 t.Run("call command rawarguments", func(t *testing.T) { 172 ctx.HasPrefix = NewPrefix("!") 173 expects := RawArguments("just things") 174 175 if err := expect(ctx, given, expects, "!content just things"); err != nil { 176 t.Fatal("Unexpected call error:", err) 177 } 178 }) 179 180 t.Run("call command custom manual parser", func(t *testing.T) { 181 ctx.HasPrefix = NewPrefix("!") 182 expects := []string{"arg1", ":)"} 183 184 if err := expect(ctx, given, expects, "!custom arg1 :)"); err != nil { 185 t.Fatal("Unexpected call error:", err) 186 } 187 }) 188 189 t.Run("call command custom variadic parser", func(t *testing.T) { 190 ctx.HasPrefix = NewPrefix("!") 191 expects := &customParsed{true} 192 193 if err := expect(ctx, given, expects, "!variadic bruh moment"); err != nil { 194 t.Fatal("Unexpected call error:", err) 195 } 196 }) 197 198 t.Run("call command custom trailing manual parser", func(t *testing.T) { 199 ctx.HasPrefix = NewPrefix("!") 200 expects := ArgumentParts{"arikawa"} 201 202 if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil { 203 t.Fatal("Unexpected call error:", err) 204 } 205 206 if expects.Length() != 1 { 207 t.Fatal("Unexpected ArgumentParts length.") 208 } 209 if expects.After(1)+expects.After(2)+expects.After(-1) != "" { 210 t.Fatal("Unexpected ArgumentsParts after.") 211 } 212 if expects.String() != "arikawa" { 213 t.Fatal("Unexpected ArgumentsParts string.") 214 } 215 if expects.Arg(0) != "arikawa" { 216 t.Fatal("Unexpected ArgumentParts arg 0") 217 } 218 if expects.Arg(1) != "" { 219 t.Fatal("Unexpected ArgumentParts arg 1") 220 } 221 }) 222 223 testMessage := func(content string) error { 224 // Mock a messageCreate event 225 m := &gateway.MessageCreateEvent{ 226 Message: discord.Message{ 227 Content: content, 228 }, 229 } 230 231 return ctx.callCmd(m) 232 } 233 234 t.Run("call command without args", func(t *testing.T) { 235 ctx.HasPrefix = NewPrefix("") 236 237 if err := testMessage("noArgs"); err.Error() != "passed" { 238 t.Fatal("unexpected error:", err) 239 } 240 }) 241 242 // Test error cases 243 244 t.Run("call unknown command", func(t *testing.T) { 245 ctx.HasPrefix = NewPrefix("joe pls ") 246 247 err := testMessage("joe pls no") 248 249 if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") { 250 t.Fatal("unexpected error:", err) 251 } 252 }) 253 254 // Test subcommands 255 256 t.Run("register subcommand", func(t *testing.T) { 257 ctx.HasPrefix = NewPrefix("run ") 258 259 sub := &testc{} 260 ctx.MustRegisterSubcommand(sub) 261 262 if err := testMessage("run testc noop"); err != nil { 263 t.Fatal("Unexpected error:", err) 264 } 265 266 expects := RawArguments("hackadoll no. 3") 267 268 if err := expect(ctx, sub, expects, "run testc content hackadoll no. 3"); err != nil { 269 t.Fatal("Unexpected call error:", err) 270 } 271 272 if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil { 273 t.Fatal("Failed to find subcommand Noop") 274 } 275 }) 276 277 t.Run("register subcommand custom", func(t *testing.T) { 278 ctx.MustRegisterSubcommand(&testc{}, "arikawa", "a") 279 }) 280 281 t.Run("duplicate subcommand", func(t *testing.T) { 282 _, err := ctx.RegisterSubcommand(&testc{}, "arikawa") 283 if err := err.Error(); !strings.Contains(err, "duplicate") { 284 t.Fatal("Unexpected error:", err) 285 } 286 287 _, err = ctx.RegisterSubcommand(&testc{}, "a") 288 if err := err.Error(); !strings.Contains(err, "duplicate") { 289 t.Fatal("Unexpected error:", err) 290 } 291 }) 292 293 t.Run("help", func(t *testing.T) { 294 ctx.MustRegisterSubcommand(&testc{}, "helper") 295 296 h := ctx.Help() 297 if h == "" { 298 t.Fatal("Empty help?") 299 } 300 301 if strings.Contains(h, "hidden") { 302 t.Fatal("Hidden command shown in help.") 303 } 304 305 if !strings.Contains(h, "arikawa/bot test") { 306 t.Fatal("Name not found.") 307 } 308 if !strings.Contains(h, "Just a test.") { 309 t.Fatal("Description not found.") 310 } 311 if !strings.Contains(h, "**a**") { 312 t.Fatal("arikawa alias `a' not found.") 313 } 314 }) 315 316 t.Run("start", func(t *testing.T) { 317 cancel := ctx.Start() 318 defer cancel() 319 320 ctx.HasPrefix = NewPrefix("!") 321 given.Return = make(chan interface{}) 322 323 ctx.Handler.Call(&gateway.MessageCreateEvent{ 324 Message: discord.Message{ 325 Content: "!content hime arikawa best trap", 326 }, 327 }) 328 329 if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" { 330 t.Fatal("Unexpected content:", c) 331 } 332 }) 333 } 334 335 func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) { 336 var v interface{} 337 if call = sendMsg(ctx, given, &v, content); call != nil { 338 return 339 } 340 if !reflect.DeepEqual(v, expects) { 341 return fmt.Errorf("returned argument is invalid: %v", v) 342 } 343 return nil 344 } 345 346 func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) { 347 // Return channel for testing 348 ret := make(chan interface{}) 349 given.Return = ret 350 351 // Mock a messageCreate event 352 m := &gateway.MessageCreateEvent{ 353 Message: discord.Message{ 354 Content: content, 355 }, 356 } 357 358 var callCh = make(chan error) 359 go func() { 360 callCh <- ctx.Call(m) 361 }() 362 363 select { 364 case arg := <-ret: 365 call = <-callCh 366 reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg)) 367 return 368 369 case call = <-callCh: 370 return fmt.Errorf("expected return before error: %w", call) 371 372 case <-time.After(time.Second): 373 return errors.New("timed out while waiting") 374 } 375 } 376 377 func BenchmarkConstructor(b *testing.B) { 378 var s = &state.State{ 379 Cabinet: store.NoopCabinet, 380 } 381 382 for i := 0; i < b.N; i++ { 383 _, _ = New(s, &testc{}) 384 } 385 } 386 387 func BenchmarkCall(b *testing.B) { 388 var given = &testc{} 389 var s = &state.State{ 390 Cabinet: store.NoopCabinet, 391 } 392 393 sub, _ := NewSubcommand(given) 394 395 var ctx = &Context{ 396 Subcommand: sub, 397 State: s, 398 HasPrefix: NewPrefix("~"), 399 ParseArgs: DefaultArgsParser(), 400 } 401 402 m := &gateway.MessageCreateEvent{ 403 Message: discord.Message{ 404 Content: "~noop", 405 }, 406 } 407 408 b.ResetTimer() 409 410 for i := 0; i < b.N; i++ { 411 ctx.callCmd(m) 412 } 413 } 414 415 func BenchmarkHelp(b *testing.B) { 416 var given = &testc{} 417 var s = &state.State{ 418 Cabinet: store.NoopCabinet, 419 } 420 421 sub, _ := NewSubcommand(given) 422 423 var ctx = &Context{ 424 Subcommand: sub, 425 State: s, 426 HasPrefix: NewPrefix("~"), 427 ParseArgs: DefaultArgsParser(), 428 } 429 430 b.ResetTimer() 431 432 for i := 0; i < b.N; i++ { 433 _ = ctx.Help() 434 } 435 }