github.com/diamondburned/arikawa/v2@v2.1.0/bot/ctx_call.go (about)

     1  package bot
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  
     7  	"github.com/diamondburned/arikawa/v2/api"
     8  	"github.com/diamondburned/arikawa/v2/discord"
     9  	"github.com/diamondburned/arikawa/v2/gateway"
    10  	"github.com/diamondburned/arikawa/v2/utils/json/option"
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  // Break is a non-fatal error that could be returned from middlewares to stop
    15  // the chain of execution.
    16  var Break = errors.New("break middleware chain, non-fatal")
    17  
    18  // filterEventType filters all commands and subcommands into a 2D slice,
    19  // structured so that a Break would only exit out the nested slice.
    20  func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) {
    21  	// Find the main context first.
    22  	callers = append(callers, ctx.eventCallers(evT))
    23  
    24  	for _, sub := range ctx.subcommands {
    25  		// Find subcommands second.
    26  		callers = append(callers, sub.eventCallers(evT))
    27  	}
    28  
    29  	return
    30  }
    31  
    32  func (ctx *Context) callCmd(ev interface{}) (bottomError error) {
    33  	evV := reflect.ValueOf(ev)
    34  	evT := evV.Type()
    35  
    36  	var callers [][]caller
    37  
    38  	// Hit the cache
    39  	t, ok := ctx.typeCache.Load(evT)
    40  	if ok {
    41  		callers = t.([][]caller)
    42  	} else {
    43  		callers = ctx.filterEventType(evT)
    44  		ctx.typeCache.Store(evT, callers)
    45  	}
    46  
    47  	for _, subcallers := range callers {
    48  		for _, c := range subcallers {
    49  			_, err := c.call(evV)
    50  			if err != nil {
    51  				// Only count as an error if it's not Break.
    52  				if err = errNoBreak(err); err != nil {
    53  					bottomError = err
    54  				}
    55  
    56  				// Break the caller loop only for this subcommand.
    57  				break
    58  			}
    59  		}
    60  	}
    61  
    62  	var msc *gateway.MessageCreateEvent
    63  
    64  	// We call the messages later, since we want MessageCreate middlewares to
    65  	// run as well.
    66  	switch {
    67  	case evT == typeMessageCreate:
    68  		msc = ev.(*gateway.MessageCreateEvent)
    69  
    70  	case evT == typeMessageUpdate && ctx.EditableCommands:
    71  		up := ev.(*gateway.MessageUpdateEvent)
    72  		// Message updates could have empty contents when only their embeds are
    73  		// filled. We don't need that here.
    74  		if up.Content == "" {
    75  			return nil
    76  		}
    77  
    78  		// Query the updated message.
    79  		m, err := ctx.Cabinet.Message(up.ChannelID, up.ID)
    80  		if err != nil {
    81  			// It's probably safe to ignore this.
    82  			return nil
    83  		}
    84  
    85  		// Treat the message update as a message create event to avoid breaking
    86  		// changes.
    87  		msc = &gateway.MessageCreateEvent{Message: *m, Member: up.Member}
    88  
    89  		// Fill up member, if available.
    90  		if m.GuildID.IsValid() && up.Member == nil {
    91  			if mem, err := ctx.Cabinet.Member(m.GuildID, m.Author.ID); err == nil {
    92  				msc.Member = mem
    93  			}
    94  		}
    95  
    96  		// Update the reflect value as well.
    97  		evV = reflect.ValueOf(msc)
    98  
    99  	default:
   100  		// Unknown event, return.
   101  		return nil
   102  	}
   103  
   104  	// There's no need for an errNoBreak here, as the method already checked
   105  	// for that.
   106  	return ctx.callMessageCreate(msc, evV)
   107  }
   108  
   109  func (ctx *Context) callMessageCreate(
   110  	mc *gateway.MessageCreateEvent, value reflect.Value) error {
   111  
   112  	v, err := ctx.callMessageCreateNoReply(mc, value)
   113  	if err == nil && v == nil {
   114  		return nil
   115  	}
   116  
   117  	if err != nil && !ctx.ReplyError && ctx.ErrorReplier == nil {
   118  		return err
   119  	}
   120  
   121  	var data api.SendMessageData
   122  
   123  	if err != nil {
   124  		if ctx.ErrorReplier != nil {
   125  			data = ctx.ErrorReplier(err, mc)
   126  		} else {
   127  			data.Content = ctx.FormatError(err)
   128  		}
   129  	} else {
   130  		switch v := v.(type) {
   131  		case string:
   132  			data.Content = v
   133  		case *discord.Embed:
   134  			data.Embed = v
   135  		case *api.SendMessageData:
   136  			data = *v
   137  		default:
   138  			return nil
   139  		}
   140  	}
   141  
   142  	if data.Reference == nil {
   143  		data.Reference = &discord.MessageReference{MessageID: mc.ID}
   144  	}
   145  
   146  	if data.AllowedMentions == nil {
   147  		// Do not mention on reply by default. Only allow author mentions.
   148  		data.AllowedMentions = &api.AllowedMentions{
   149  			Users:       []discord.UserID{mc.Author.ID},
   150  			RepliedUser: option.False,
   151  		}
   152  	}
   153  
   154  	_, err = ctx.SendMessageComplex(mc.ChannelID, data)
   155  	return err
   156  }
   157  
   158  func (ctx *Context) callMessageCreateNoReply(
   159  	mc *gateway.MessageCreateEvent, value reflect.Value) (interface{}, error) {
   160  
   161  	// check if bot
   162  	if !ctx.AllowBot && mc.Author.Bot {
   163  		return nil, nil
   164  	}
   165  
   166  	// check if prefix
   167  	pf, ok := ctx.HasPrefix(mc)
   168  	if !ok {
   169  		return nil, nil
   170  	}
   171  
   172  	// trim the prefix before splitting, this way multi-words prefixes work
   173  	content := mc.Content[len(pf):]
   174  
   175  	if content == "" {
   176  		return nil, nil // just the prefix only
   177  	}
   178  
   179  	// parse arguments
   180  	parts, parseErr := ctx.ParseArgs(content)
   181  	// We're not checking parse errors yet, as raw arguments may be able to
   182  	// ignore it.
   183  
   184  	if len(parts) == 0 {
   185  		return nil, parseErr
   186  	}
   187  
   188  	// Find the command and subcommand.
   189  	commandCtx, err := ctx.findCommand(parts)
   190  	if err != nil {
   191  		return nil, errNoBreak(err)
   192  	}
   193  
   194  	var (
   195  		arguments = commandCtx.parts
   196  		cmd       = commandCtx.method
   197  		sub       = commandCtx.subcmd
   198  		plumbed   = commandCtx.plumbed
   199  	)
   200  
   201  	// We don't run the subcommand's middlewares here, as the callCmd function
   202  	// already handles that.
   203  
   204  	// Run command middlewares.
   205  	if err := cmd.walkMiddlewares(value); err != nil {
   206  		return nil, errNoBreak(err)
   207  	}
   208  
   209  	// Start converting
   210  	var argv []reflect.Value
   211  	var argc int
   212  
   213  	// the last argument in the list, not used until set
   214  	var last Argument
   215  
   216  	// Here's an edge case: when the handler takes no arguments, we allow that
   217  	// anyway, as they might've used the raw content.
   218  	if len(cmd.Arguments) == 0 {
   219  		return cmd.call(value, argv...)
   220  	}
   221  
   222  	// Argument count check.
   223  	if argdelta := len(arguments) - len(cmd.Arguments); argdelta != 0 {
   224  		var err error // no err if nil
   225  
   226  		// If the function is variadic, then we can allow the last argument to
   227  		// be empty.
   228  		if cmd.Variadic {
   229  			argdelta++
   230  		}
   231  
   232  		switch {
   233  		// If there aren't enough arguments given.
   234  		case argdelta < 0:
   235  			err = ErrNotEnoughArgs
   236  
   237  		// If there are too many arguments, then check if the command supports
   238  		// variadic arguments. We already did a length check above.
   239  		case argdelta > 0 && !cmd.Variadic:
   240  			// If it's not variadic, then we can't accept it.
   241  			err = ErrTooManyArgs
   242  		}
   243  
   244  		if err != nil {
   245  			return nil, &ErrInvalidUsage{
   246  				Prefix: pf,
   247  				Args:   parts,
   248  				Index:  len(parts) - 1,
   249  				Wrap:   err,
   250  				Ctx:    cmd,
   251  			}
   252  		}
   253  	}
   254  
   255  	// The last argument in the arguments slice.
   256  	last = cmd.Arguments[len(cmd.Arguments)-1]
   257  
   258  	// Allocate a new slice the length of function arguments.
   259  	argc = len(cmd.Arguments) - 1         // arg len without last
   260  	argv = make([]reflect.Value, 0, argc) // could be 0
   261  
   262  	// Parse all arguments except for the last one.
   263  	for i := 0; i < argc; i++ {
   264  		v, err := cmd.Arguments[i].fn(arguments[0])
   265  		if err != nil {
   266  			return nil, &ErrInvalidUsage{
   267  				Prefix: pf,
   268  				Args:   parts,
   269  				Index:  len(parts) - len(arguments) + i,
   270  				Wrap:   err,
   271  				Ctx:    cmd,
   272  			}
   273  		}
   274  
   275  		// Pop arguments.
   276  		arguments = arguments[1:]
   277  		argv = append(argv, v)
   278  	}
   279  
   280  	// Is this last argument actually a variadic slice? If yes, then it
   281  	// should still have fn normally.
   282  	if last.fn != nil {
   283  		// Allocate a new slice to append into.
   284  		vars := make([]reflect.Value, 0, len(arguments))
   285  
   286  		// Parse the rest with variadic arguments. Go's reflect states that
   287  		// variadic parameters will automatically be copied, which is good.
   288  		for i := 0; len(arguments) > 0; i++ {
   289  			v, err := last.fn(arguments[0])
   290  			if err != nil {
   291  				return nil, &ErrInvalidUsage{
   292  					Prefix: pf,
   293  					Args:   parts,
   294  					Index:  len(parts) - len(arguments) + i,
   295  					Wrap:   err,
   296  					Ctx:    cmd,
   297  				}
   298  			}
   299  
   300  			arguments = arguments[1:]
   301  			vars = append(vars, v)
   302  		}
   303  
   304  		argv = append(argv, vars...)
   305  
   306  	} else {
   307  		// Create a zero value instance of this:
   308  		v := reflect.New(last.rtype)
   309  		var err error // return nil, error
   310  
   311  		switch {
   312  		// If the argument wants all arguments:
   313  		case last.manual != nil:
   314  			// Call the manual parse method:
   315  			err = last.manual(v.Interface().(ManualParser), arguments)
   316  
   317  		// If the argument wants all arguments in string:
   318  		case last.custom != nil:
   319  			// Ignore parser errors. This allows custom commands sliced away to
   320  			// have erroneous hanging quotes.
   321  			parseErr = nil
   322  
   323  			content = trimPrefixStringAndSlice(content, sub.Command, sub.Aliases)
   324  
   325  			// If the current command is not the plumbed command, then we can
   326  			// keep trimming. We have to check for this, as a plumbed subcommand
   327  			// may return nil, other non-plumbed commands.
   328  			if !plumbed {
   329  				content = trimPrefixStringAndSlice(content, cmd.Command, cmd.Aliases)
   330  			}
   331  
   332  			// Call the method with the raw unparsed command:
   333  			err = last.custom(v.Interface().(CustomParser), content)
   334  		}
   335  
   336  		// Check the returned error:
   337  		if err != nil {
   338  			return nil, err
   339  		}
   340  
   341  		// Check if the argument wants a non-pointer:
   342  		if last.pointer {
   343  			v = v.Elem()
   344  		}
   345  
   346  		// Add the argument into argv.
   347  		argv = append(argv, v)
   348  	}
   349  
   350  	// Check for parsing errors after parsing arguments.
   351  	if parseErr != nil {
   352  		return nil, parseErr
   353  	}
   354  
   355  	return cmd.call(value, argv...)
   356  }
   357  
   358  // commandContext contains related command values to call one. It is returned
   359  // from findCommand.
   360  type commandContext struct {
   361  	parts   []string
   362  	plumbed bool
   363  	method  *MethodContext
   364  	subcmd  *Subcommand
   365  }
   366  
   367  var emptyCommand = commandContext{}
   368  
   369  // findCommand filters.
   370  func (ctx *Context) findCommand(parts []string) (commandContext, error) {
   371  	// Main command entrypoint cannot have plumb.
   372  	for _, c := range ctx.Commands {
   373  		if searchStringAndSlice(parts[0], c.Command, c.Aliases) {
   374  			return commandContext{parts[1:], false, c, ctx.Subcommand}, nil
   375  		}
   376  	}
   377  
   378  	// Can't find the command, look for subcommands if len(args) has a 2nd
   379  	// entry.
   380  	for _, s := range ctx.subcommands {
   381  		if !searchStringAndSlice(parts[0], s.Command, s.Aliases) {
   382  			continue
   383  		}
   384  
   385  		// The new plumbing behavior allows other commands to co-exist with a
   386  		// plumbed command. Those commands will override the second argument,
   387  		// similarly to a non-plumbed command.
   388  
   389  		if len(parts) >= 2 {
   390  			for _, c := range s.Commands {
   391  				if searchStringAndSlice(parts[1], c.Command, c.Aliases) {
   392  					return commandContext{parts[2:], false, c, s}, nil
   393  				}
   394  			}
   395  		}
   396  
   397  		if s.IsPlumbed() {
   398  			return commandContext{parts[1:], true, s.plumbed, s}, nil
   399  		}
   400  
   401  		// If unknown command is disabled or the subcommand is hidden:
   402  		if ctx.SilentUnknown.Subcommand || s.Hidden {
   403  			return emptyCommand, Break
   404  		}
   405  
   406  		return emptyCommand, newErrUnknownCommand(s, parts)
   407  	}
   408  
   409  	if ctx.SilentUnknown.Command {
   410  		return emptyCommand, Break
   411  	}
   412  
   413  	return emptyCommand, newErrUnknownCommand(ctx.Subcommand, parts)
   414  }
   415  
   416  // searchStringAndSlice searches if str is equal to isString or any of the given
   417  // otherStrings. It is used for alias matching.
   418  func searchStringAndSlice(str string, isString string, otherStrings []string) bool {
   419  	if str == isString {
   420  		return true
   421  	}
   422  
   423  	for _, other := range otherStrings {
   424  		if other == str {
   425  			return true
   426  		}
   427  	}
   428  
   429  	return false
   430  }
   431  
   432  // trimPrefixStringAndSlice behaves similarly to searchStringAndSlice, but it
   433  // trims the prefix and the surrounding spaces after a match.
   434  func trimPrefixStringAndSlice(str string, prefix string, prefixes []string) string {
   435  	if strings.HasPrefix(str, prefix) {
   436  		return strings.TrimSpace(str[len(prefix):])
   437  	}
   438  
   439  	for _, prefix := range prefixes {
   440  		if strings.HasPrefix(str, prefix) {
   441  			return strings.TrimSpace(str[len(prefix):])
   442  		}
   443  	}
   444  
   445  	return str
   446  }
   447  
   448  func errNoBreak(err error) error {
   449  	if errors.Is(err, Break) {
   450  		return nil
   451  	}
   452  	return err
   453  }