github.com/diamondburned/arikawa/v2@v2.1.0/bot/ctx_test.go (about)

     1  package bot
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/diamondburned/arikawa/v2/discord"
    13  	"github.com/diamondburned/arikawa/v2/gateway"
    14  	"github.com/diamondburned/arikawa/v2/state"
    15  	"github.com/diamondburned/arikawa/v2/state/store"
    16  	"github.com/diamondburned/arikawa/v2/utils/handler"
    17  )
    18  
    19  type testc struct {
    20  	Ctx     *Context
    21  	Return  chan interface{}
    22  	Counter uint64
    23  	Typed   int8
    24  }
    25  
    26  func (t *testc) Setup(sub *Subcommand) {
    27  	sub.AddMiddleware([]string{"*", "GetCounter"}, func(v interface{}) {
    28  		t.Counter++
    29  	})
    30  	sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
    31  		t.Counter++
    32  	})
    33  	// stub middleware for testing
    34  	sub.AddMiddleware(t.OnTyping, func(*gateway.TypingStartEvent) {
    35  		t.Typed = 2
    36  	})
    37  	sub.Hide(t.Hidden)
    38  }
    39  func (t *testc) Hidden(*gateway.MessageCreateEvent) {}
    40  func (t *testc) Noop(*gateway.MessageCreateEvent)   {}
    41  func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
    42  	t.Return <- strconv.FormatUint(t.Counter, 10)
    43  }
    44  func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
    45  	t.Return <- args
    46  	return errors.New("oh no")
    47  }
    48  func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) {
    49  	t.Return <- []string(*c)
    50  }
    51  func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
    52  	t.Return <- c[len(c)-1]
    53  }
    54  func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) {
    55  	t.Return <- c
    56  }
    57  func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
    58  	t.Return <- c
    59  }
    60  func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
    61  	return errors.New("passed")
    62  }
    63  func (t *testc) OnTyping(*gateway.TypingStartEvent) {
    64  	t.Typed--
    65  }
    66  
    67  func TestNewContext(t *testing.T) {
    68  	var s = &state.State{
    69  		Cabinet: store.NoopCabinet,
    70  	}
    71  
    72  	c, err := New(s, &testc{})
    73  	if err != nil {
    74  		t.Fatal("Failed to create new context:", err)
    75  	}
    76  
    77  	if !reflect.DeepEqual(c.Subcommands(), c.subcommands) {
    78  		t.Fatal("Subcommands mismatch.")
    79  	}
    80  }
    81  
    82  func TestContext(t *testing.T) {
    83  	var given = &testc{}
    84  	var s = &state.State{
    85  		Cabinet: store.NoopCabinet,
    86  		Handler: handler.New(),
    87  	}
    88  
    89  	sub, err := NewSubcommand(given)
    90  	if err != nil {
    91  		t.Fatal("Failed to create subcommand:", err)
    92  	}
    93  
    94  	var ctx = &Context{
    95  		Name:        "arikawa/bot test",
    96  		Description: "Just a test.",
    97  
    98  		Subcommand: sub,
    99  		State:      s,
   100  		ParseArgs:  DefaultArgsParser(),
   101  	}
   102  
   103  	t.Run("init commands", func(t *testing.T) {
   104  		if err := ctx.Subcommand.InitCommands(ctx); err != nil {
   105  			t.Fatal("Failed to init commands:", err)
   106  		}
   107  
   108  		if given.Ctx != ctx {
   109  			t.Fatal("given Context field has invalid pointer")
   110  		}
   111  	})
   112  
   113  	t.Run("find commands", func(t *testing.T) {
   114  		cmd := ctx.FindCommand("", "NoArgs")
   115  		if cmd == nil {
   116  			t.Fatal("Failed to find NoArgs")
   117  		}
   118  	})
   119  
   120  	t.Run("middleware", func(t *testing.T) {
   121  		ctx.HasPrefix = NewPrefix("pls do ")
   122  
   123  		// This should trigger the middleware first.
   124  		if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
   125  			t.Fatal("Unexpected error:", err)
   126  		}
   127  	})
   128  
   129  	t.Run("derive intents", func(t *testing.T) {
   130  		intents := ctx.DeriveIntents()
   131  
   132  		assertIntents := func(target gateway.Intents, name string) {
   133  			if !intents.Has(target) {
   134  				t.Error("Derived intents do not have", name)
   135  			}
   136  		}
   137  
   138  		assertIntents(gateway.IntentGuildMessages, "guild messages")
   139  		assertIntents(gateway.IntentDirectMessages, "direct messages")
   140  		assertIntents(gateway.IntentGuildMessageTyping, "guild typing")
   141  		assertIntents(gateway.IntentDirectMessageTyping, "direct message typing")
   142  	})
   143  
   144  	t.Run("typing event", func(t *testing.T) {
   145  		typing := &gateway.TypingStartEvent{}
   146  
   147  		if err := ctx.callCmd(typing); err != nil {
   148  			t.Fatal("Failed to call with TypingStart:", err)
   149  		}
   150  
   151  		// -1 none ran
   152  		if given.Typed != 1 {
   153  			t.Fatal("Typed bool is false")
   154  		}
   155  	})
   156  
   157  	t.Run("call command", func(t *testing.T) {
   158  		// Set a custom prefix
   159  		ctx.HasPrefix = NewPrefix("~")
   160  
   161  		var (
   162  			send    = "hacka doll no. 3"
   163  			expects = []string{"hacka", "doll", "no.", "3"}
   164  		)
   165  
   166  		if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" {
   167  			t.Fatal("Unexpected error:", err)
   168  		}
   169  	})
   170  
   171  	t.Run("call command rawarguments", func(t *testing.T) {
   172  		ctx.HasPrefix = NewPrefix("!")
   173  		expects := RawArguments("just things")
   174  
   175  		if err := expect(ctx, given, expects, "!content just things"); err != nil {
   176  			t.Fatal("Unexpected call error:", err)
   177  		}
   178  	})
   179  
   180  	t.Run("call command custom manual parser", func(t *testing.T) {
   181  		ctx.HasPrefix = NewPrefix("!")
   182  		expects := []string{"arg1", ":)"}
   183  
   184  		if err := expect(ctx, given, expects, "!custom arg1 :)"); err != nil {
   185  			t.Fatal("Unexpected call error:", err)
   186  		}
   187  	})
   188  
   189  	t.Run("call command custom variadic parser", func(t *testing.T) {
   190  		ctx.HasPrefix = NewPrefix("!")
   191  		expects := &customParsed{true}
   192  
   193  		if err := expect(ctx, given, expects, "!variadic bruh moment"); err != nil {
   194  			t.Fatal("Unexpected call error:", err)
   195  		}
   196  	})
   197  
   198  	t.Run("call command custom trailing manual parser", func(t *testing.T) {
   199  		ctx.HasPrefix = NewPrefix("!")
   200  		expects := ArgumentParts{"arikawa"}
   201  
   202  		if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil {
   203  			t.Fatal("Unexpected call error:", err)
   204  		}
   205  
   206  		if expects.Length() != 1 {
   207  			t.Fatal("Unexpected ArgumentParts length.")
   208  		}
   209  		if expects.After(1)+expects.After(2)+expects.After(-1) != "" {
   210  			t.Fatal("Unexpected ArgumentsParts after.")
   211  		}
   212  		if expects.String() != "arikawa" {
   213  			t.Fatal("Unexpected ArgumentsParts string.")
   214  		}
   215  		if expects.Arg(0) != "arikawa" {
   216  			t.Fatal("Unexpected ArgumentParts arg 0")
   217  		}
   218  		if expects.Arg(1) != "" {
   219  			t.Fatal("Unexpected ArgumentParts arg 1")
   220  		}
   221  	})
   222  
   223  	testMessage := func(content string) error {
   224  		// Mock a messageCreate event
   225  		m := &gateway.MessageCreateEvent{
   226  			Message: discord.Message{
   227  				Content: content,
   228  			},
   229  		}
   230  
   231  		return ctx.callCmd(m)
   232  	}
   233  
   234  	t.Run("call command without args", func(t *testing.T) {
   235  		ctx.HasPrefix = NewPrefix("")
   236  
   237  		if err := testMessage("noArgs"); err.Error() != "passed" {
   238  			t.Fatal("unexpected error:", err)
   239  		}
   240  	})
   241  
   242  	// Test error cases
   243  
   244  	t.Run("call unknown command", func(t *testing.T) {
   245  		ctx.HasPrefix = NewPrefix("joe pls ")
   246  
   247  		err := testMessage("joe pls no")
   248  
   249  		if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") {
   250  			t.Fatal("unexpected error:", err)
   251  		}
   252  	})
   253  
   254  	// Test subcommands
   255  
   256  	t.Run("register subcommand", func(t *testing.T) {
   257  		ctx.HasPrefix = NewPrefix("run ")
   258  
   259  		sub := &testc{}
   260  		ctx.MustRegisterSubcommand(sub)
   261  
   262  		if err := testMessage("run testc noop"); err != nil {
   263  			t.Fatal("Unexpected error:", err)
   264  		}
   265  
   266  		expects := RawArguments("hackadoll no. 3")
   267  
   268  		if err := expect(ctx, sub, expects, "run testc content hackadoll no. 3"); err != nil {
   269  			t.Fatal("Unexpected call error:", err)
   270  		}
   271  
   272  		if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil {
   273  			t.Fatal("Failed to find subcommand Noop")
   274  		}
   275  	})
   276  
   277  	t.Run("register subcommand custom", func(t *testing.T) {
   278  		ctx.MustRegisterSubcommand(&testc{}, "arikawa", "a")
   279  	})
   280  
   281  	t.Run("duplicate subcommand", func(t *testing.T) {
   282  		_, err := ctx.RegisterSubcommand(&testc{}, "arikawa")
   283  		if err := err.Error(); !strings.Contains(err, "duplicate") {
   284  			t.Fatal("Unexpected error:", err)
   285  		}
   286  
   287  		_, err = ctx.RegisterSubcommand(&testc{}, "a")
   288  		if err := err.Error(); !strings.Contains(err, "duplicate") {
   289  			t.Fatal("Unexpected error:", err)
   290  		}
   291  	})
   292  
   293  	t.Run("help", func(t *testing.T) {
   294  		ctx.MustRegisterSubcommand(&testc{}, "helper")
   295  
   296  		h := ctx.Help()
   297  		if h == "" {
   298  			t.Fatal("Empty help?")
   299  		}
   300  
   301  		if strings.Contains(h, "hidden") {
   302  			t.Fatal("Hidden command shown in help.")
   303  		}
   304  
   305  		if !strings.Contains(h, "arikawa/bot test") {
   306  			t.Fatal("Name not found.")
   307  		}
   308  		if !strings.Contains(h, "Just a test.") {
   309  			t.Fatal("Description not found.")
   310  		}
   311  		if !strings.Contains(h, "**a**") {
   312  			t.Fatal("arikawa alias `a' not found.")
   313  		}
   314  	})
   315  
   316  	t.Run("start", func(t *testing.T) {
   317  		cancel := ctx.Start()
   318  		defer cancel()
   319  
   320  		ctx.HasPrefix = NewPrefix("!")
   321  		given.Return = make(chan interface{})
   322  
   323  		ctx.Handler.Call(&gateway.MessageCreateEvent{
   324  			Message: discord.Message{
   325  				Content: "!content hime arikawa best trap",
   326  			},
   327  		})
   328  
   329  		if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" {
   330  			t.Fatal("Unexpected content:", c)
   331  		}
   332  	})
   333  }
   334  
   335  func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) {
   336  	var v interface{}
   337  	if call = sendMsg(ctx, given, &v, content); call != nil {
   338  		return
   339  	}
   340  	if !reflect.DeepEqual(v, expects) {
   341  		return fmt.Errorf("returned argument is invalid: %v", v)
   342  	}
   343  	return nil
   344  }
   345  
   346  func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) {
   347  	// Return channel for testing
   348  	ret := make(chan interface{})
   349  	given.Return = ret
   350  
   351  	// Mock a messageCreate event
   352  	m := &gateway.MessageCreateEvent{
   353  		Message: discord.Message{
   354  			Content: content,
   355  		},
   356  	}
   357  
   358  	var callCh = make(chan error)
   359  	go func() {
   360  		callCh <- ctx.Call(m)
   361  	}()
   362  
   363  	select {
   364  	case arg := <-ret:
   365  		call = <-callCh
   366  		reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg))
   367  		return
   368  
   369  	case call = <-callCh:
   370  		return fmt.Errorf("expected return before error: %w", call)
   371  
   372  	case <-time.After(time.Second):
   373  		return errors.New("timed out while waiting")
   374  	}
   375  }
   376  
   377  func BenchmarkConstructor(b *testing.B) {
   378  	var s = &state.State{
   379  		Cabinet: store.NoopCabinet,
   380  	}
   381  
   382  	for i := 0; i < b.N; i++ {
   383  		_, _ = New(s, &testc{})
   384  	}
   385  }
   386  
   387  func BenchmarkCall(b *testing.B) {
   388  	var given = &testc{}
   389  	var s = &state.State{
   390  		Cabinet: store.NoopCabinet,
   391  	}
   392  
   393  	sub, _ := NewSubcommand(given)
   394  
   395  	var ctx = &Context{
   396  		Subcommand: sub,
   397  		State:      s,
   398  		HasPrefix:  NewPrefix("~"),
   399  		ParseArgs:  DefaultArgsParser(),
   400  	}
   401  
   402  	m := &gateway.MessageCreateEvent{
   403  		Message: discord.Message{
   404  			Content: "~noop",
   405  		},
   406  	}
   407  
   408  	b.ResetTimer()
   409  
   410  	for i := 0; i < b.N; i++ {
   411  		ctx.callCmd(m)
   412  	}
   413  }
   414  
   415  func BenchmarkHelp(b *testing.B) {
   416  	var given = &testc{}
   417  	var s = &state.State{
   418  		Cabinet: store.NoopCabinet,
   419  	}
   420  
   421  	sub, _ := NewSubcommand(given)
   422  
   423  	var ctx = &Context{
   424  		Subcommand: sub,
   425  		State:      s,
   426  		HasPrefix:  NewPrefix("~"),
   427  		ParseArgs:  DefaultArgsParser(),
   428  	}
   429  
   430  	b.ResetTimer()
   431  
   432  	for i := 0; i < b.N; i++ {
   433  		_ = ctx.Help()
   434  	}
   435  }