github.com/starshine-sys/bcr@v0.21.0/ctx_buttons.go (about)

     1  package bcr
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/diamondburned/arikawa/v3/api"
    11  	"github.com/diamondburned/arikawa/v3/discord"
    12  	"github.com/diamondburned/arikawa/v3/gateway"
    13  )
    14  
    15  // ConfirmData is the data for ctx.ConfirmButton()
    16  type ConfirmData struct {
    17  	Message string
    18  	Embeds  []discord.Embed
    19  
    20  	// Defaults to "Confirm"
    21  	YesPrompt string
    22  	// Defaults to a primary button
    23  	YesStyle discord.ButtonStyle
    24  	// Defaults to "Cancel"
    25  	NoPrompt string
    26  	// Defaults to a secondary button
    27  	NoStyle discord.ButtonStyle
    28  
    29  	// Defaults to one minute
    30  	Timeout time.Duration
    31  }
    32  
    33  // ConfirmButton confirms a prompt with buttons or "yes"/"no" messages.
    34  func (ctx *Context) ConfirmButton(userID discord.UserID, data ConfirmData) (yes, timeout bool) {
    35  	if data.Message == "" && len(data.Embeds) == 0 {
    36  		return
    37  	}
    38  
    39  	if data.YesPrompt == "" {
    40  		data.YesPrompt = "Confirm"
    41  	}
    42  	if data.YesStyle == 0 {
    43  		data.YesStyle = discord.PrimaryButton
    44  	}
    45  	if data.NoPrompt == "" {
    46  		data.NoPrompt = "Cancel"
    47  	}
    48  	if data.NoStyle == 0 {
    49  		data.NoStyle = discord.SecondaryButton
    50  	}
    51  	if data.Timeout == 0 {
    52  		data.Timeout = time.Minute
    53  	}
    54  
    55  	con, cancel := context.WithTimeout(context.Background(), data.Timeout)
    56  	defer cancel()
    57  
    58  	msg, err := ctx.State.SendMessageComplex(ctx.Message.ChannelID, api.SendMessageData{
    59  		Content: data.Message,
    60  		Embeds:  data.Embeds,
    61  
    62  		Components: []discord.Component{
    63  			discord.ActionRowComponent{
    64  				Components: []discord.Component{
    65  					discord.ButtonComponent{
    66  						Label:    data.YesPrompt,
    67  						Style:    data.YesStyle,
    68  						CustomID: "yes",
    69  					},
    70  					discord.ButtonComponent{
    71  						Label:    data.NoPrompt,
    72  						Style:    data.NoStyle,
    73  						CustomID: "no",
    74  					},
    75  				},
    76  			},
    77  		},
    78  	})
    79  	if err != nil {
    80  		return
    81  	}
    82  
    83  	v := ctx.State.WaitFor(con, func(ev interface{}) bool {
    84  		v, ok := ev.(*gateway.InteractionCreateEvent)
    85  		if ok {
    86  			if v.Message == nil || (v.Member == nil && v.User == nil) {
    87  				return false
    88  			}
    89  
    90  			if v.Message.ID != msg.ID {
    91  				return false
    92  			}
    93  
    94  			var uID discord.UserID
    95  			if v.Member != nil {
    96  				uID = v.Member.User.ID
    97  			} else {
    98  				uID = v.User.ID
    99  			}
   100  
   101  			if uID != userID {
   102  				return false
   103  			}
   104  
   105  			if v.Data.CustomID == "" {
   106  				return false
   107  			}
   108  
   109  			yes = v.Data.CustomID == "yes"
   110  			timeout = false
   111  			return true
   112  		}
   113  
   114  		m, ok := ev.(*gateway.MessageCreateEvent)
   115  		if ok {
   116  			if m.ChannelID != msg.ChannelID || m.Author.ID != ctx.Author.ID {
   117  				return false
   118  			}
   119  
   120  			switch strings.ToLower(m.Content) {
   121  			case "yes", "y", strings.ToLower(data.YesPrompt):
   122  				yes = true
   123  				timeout = false
   124  				return true
   125  			case "no", "n", strings.ToLower(data.NoPrompt):
   126  				yes = false
   127  				timeout = false
   128  				return true
   129  			default:
   130  				return false
   131  			}
   132  		}
   133  
   134  		return false
   135  	})
   136  
   137  	upd := &[]discord.Component{
   138  		discord.ActionRowComponent{
   139  			Components: []discord.Component{
   140  				discord.ButtonComponent{
   141  					Label:    data.YesPrompt,
   142  					Style:    data.YesStyle,
   143  					CustomID: "yes",
   144  					Disabled: true,
   145  				},
   146  				discord.ButtonComponent{
   147  					Label:    data.NoPrompt,
   148  					Style:    data.NoStyle,
   149  					CustomID: "no",
   150  					Disabled: true,
   151  				},
   152  			},
   153  		},
   154  	}
   155  
   156  	ctx.State.EditMessageComplex(msg.ChannelID, msg.ID, api.EditMessageData{
   157  		Components: upd,
   158  	})
   159  
   160  	if v == nil {
   161  		return false, true
   162  	}
   163  
   164  	if ev, ok := v.(*gateway.InteractionCreateEvent); ok {
   165  		ctx.State.RespondInteraction(ev.ID, ev.Token, api.InteractionResponse{
   166  			Type: api.UpdateMessage,
   167  			Data: &api.InteractionResponseData{
   168  				Components: upd,
   169  			},
   170  		})
   171  	}
   172  
   173  	return
   174  }
   175  
   176  type buttonKey struct {
   177  	msg      discord.MessageID
   178  	user     discord.UserID
   179  	customID string
   180  }
   181  
   182  type buttonInfo struct {
   183  	ctx    *Context
   184  	fn     func(*Context, *gateway.InteractionCreateEvent)
   185  	delete bool
   186  }
   187  
   188  // ButtonRemoveFunc is returned by AddButtonHandler
   189  type ButtonRemoveFunc func()
   190  
   191  // AddButtonHandler adds a handler for the given message ID, user ID, and custom ID
   192  func (ctx *Context) AddButtonHandler(
   193  	msg discord.MessageID,
   194  	user discord.UserID,
   195  	customID string,
   196  	del bool,
   197  	fn func(*Context, *gateway.InteractionCreateEvent),
   198  ) ButtonRemoveFunc {
   199  	ctx.Router.buttonMu.Lock()
   200  	defer ctx.Router.buttonMu.Unlock()
   201  
   202  	ctx.Router.buttons[buttonKey{msg, user, customID}] = buttonInfo{ctx, fn, del}
   203  
   204  	return func() {
   205  		ctx.Router.buttonMu.Lock()
   206  		delete(ctx.Router.buttons, buttonKey{msg, user, customID})
   207  		ctx.Router.buttonMu.Unlock()
   208  	}
   209  }
   210  
   211  // ButtonHandler handles buttons added by ctx.AddButtonHandler
   212  func (r *Router) ButtonHandler(ev *gateway.InteractionCreateEvent) {
   213  	if ev.Type != gateway.ButtonInteraction {
   214  		return
   215  	}
   216  
   217  	if ev.Message == nil ||
   218  		(ev.Member == nil && ev.User == nil) ||
   219  		ev.Data == nil {
   220  		return
   221  	}
   222  	if ev.Data.CustomID == "" {
   223  		return
   224  	}
   225  
   226  	var user discord.UserID
   227  	if ev.Member != nil {
   228  		user = ev.Member.User.ID
   229  	} else {
   230  		user = ev.User.ID
   231  	}
   232  
   233  	r.buttonMu.RLock()
   234  	info, ok := r.buttons[buttonKey{ev.Message.ID, user, ev.Data.CustomID}]
   235  	r.buttonMu.RUnlock()
   236  
   237  	if !ok {
   238  		r.slashButton(ev, user)
   239  		return
   240  	}
   241  
   242  	info.fn(info.ctx, ev)
   243  
   244  	if info.delete {
   245  		r.buttonMu.Lock()
   246  		delete(r.buttons, buttonKey{ev.Message.ID, user, ev.Data.CustomID})
   247  		r.buttonMu.Unlock()
   248  	}
   249  }
   250  
   251  func (r *Router) slashButton(ev *gateway.InteractionCreateEvent, user discord.UserID) {
   252  	r.slashButtonMu.RLock()
   253  	info, ok := r.slashButtons[buttonKey{ev.Message.ID, user, ev.Data.CustomID}]
   254  	r.slashButtonMu.RUnlock()
   255  
   256  	if !ok {
   257  		return
   258  	}
   259  
   260  	info.fn(info.ctx, ev)
   261  
   262  	if info.delete {
   263  		r.slashButtonMu.Lock()
   264  		delete(r.slashButtons, buttonKey{ev.Message.ID, user, ev.Data.CustomID})
   265  		r.slashButtonMu.Unlock()
   266  	}
   267  }
   268  
   269  // ButtonPages is like PagedEmbed but uses buttons instead of reactions.
   270  func (ctx *Context) ButtonPages(embeds []discord.Embed, timeout time.Duration) (msg *discord.Message, rmFunc func(), err error) {
   271  	return ctx.ButtonPagesWithComponents(embeds, timeout, nil)
   272  }
   273  
   274  // ButtonPagesWithComponents is like ButtonPages but adds the given components before the buttons used for pagination.
   275  func (ctx *Context) ButtonPagesWithComponents(embeds []discord.Embed, timeout time.Duration, components []discord.Component) (msg *discord.Message, rmFunc func(), err error) {
   276  	rmFunc = func() {}
   277  
   278  	if len(embeds) == 0 {
   279  		return nil, func() {}, errors.New("no embeds")
   280  	}
   281  
   282  	if len(embeds) == 1 {
   283  		msg, err = ctx.State.SendEmbeds(ctx.Message.ChannelID, embeds[0])
   284  		return
   285  	}
   286  
   287  	components = append(components, []discord.Component{discord.ActionRowComponent{
   288  		Components: []discord.Component{
   289  			discord.ButtonComponent{
   290  				Emoji: &discord.ButtonEmoji{
   291  					Name: "⏪",
   292  				},
   293  				Style:    discord.SecondaryButton,
   294  				CustomID: "first",
   295  			},
   296  			discord.ButtonComponent{
   297  				Emoji: &discord.ButtonEmoji{
   298  					Name: "⬅️",
   299  				},
   300  				Style:    discord.SecondaryButton,
   301  				CustomID: "prev",
   302  			},
   303  			discord.ButtonComponent{
   304  				Emoji: &discord.ButtonEmoji{
   305  					Name: "➡️",
   306  				},
   307  				Style:    discord.SecondaryButton,
   308  				CustomID: "next",
   309  			},
   310  			discord.ButtonComponent{
   311  				Emoji: &discord.ButtonEmoji{
   312  					Name: "⏩",
   313  				},
   314  				Style:    discord.SecondaryButton,
   315  				CustomID: "last",
   316  			},
   317  			discord.ButtonComponent{
   318  				Emoji: &discord.ButtonEmoji{
   319  					Name: "❌",
   320  				},
   321  				Style:    discord.SecondaryButton,
   322  				CustomID: "cross",
   323  			},
   324  		},
   325  	}}...)
   326  
   327  	msg, err = ctx.State.SendMessageComplex(ctx.Message.ChannelID, api.SendMessageData{
   328  		Embeds:     []discord.Embed{embeds[0]},
   329  		Components: components,
   330  	})
   331  	if err != nil {
   332  		return
   333  	}
   334  
   335  	page := 0
   336  
   337  	prev := ctx.AddButtonHandler(msg.ID, ctx.Author.ID, "prev", false, func(ctx *Context, ev *gateway.InteractionCreateEvent) {
   338  		if page == 0 {
   339  			page = len(embeds) - 1
   340  		} else {
   341  			page--
   342  		}
   343  
   344  		ctx.State.RespondInteraction(ev.ID, ev.Token, api.InteractionResponse{
   345  			Type: api.UpdateMessage,
   346  			Data: &api.InteractionResponseData{
   347  				Embeds: &[]discord.Embed{embeds[page]},
   348  			},
   349  		})
   350  	})
   351  
   352  	next := ctx.AddButtonHandler(msg.ID, ctx.Author.ID, "next", false, func(ctx *Context, ev *gateway.InteractionCreateEvent) {
   353  		if page >= len(embeds)-1 {
   354  			page = 0
   355  		} else {
   356  			page++
   357  		}
   358  
   359  		ctx.State.RespondInteraction(ev.ID, ev.Token, api.InteractionResponse{
   360  			Type: api.UpdateMessage,
   361  			Data: &api.InteractionResponseData{
   362  				Embeds: &[]discord.Embed{embeds[page]},
   363  			},
   364  		})
   365  	})
   366  
   367  	first := ctx.AddButtonHandler(msg.ID, ctx.Author.ID, "first", false, func(ctx *Context, ev *gateway.InteractionCreateEvent) {
   368  		page = 0
   369  
   370  		ctx.State.RespondInteraction(ev.ID, ev.Token, api.InteractionResponse{
   371  			Type: api.UpdateMessage,
   372  			Data: &api.InteractionResponseData{
   373  				Embeds: &[]discord.Embed{embeds[page]},
   374  			},
   375  		})
   376  	})
   377  
   378  	last := ctx.AddButtonHandler(msg.ID, ctx.Author.ID, "last", false, func(ctx *Context, ev *gateway.InteractionCreateEvent) {
   379  		page = len(embeds) - 1
   380  
   381  		ctx.State.RespondInteraction(ev.ID, ev.Token, api.InteractionResponse{
   382  			Type: api.UpdateMessage,
   383  			Data: &api.InteractionResponseData{
   384  				Embeds: &[]discord.Embed{embeds[page]},
   385  			},
   386  		})
   387  	})
   388  
   389  	var o sync.Once
   390  
   391  	cross := ctx.AddButtonHandler(msg.ID, ctx.Author.ID, "cross", false, func(ctx *Context, ev *gateway.InteractionCreateEvent) {
   392  		ctx.State.EditMessageComplex(msg.ChannelID, msg.ID, api.EditMessageData{
   393  			Components: &[]discord.Component{},
   394  		})
   395  	})
   396  
   397  	rmFunc = func() {
   398  		o.Do(func() {
   399  			ctx.State.EditMessageComplex(msg.ChannelID, msg.ID, api.EditMessageData{
   400  				Components: &[]discord.Component{},
   401  			})
   402  
   403  			prev()
   404  			next()
   405  			first()
   406  			last()
   407  			cross()
   408  		})
   409  	}
   410  
   411  	time.AfterFunc(timeout, rmFunc)
   412  	return msg, rmFunc, err
   413  }