github.com/diamondburned/arikawa@v1.3.14/bot/ctx_call.go (about)

     1  package bot
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  
     7  	"github.com/diamondburned/arikawa/api"
     8  	"github.com/diamondburned/arikawa/discord"
     9  	"github.com/diamondburned/arikawa/gateway"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  // Break is a non-fatal error that could be returned from middlewares to stop
    14  // the chain of execution.
    15  var Break = errors.New("break middleware chain, non-fatal")
    16  
    17  // filterEventType filters all commands and subcommands into a 2D slice,
    18  // structured so that a Break would only exit out the nested slice.
    19  func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) {
    20  	// Find the main context first.
    21  	callers = append(callers, ctx.eventCallers(evT))
    22  
    23  	for _, sub := range ctx.subcommands {
    24  		// Find subcommands second.
    25  		callers = append(callers, sub.eventCallers(evT))
    26  	}
    27  
    28  	return
    29  }
    30  
    31  func (ctx *Context) callCmd(ev interface{}) (bottomError error) {
    32  	evV := reflect.ValueOf(ev)
    33  	evT := evV.Type()
    34  
    35  	var callers [][]caller
    36  
    37  	// Hit the cache
    38  	t, ok := ctx.typeCache.Load(evT)
    39  	if ok {
    40  		callers = t.([][]caller)
    41  	} else {
    42  		callers = ctx.filterEventType(evT)
    43  		ctx.typeCache.Store(evT, callers)
    44  	}
    45  
    46  	for _, subcallers := range callers {
    47  		for _, c := range subcallers {
    48  			_, err := c.call(evV)
    49  			if err != nil {
    50  				// Only count as an error if it's not Break.
    51  				if err = errNoBreak(err); err != nil {
    52  					bottomError = err
    53  				}
    54  
    55  				// Break the caller loop only for this subcommand.
    56  				break
    57  			}
    58  		}
    59  	}
    60  
    61  	var msc *gateway.MessageCreateEvent
    62  
    63  	// We call the messages later, since we want MessageCreate middlewares to
    64  	// run as well.
    65  	switch {
    66  	case evT == typeMessageCreate:
    67  		msc = ev.(*gateway.MessageCreateEvent)
    68  
    69  	case evT == typeMessageUpdate && ctx.EditableCommands:
    70  		up := ev.(*gateway.MessageUpdateEvent)
    71  		// Message updates could have empty contents when only their embeds are
    72  		// filled. We don't need that here.
    73  		if up.Content == "" {
    74  			return nil
    75  		}
    76  
    77  		// Query the updated message.
    78  		m, err := ctx.Store.Message(up.ChannelID, up.ID)
    79  		if err != nil {
    80  			// It's probably safe to ignore this.
    81  			return nil
    82  		}
    83  
    84  		// Treat the message update as a message create event to avoid breaking
    85  		// changes.
    86  		msc = &gateway.MessageCreateEvent{Message: *m, Member: up.Member}
    87  
    88  		// Fill up member, if available.
    89  		if m.GuildID.IsValid() && up.Member == nil {
    90  			if mem, err := ctx.Store.Member(m.GuildID, m.Author.ID); err == nil {
    91  				msc.Member = mem
    92  			}
    93  		}
    94  
    95  		// Update the reflect value as well.
    96  		evV = reflect.ValueOf(msc)
    97  
    98  	default:
    99  		// Unknown event, return.
   100  		return nil
   101  	}
   102  
   103  	// There's no need for an errNoBreak here, as the method already checked
   104  	// for that.
   105  	return ctx.callMessageCreate(msc, evV)
   106  }
   107  
   108  func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent, value reflect.Value) error {
   109  	// check if bot
   110  	if !ctx.AllowBot && mc.Author.Bot {
   111  		return nil
   112  	}
   113  
   114  	// check if prefix
   115  	pf, ok := ctx.HasPrefix(mc)
   116  	if !ok {
   117  		return nil
   118  	}
   119  
   120  	// trim the prefix before splitting, this way multi-words prefixes work
   121  	content := mc.Content[len(pf):]
   122  
   123  	if content == "" {
   124  		return nil // just the prefix only
   125  	}
   126  
   127  	// parse arguments
   128  	parts, parseErr := ctx.ParseArgs(content)
   129  	// We're not checking parse errors yet, as raw arguments may be able to
   130  	// ignore it.
   131  
   132  	if len(parts) == 0 {
   133  		return parseErr
   134  	}
   135  
   136  	// Find the command and subcommand.
   137  	arguments, cmd, sub, err := ctx.findCommand(parts)
   138  	if err != nil {
   139  		return errNoBreak(err)
   140  	}
   141  
   142  	// We don't run the subcommand's middlewares here, as the callCmd function
   143  	// already handles that.
   144  
   145  	// Run command middlewares.
   146  	if err := cmd.walkMiddlewares(value); err != nil {
   147  		return errNoBreak(err)
   148  	}
   149  
   150  	// Start converting
   151  	var argv []reflect.Value
   152  	var argc int
   153  
   154  	// the last argument in the list, not used until set
   155  	var last Argument
   156  
   157  	// Here's an edge case: when the handler takes no arguments, we allow that
   158  	// anyway, as they might've used the raw content.
   159  	if len(cmd.Arguments) == 0 {
   160  		goto Call
   161  	}
   162  
   163  	// Argument count check.
   164  	if argdelta := len(arguments) - len(cmd.Arguments); argdelta != 0 {
   165  		var err error // no err if nil
   166  
   167  		// If the function is variadic, then we can allow the last argument to
   168  		// be empty.
   169  		if cmd.Variadic {
   170  			argdelta++
   171  		}
   172  
   173  		switch {
   174  		// If there aren't enough arguments given.
   175  		case argdelta < 0:
   176  			err = ErrNotEnoughArgs
   177  
   178  		// If there are too many arguments, then check if the command supports
   179  		// variadic arguments. We already did a length check above.
   180  		case argdelta > 0 && !cmd.Variadic:
   181  			// If it's not variadic, then we can't accept it.
   182  			err = ErrTooManyArgs
   183  		}
   184  
   185  		if err != nil {
   186  			return &ErrInvalidUsage{
   187  				Prefix: pf,
   188  				Args:   parts,
   189  				Index:  len(parts) - 1,
   190  				Wrap:   err,
   191  				Ctx:    cmd,
   192  			}
   193  		}
   194  	}
   195  
   196  	// The last argument in the arguments slice.
   197  	last = cmd.Arguments[len(cmd.Arguments)-1]
   198  
   199  	// Allocate a new slice the length of function arguments.
   200  	argc = len(cmd.Arguments) - 1         // arg len without last
   201  	argv = make([]reflect.Value, 0, argc) // could be 0
   202  
   203  	// Parse all arguments except for the last one.
   204  	for i := 0; i < argc; i++ {
   205  		v, err := cmd.Arguments[i].fn(arguments[0])
   206  		if err != nil {
   207  			return &ErrInvalidUsage{
   208  				Prefix: pf,
   209  				Args:   parts,
   210  				Index:  len(parts) - len(arguments) + i,
   211  				Wrap:   err,
   212  				Ctx:    cmd,
   213  			}
   214  		}
   215  
   216  		// Pop arguments.
   217  		arguments = arguments[1:]
   218  		argv = append(argv, v)
   219  	}
   220  
   221  	// Is this last argument actually a variadic slice? If yes, then it
   222  	// should still have fn normally.
   223  	if last.fn != nil {
   224  		// Allocate a new slice to append into.
   225  		vars := make([]reflect.Value, 0, len(arguments))
   226  
   227  		// Parse the rest with variadic arguments. Go's reflect states that
   228  		// variadic parameters will automatically be copied, which is good.
   229  		for i := 0; len(arguments) > 0; i++ {
   230  			v, err := last.fn(arguments[0])
   231  			if err != nil {
   232  				return &ErrInvalidUsage{
   233  					Prefix: pf,
   234  					Args:   parts,
   235  					Index:  len(parts) - len(arguments) + i,
   236  					Wrap:   err,
   237  					Ctx:    cmd,
   238  				}
   239  			}
   240  
   241  			arguments = arguments[1:]
   242  			vars = append(vars, v)
   243  		}
   244  
   245  		argv = append(argv, vars...)
   246  
   247  	} else {
   248  		// Create a zero value instance of this:
   249  		v := reflect.New(last.rtype)
   250  		var err error // return error
   251  
   252  		switch {
   253  		// If the argument wants all arguments:
   254  		case last.manual != nil:
   255  			// Call the manual parse method:
   256  			_, err = callWith(last.manual.Func, v, reflect.ValueOf(arguments))
   257  
   258  		// If the argument wants all arguments in string:
   259  		case last.custom != nil:
   260  			// Ignore parser errors. This allows custom commands sliced away to
   261  			// have erroneous hanging quotes.
   262  			parseErr = nil
   263  
   264  			// Manual string seeking is a must here. This is because the string
   265  			// could contain multiple whitespaces, and the parser would not
   266  			// count them.
   267  			var seekTo = cmd.Command
   268  			// We can't rely on the plumbing behavior.
   269  			if sub.plumbed != nil {
   270  				seekTo = sub.Command
   271  			}
   272  
   273  			// Seek to the string.
   274  			var i = strings.Index(content, seekTo)
   275  			// Edge case if the subcommand is the same as the command.
   276  			if cmd.Command == sub.Command {
   277  				// Seek again past the command.
   278  				i = strings.Index(content[i+len(seekTo):], seekTo)
   279  			}
   280  
   281  			if i > -1 {
   282  				// Seek past the substring.
   283  				i += len(seekTo)
   284  
   285  				content = strings.TrimSpace(content[i:])
   286  			}
   287  
   288  			// Call the method with the raw unparsed command:
   289  			_, err = callWith(last.custom.Func, v, reflect.ValueOf(content))
   290  		}
   291  
   292  		// Check the returned error:
   293  		if err != nil {
   294  			return err
   295  		}
   296  
   297  		// Check if the argument wants a non-pointer:
   298  		if last.pointer {
   299  			v = v.Elem()
   300  		}
   301  
   302  		// Add the argument into argv.
   303  		argv = append(argv, v)
   304  	}
   305  
   306  	// Check for parsing errors after parsing arguments.
   307  	if parseErr != nil {
   308  		return parseErr
   309  	}
   310  
   311  Call:
   312  	// call the function and parse the error return value
   313  	v, err := cmd.call(value, argv...)
   314  	if err != nil {
   315  		return err
   316  	}
   317  
   318  	switch v := v.(type) {
   319  	case string:
   320  		v = sub.SanitizeMessage(v)
   321  		_, err = ctx.SendMessage(mc.ChannelID, v, nil)
   322  	case *discord.Embed:
   323  		_, err = ctx.SendMessage(mc.ChannelID, "", v)
   324  	case *api.SendMessageData:
   325  		if v.Content != "" {
   326  			v.Content = sub.SanitizeMessage(v.Content)
   327  		}
   328  		_, err = ctx.SendMessageComplex(mc.ChannelID, *v)
   329  	}
   330  
   331  	return err
   332  }
   333  
   334  // findCommand filters.
   335  func (ctx *Context) findCommand(parts []string) ([]string, *MethodContext, *Subcommand, error) {
   336  	// Main command entrypoint cannot have plumb.
   337  	for _, c := range ctx.Commands {
   338  		if c.Command == parts[0] {
   339  			return parts[1:], c, ctx.Subcommand, nil
   340  		}
   341  		// Check for alias
   342  		for _, alias := range c.Aliases {
   343  			if alias == parts[0] {
   344  				return parts[1:], c, ctx.Subcommand, nil
   345  			}
   346  		}
   347  	}
   348  
   349  	// Can't find the command, look for subcommands if len(args) has a 2nd
   350  	// entry.
   351  	for _, s := range ctx.subcommands {
   352  		if s.Command != parts[0] {
   353  			continue
   354  		}
   355  
   356  		// Only actually plumb if we actually have a plumbed handler AND
   357  		//    1. We only have one command handler OR
   358  		//    2. We only have the subcommand name but no command.
   359  		if s.plumbed != nil && (len(s.Commands) == 1 || len(parts) <= 2) {
   360  			return parts[1:], s.plumbed, s, nil
   361  		}
   362  
   363  		if len(parts) >= 2 {
   364  			for _, c := range s.Commands {
   365  				if c.Command == parts[1] {
   366  					return parts[2:], c, s, nil
   367  				}
   368  				// Check for aliases
   369  				for _, alias := range c.Aliases {
   370  					if alias == parts[1] {
   371  						return parts[2:], c, s, nil
   372  					}
   373  				}
   374  			}
   375  		}
   376  
   377  		// If unknown command is disabled or the subcommand is hidden:
   378  		if ctx.SilentUnknown.Subcommand || s.Hidden {
   379  			return nil, nil, nil, Break
   380  		}
   381  
   382  		return nil, nil, nil, &ErrUnknownCommand{
   383  			Parts:  parts,
   384  			Subcmd: s,
   385  		}
   386  	}
   387  
   388  	if ctx.SilentUnknown.Command {
   389  		return nil, nil, nil, Break
   390  	}
   391  
   392  	return nil, nil, nil, &ErrUnknownCommand{
   393  		Parts:  parts,
   394  		Subcmd: ctx.Subcommand,
   395  	}
   396  }
   397  
   398  func errNoBreak(err error) error {
   399  	if errors.Is(err, Break) {
   400  		return nil
   401  	}
   402  	return err
   403  }