github.com/diamondburned/arikawa@v1.3.14/bot/ctx_test.go (about)

     1  package bot
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/diamondburned/arikawa/discord"
    13  	"github.com/diamondburned/arikawa/gateway"
    14  	"github.com/diamondburned/arikawa/state"
    15  	"github.com/diamondburned/arikawa/utils/handler"
    16  )
    17  
    18  type testc struct {
    19  	Ctx     *Context
    20  	Return  chan interface{}
    21  	Counter uint64
    22  	Typed   int8
    23  }
    24  
    25  func (t *testc) Setup(sub *Subcommand) {
    26  	sub.AddMiddleware("*,GetCounter", func(v interface{}) {
    27  		t.Counter++
    28  	})
    29  	sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
    30  		t.Counter++
    31  	})
    32  	// stub middleware for testing
    33  	sub.AddMiddleware("OnTyping", func(*gateway.TypingStartEvent) {
    34  		t.Typed = 2
    35  	})
    36  	sub.Hide("Hidden")
    37  }
    38  func (t *testc) Hidden(*gateway.MessageCreateEvent) {}
    39  func (t *testc) Noop(*gateway.MessageCreateEvent)   {}
    40  func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
    41  	t.Return <- strconv.FormatUint(t.Counter, 10)
    42  }
    43  func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
    44  	t.Return <- args
    45  	return errors.New("oh no")
    46  }
    47  func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) {
    48  	t.Return <- []string(*c)
    49  }
    50  func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
    51  	t.Return <- c[len(c)-1]
    52  }
    53  func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) {
    54  	t.Return <- c
    55  }
    56  func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
    57  	t.Return <- c
    58  }
    59  func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
    60  	return errors.New("passed")
    61  }
    62  func (t *testc) OnTyping(*gateway.TypingStartEvent) {
    63  	t.Typed--
    64  }
    65  
    66  func TestNewContext(t *testing.T) {
    67  	var s = &state.State{
    68  		Store: state.NewDefaultStore(nil),
    69  	}
    70  
    71  	c, err := New(s, &testc{})
    72  	if err != nil {
    73  		t.Fatal("Failed to create new context:", err)
    74  	}
    75  
    76  	if !reflect.DeepEqual(c.Subcommands(), c.subcommands) {
    77  		t.Fatal("Subcommands mismatch.")
    78  	}
    79  }
    80  
    81  func TestContext(t *testing.T) {
    82  	var given = &testc{}
    83  	var s = &state.State{
    84  		Store:   state.NewDefaultStore(nil),
    85  		Handler: handler.New(),
    86  	}
    87  
    88  	sub, err := NewSubcommand(given)
    89  	if err != nil {
    90  		t.Fatal("Failed to create subcommand:", err)
    91  	}
    92  
    93  	var ctx = &Context{
    94  		Name:        "arikawa/bot test",
    95  		Description: "Just a test.",
    96  
    97  		Subcommand: sub,
    98  		State:      s,
    99  		ParseArgs:  DefaultArgsParser(),
   100  	}
   101  
   102  	t.Run("init commands", func(t *testing.T) {
   103  		if err := ctx.Subcommand.InitCommands(ctx); err != nil {
   104  			t.Fatal("Failed to init commands:", err)
   105  		}
   106  
   107  		if given.Ctx == nil {
   108  			t.Fatal("given'sub Context field is nil")
   109  		}
   110  
   111  		if given.Ctx.State.Store == nil {
   112  			t.Fatal("given'sub State is nil")
   113  		}
   114  	})
   115  
   116  	t.Run("find commands", func(t *testing.T) {
   117  		cmd := ctx.FindCommand("", "NoArgs")
   118  		if cmd == nil {
   119  			t.Fatal("Failed to find NoArgs")
   120  		}
   121  	})
   122  
   123  	t.Run("help", func(t *testing.T) {
   124  		ctx.MustRegisterSubcommandCustom(&testc{}, "helper")
   125  
   126  		h := ctx.Help()
   127  		if h == "" {
   128  			t.Fatal("Empty help?")
   129  		}
   130  
   131  		if strings.Contains(h, "hidden") {
   132  			t.Fatal("Hidden command shown in help.")
   133  		}
   134  
   135  		if !strings.Contains(h, "arikawa/bot test") {
   136  			t.Fatal("Name not found.")
   137  		}
   138  		if !strings.Contains(h, "Just a test.") {
   139  			t.Fatal("Description not found.")
   140  		}
   141  	})
   142  
   143  	t.Run("middleware", func(t *testing.T) {
   144  		ctx.HasPrefix = NewPrefix("pls do ")
   145  
   146  		// This should trigger the middleware first.
   147  		if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
   148  			t.Fatal("Unexpected error:", err)
   149  		}
   150  	})
   151  
   152  	t.Run("typing event", func(t *testing.T) {
   153  		typing := &gateway.TypingStartEvent{}
   154  
   155  		if err := ctx.callCmd(typing); err != nil {
   156  			t.Fatal("Failed to call with TypingStart:", err)
   157  		}
   158  
   159  		// -1 none ran
   160  		if given.Typed != 1 {
   161  			t.Fatal("Typed bool is false")
   162  		}
   163  	})
   164  
   165  	t.Run("call command", func(t *testing.T) {
   166  		// Set a custom prefix
   167  		ctx.HasPrefix = NewPrefix("~")
   168  
   169  		var (
   170  			send    = "hacka doll no. 3"
   171  			expects = []string{"hacka", "doll", "no.", "3"}
   172  		)
   173  
   174  		if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" {
   175  			t.Fatal("Unexpected error:", err)
   176  		}
   177  	})
   178  
   179  	t.Run("call command rawarguments", func(t *testing.T) {
   180  		ctx.HasPrefix = NewPrefix("!")
   181  		expects := RawArguments("just things")
   182  
   183  		if err := expect(ctx, given, expects, "!content just things"); err != nil {
   184  			t.Fatal("Unexpected call error:", err)
   185  		}
   186  	})
   187  
   188  	t.Run("call command custom manual parser", func(t *testing.T) {
   189  		ctx.HasPrefix = NewPrefix("!")
   190  		expects := []string{"arg1", ":)"}
   191  
   192  		if err := expect(ctx, given, expects, "!custom arg1 :)"); err != nil {
   193  			t.Fatal("Unexpected call error:", err)
   194  		}
   195  	})
   196  
   197  	t.Run("call command custom variadic parser", func(t *testing.T) {
   198  		ctx.HasPrefix = NewPrefix("!")
   199  		expects := &customParsed{true}
   200  
   201  		if err := expect(ctx, given, expects, "!variadic bruh moment"); err != nil {
   202  			t.Fatal("Unexpected call error:", err)
   203  		}
   204  	})
   205  
   206  	t.Run("call command custom trailing manual parser", func(t *testing.T) {
   207  		ctx.HasPrefix = NewPrefix("!")
   208  		expects := ArgumentParts{"arikawa"}
   209  
   210  		if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil {
   211  			t.Fatal("Unexpected call error:", err)
   212  		}
   213  
   214  		if expects.Length() != 1 {
   215  			t.Fatal("Unexpected ArgumentParts length.")
   216  		}
   217  		if expects.After(1)+expects.After(2)+expects.After(-1) != "" {
   218  			t.Fatal("Unexpected ArgumentsParts after.")
   219  		}
   220  		if expects.String() != "arikawa" {
   221  			t.Fatal("Unexpected ArgumentsParts string.")
   222  		}
   223  		if expects.Arg(0) != "arikawa" {
   224  			t.Fatal("Unexpected ArgumentParts arg 0")
   225  		}
   226  		if expects.Arg(1) != "" {
   227  			t.Fatal("Unexpected ArgumentParts arg 1")
   228  		}
   229  	})
   230  
   231  	testMessage := func(content string) error {
   232  		// Mock a messageCreate event
   233  		m := &gateway.MessageCreateEvent{
   234  			Message: discord.Message{
   235  				Content: content,
   236  			},
   237  		}
   238  
   239  		return ctx.callCmd(m)
   240  	}
   241  
   242  	t.Run("call command without args", func(t *testing.T) {
   243  		ctx.HasPrefix = NewPrefix("")
   244  
   245  		if err := testMessage("noArgs"); err.Error() != "passed" {
   246  			t.Fatal("unexpected error:", err)
   247  		}
   248  	})
   249  
   250  	// Test error cases
   251  
   252  	t.Run("call unknown command", func(t *testing.T) {
   253  		ctx.HasPrefix = NewPrefix("joe pls ")
   254  
   255  		err := testMessage("joe pls no")
   256  
   257  		if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") {
   258  			t.Fatal("unexpected error:", err)
   259  		}
   260  	})
   261  
   262  	// Test subcommands
   263  
   264  	t.Run("register subcommand", func(t *testing.T) {
   265  		ctx.HasPrefix = NewPrefix("run ")
   266  
   267  		sub := &testc{}
   268  		ctx.MustRegisterSubcommand(sub)
   269  
   270  		if err := testMessage("run testc noop"); err != nil {
   271  			t.Fatal("Unexpected error:", err)
   272  		}
   273  
   274  		expects := RawArguments("hackadoll no. 3")
   275  
   276  		if err := expect(ctx, sub, expects, "run testc content hackadoll no. 3"); err != nil {
   277  			t.Fatal("Unexpected call error:", err)
   278  		}
   279  
   280  		if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil {
   281  			t.Fatal("Failed to find subcommand Noop")
   282  		}
   283  	})
   284  
   285  	t.Run("register subcommand custom", func(t *testing.T) {
   286  		ctx.MustRegisterSubcommandCustom(&testc{}, "arikawa")
   287  	})
   288  
   289  	t.Run("duplicate subcommand", func(t *testing.T) {
   290  		_, err := ctx.RegisterSubcommandCustom(&testc{}, "arikawa")
   291  		if err := err.Error(); !strings.Contains(err, "duplicate") {
   292  			t.Fatal("Unexpected error:", err)
   293  		}
   294  	})
   295  
   296  	t.Run("start", func(t *testing.T) {
   297  		cancel := ctx.Start()
   298  		defer cancel()
   299  
   300  		ctx.HasPrefix = NewPrefix("!")
   301  		given.Return = make(chan interface{})
   302  
   303  		ctx.Handler.Call(&gateway.MessageCreateEvent{
   304  			Message: discord.Message{
   305  				Content: "!content hime arikawa best trap",
   306  			},
   307  		})
   308  
   309  		if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" {
   310  			t.Fatal("Unexpected content:", c)
   311  		}
   312  	})
   313  }
   314  
   315  func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) {
   316  	var v interface{}
   317  	if call = sendMsg(ctx, given, &v, content); call != nil {
   318  		return
   319  	}
   320  	if !reflect.DeepEqual(v, expects) {
   321  		return fmt.Errorf("returned argument is invalid: %v", v)
   322  	}
   323  	return nil
   324  }
   325  
   326  func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) {
   327  	// Return channel for testing
   328  	ret := make(chan interface{})
   329  	given.Return = ret
   330  
   331  	// Mock a messageCreate event
   332  	m := &gateway.MessageCreateEvent{
   333  		Message: discord.Message{
   334  			Content: content,
   335  		},
   336  	}
   337  
   338  	var callCh = make(chan error)
   339  	go func() {
   340  		callCh <- ctx.Call(m)
   341  	}()
   342  
   343  	select {
   344  	case arg := <-ret:
   345  		call = <-callCh
   346  		reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg))
   347  		return
   348  
   349  	case call = <-callCh:
   350  		return fmt.Errorf("expected return before error: %w", call)
   351  
   352  	case <-time.After(time.Second):
   353  		return errors.New("timed out while waiting")
   354  	}
   355  }
   356  
   357  func BenchmarkConstructor(b *testing.B) {
   358  	var s = &state.State{
   359  		Store: state.NewDefaultStore(nil),
   360  	}
   361  
   362  	for i := 0; i < b.N; i++ {
   363  		_, _ = New(s, &testc{})
   364  	}
   365  }
   366  
   367  func BenchmarkCall(b *testing.B) {
   368  	var given = &testc{}
   369  	var s = &state.State{
   370  		Store: state.NewDefaultStore(nil),
   371  	}
   372  
   373  	sub, _ := NewSubcommand(given)
   374  
   375  	var ctx = &Context{
   376  		Subcommand: sub,
   377  		State:      s,
   378  		HasPrefix:  NewPrefix("~"),
   379  		ParseArgs:  DefaultArgsParser(),
   380  	}
   381  
   382  	m := &gateway.MessageCreateEvent{
   383  		Message: discord.Message{
   384  			Content: "~noop",
   385  		},
   386  	}
   387  
   388  	b.ResetTimer()
   389  
   390  	for i := 0; i < b.N; i++ {
   391  		ctx.callCmd(m)
   392  	}
   393  }
   394  
   395  func BenchmarkHelp(b *testing.B) {
   396  	var given = &testc{}
   397  	var s = &state.State{
   398  		Store: state.NewDefaultStore(nil),
   399  	}
   400  
   401  	sub, _ := NewSubcommand(given)
   402  
   403  	var ctx = &Context{
   404  		Subcommand: sub,
   405  		State:      s,
   406  		HasPrefix:  NewPrefix("~"),
   407  		ParseArgs:  DefaultArgsParser(),
   408  	}
   409  
   410  	b.ResetTimer()
   411  
   412  	for i := 0; i < b.N; i++ {
   413  		_ = ctx.Help()
   414  	}
   415  }