github.com/diamondburned/arikawa@v1.3.14/state/state.go (about)

     1  // Package state provides interfaces for a local or remote state, as well as
     2  // abstractions around the REST API and Gateway events.
     3  package state
     4  
     5  import (
     6  	"context"
     7  	"sync"
     8  
     9  	"github.com/diamondburned/arikawa/discord"
    10  	"github.com/diamondburned/arikawa/gateway"
    11  	"github.com/diamondburned/arikawa/internal/moreatomic"
    12  	"github.com/diamondburned/arikawa/session"
    13  	"github.com/diamondburned/arikawa/utils/handler"
    14  
    15  	"github.com/pkg/errors"
    16  )
    17  
    18  var (
    19  	MaxFetchMembers uint = 1000
    20  	MaxFetchGuilds  uint = 10
    21  )
    22  
    23  // State is the cache to store events coming from Discord as well as data from
    24  // API calls.
    25  //
    26  // Store
    27  //
    28  // The state basically provides abstractions on top of the API and the state
    29  // storage (Store). The state storage is effectively a set of interfaces which
    30  // allow arbitrary backends to be implemented.
    31  //
    32  // The default storage backend is a typical in-memory structure consisting of
    33  // maps and slices. Custom backend implementations could embed this storage
    34  // backend as an in-memory fallback. A good example of this would be embedding
    35  // the default store for messages only, while handling everything else in Redis.
    36  //
    37  // The package also provides a no-op store (NoopStore) that implementations
    38  // could embed. This no-op store will always return an error, which makes the
    39  // state fetch information from the API. The setters are all no-ops, so the
    40  // fetched data won't be updated.
    41  //
    42  // Handler
    43  //
    44  // The state uses its own handler over session's to make all handlers run after
    45  // the state updates itself. A PreHandler is exposed in any case the user needs
    46  // the handlers to run before the state updates itself. Refer to that field's
    47  // documentation.
    48  //
    49  // The state also provides extra events and overrides to make up for Discord's
    50  // inconsistencies in data. The following are known instances of such.
    51  //
    52  // The Guild Create event is split up to make the state's Guild Available, Guild
    53  // Ready and Guild Join events. Refer to these events' documentations for more
    54  // information.
    55  //
    56  // The Message Create and Message Update events with the Member field provided
    57  // will have the User field copied from Author. This is because the User field
    58  // will be empty, while the Member structure expects it to be there.
    59  type State struct {
    60  	*session.Session
    61  	Store
    62  
    63  	// *: State doesn't actually keep track of pinned messages.
    64  
    65  	// Ready is not updated by the state.
    66  	Ready gateway.ReadyEvent
    67  
    68  	// StateLog logs all errors that come from the state cache. This includes
    69  	// not found errors. Defaults to a no-op, as state errors aren't that
    70  	// important.
    71  	StateLog func(error)
    72  
    73  	// PreHandler is the manual hook that is executed before the State handler
    74  	// is. This should only be used for low-level operations.
    75  	// It's recommended to set Synchronous to true if you mutate the events.
    76  	PreHandler *handler.Handler // default nil
    77  
    78  	// Command handler with inherited methods. Ran after PreHandler. You should
    79  	// most of the time use this instead of Session's, to avoid race conditions
    80  	// with the State.
    81  	*handler.Handler
    82  
    83  	// List of channels with few messages, so it doesn't bother hitting the API
    84  	// again.
    85  	fewMessages map[discord.ChannelID]struct{}
    86  	fewMutex    *sync.Mutex
    87  
    88  	// unavailableGuilds is a set of discord.GuildIDs of guilds that became
    89  	// unavailable when already connected to the gateway, i.e. sent in a
    90  	// GuildUnavailableEvent.
    91  	unavailableGuilds *moreatomic.GuildIDSet
    92  	// unreadyGuilds is a set of discord.GuildIDs of guilds that were
    93  	// unavailable when connecting to the gateway, i.e. they had Unavailable
    94  	// set to true during Ready.
    95  	unreadyGuilds *moreatomic.GuildIDSet
    96  }
    97  
    98  // New creates a new state.
    99  func New(token string) (*State, error) {
   100  	return NewWithStore(token, NewDefaultStore(nil))
   101  }
   102  
   103  // NewWithIntents creates a new state with the given gateway intents. For more
   104  // information, refer to gateway.Intents.
   105  func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
   106  	s, err := session.NewWithIntents(token, intents...)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	return NewFromSession(s, NewDefaultStore(nil))
   112  }
   113  
   114  func NewWithStore(token string, store Store) (*State, error) {
   115  	s, err := session.New(token)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	return NewFromSession(s, store)
   121  }
   122  
   123  // NewFromSession never returns an error. This API is kept for backwards
   124  // compatibility.
   125  func NewFromSession(s *session.Session, store Store) (*State, error) {
   126  	state := &State{
   127  		Session:           s,
   128  		Store:             store,
   129  		Handler:           handler.New(),
   130  		StateLog:          func(err error) {},
   131  		fewMessages:       map[discord.ChannelID]struct{}{},
   132  		fewMutex:          new(sync.Mutex),
   133  		unavailableGuilds: moreatomic.NewGuildIDSet(),
   134  		unreadyGuilds:     moreatomic.NewGuildIDSet(),
   135  	}
   136  	state.hookSession()
   137  	return state, nil
   138  }
   139  
   140  // WithContext returns a shallow copy of State with the context replaced in the
   141  // API client. All methods called on the State will use this given context. This
   142  // method is thread-safe.
   143  func (s *State) WithContext(ctx context.Context) *State {
   144  	copied := *s
   145  	copied.Client = copied.Client.WithContext(ctx)
   146  
   147  	return &copied
   148  }
   149  
   150  //// Helper methods
   151  
   152  func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string {
   153  	if !message.GuildID.IsValid() {
   154  		return message.Author.Username
   155  	}
   156  
   157  	if message.Member != nil {
   158  		if message.Member.Nick != "" {
   159  			return message.Member.Nick
   160  		}
   161  		return message.Author.Username
   162  	}
   163  
   164  	n, err := s.MemberDisplayName(message.GuildID, message.Author.ID)
   165  	if err != nil {
   166  		return message.Author.Username
   167  	}
   168  
   169  	return n
   170  }
   171  
   172  func (s *State) MemberDisplayName(guildID discord.GuildID, userID discord.UserID) (string, error) {
   173  	member, err := s.Member(guildID, userID)
   174  	if err != nil {
   175  		return "", err
   176  	}
   177  
   178  	if member.Nick == "" {
   179  		return member.User.Username, nil
   180  	}
   181  
   182  	return member.Nick, nil
   183  }
   184  
   185  func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color, error) {
   186  	if !message.GuildID.IsValid() { // this is a dm
   187  		return discord.DefaultMemberColor, nil
   188  	}
   189  
   190  	if message.Member != nil {
   191  		guild, err := s.Guild(message.GuildID)
   192  		if err != nil {
   193  			return 0, err
   194  		}
   195  		return discord.MemberColor(*guild, *message.Member), nil
   196  	}
   197  
   198  	return s.MemberColor(message.GuildID, message.Author.ID)
   199  }
   200  
   201  func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) {
   202  	var wg sync.WaitGroup
   203  
   204  	g, gerr := s.Store.Guild(guildID)
   205  	m, merr := s.Store.Member(guildID, userID)
   206  
   207  	switch {
   208  	case gerr != nil && merr != nil:
   209  		wg.Add(1)
   210  		go func() {
   211  			g, gerr = s.fetchGuild(guildID)
   212  			wg.Done()
   213  		}()
   214  
   215  		m, merr = s.fetchMember(guildID, userID)
   216  	case gerr != nil:
   217  		g, gerr = s.fetchGuild(guildID)
   218  	case merr != nil:
   219  		m, merr = s.fetchMember(guildID, userID)
   220  	}
   221  
   222  	wg.Wait()
   223  
   224  	if gerr != nil {
   225  		return 0, errors.Wrap(merr, "failed to get guild")
   226  	}
   227  	if merr != nil {
   228  		return 0, errors.Wrap(merr, "failed to get member")
   229  	}
   230  
   231  	return discord.MemberColor(*g, *m), nil
   232  }
   233  
   234  ////
   235  
   236  func (s *State) Permissions(
   237  	channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) {
   238  
   239  	ch, err := s.Channel(channelID)
   240  	if err != nil {
   241  		return 0, errors.Wrap(err, "failed to get channel")
   242  	}
   243  
   244  	var wg sync.WaitGroup
   245  
   246  	g, gerr := s.Store.Guild(ch.GuildID)
   247  	m, merr := s.Store.Member(ch.GuildID, userID)
   248  
   249  	switch {
   250  	case gerr != nil && merr != nil:
   251  		wg.Add(1)
   252  		go func() {
   253  			g, gerr = s.fetchGuild(ch.GuildID)
   254  			wg.Done()
   255  		}()
   256  
   257  		m, merr = s.fetchMember(ch.GuildID, userID)
   258  	case gerr != nil:
   259  		g, gerr = s.fetchGuild(ch.GuildID)
   260  	case merr != nil:
   261  		m, merr = s.fetchMember(ch.GuildID, userID)
   262  	}
   263  
   264  	wg.Wait()
   265  
   266  	if gerr != nil {
   267  		return 0, errors.Wrap(merr, "failed to get guild")
   268  	}
   269  	if merr != nil {
   270  		return 0, errors.Wrap(merr, "failed to get member")
   271  	}
   272  
   273  	return discord.CalcOverwrites(*g, *ch, *m), nil
   274  }
   275  
   276  ////
   277  
   278  func (s *State) Me() (*discord.User, error) {
   279  	u, err := s.Store.Me()
   280  	if err == nil {
   281  		return u, nil
   282  	}
   283  
   284  	u, err = s.Session.Me()
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	return u, s.Store.MyselfSet(*u)
   290  }
   291  
   292  ////
   293  
   294  func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) {
   295  	c, err := s.Store.Channel(id)
   296  	if err == nil {
   297  		return c, nil
   298  	}
   299  
   300  	c, err = s.Session.Channel(id)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	return c, s.Store.ChannelSet(*c)
   306  }
   307  
   308  func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
   309  	c, err := s.Store.Channels(guildID)
   310  	if err == nil {
   311  		return c, nil
   312  	}
   313  
   314  	c, err = s.Session.Channels(guildID)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  
   319  	for _, ch := range c {
   320  		ch := ch
   321  
   322  		if err := s.Store.ChannelSet(ch); err != nil {
   323  			return nil, err
   324  		}
   325  	}
   326  
   327  	return c, nil
   328  }
   329  
   330  func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
   331  	c, err := s.Store.CreatePrivateChannel(recipient)
   332  	if err == nil {
   333  		return c, nil
   334  	}
   335  
   336  	c, err = s.Session.CreatePrivateChannel(recipient)
   337  	if err != nil {
   338  		return nil, err
   339  	}
   340  
   341  	return c, s.Store.ChannelSet(*c)
   342  }
   343  
   344  func (s *State) PrivateChannels() ([]discord.Channel, error) {
   345  	c, err := s.Store.PrivateChannels()
   346  	if err == nil {
   347  		return c, nil
   348  	}
   349  
   350  	c, err = s.Session.PrivateChannels()
   351  	if err != nil {
   352  		return nil, err
   353  	}
   354  
   355  	for _, ch := range c {
   356  		ch := ch
   357  
   358  		if err := s.Store.ChannelSet(ch); err != nil {
   359  			return nil, err
   360  		}
   361  	}
   362  
   363  	return c, nil
   364  }
   365  
   366  ////
   367  
   368  func (s *State) Emoji(
   369  	guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
   370  
   371  	e, err := s.Store.Emoji(guildID, emojiID)
   372  	if err == nil {
   373  		return e, nil
   374  	}
   375  
   376  	es, err := s.Session.Emojis(guildID)
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	if err := s.Store.EmojiSet(guildID, es); err != nil {
   382  		return nil, err
   383  	}
   384  
   385  	for _, e := range es {
   386  		if e.ID == emojiID {
   387  			return &e, nil
   388  		}
   389  	}
   390  
   391  	return nil, ErrStoreNotFound
   392  }
   393  
   394  func (s *State) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
   395  	e, err := s.Store.Emojis(guildID)
   396  	if err == nil {
   397  		return e, nil
   398  	}
   399  
   400  	es, err := s.Session.Emojis(guildID)
   401  	if err != nil {
   402  		return nil, err
   403  	}
   404  
   405  	return es, s.Store.EmojiSet(guildID, es)
   406  }
   407  
   408  ////
   409  
   410  func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
   411  	c, err := s.Store.Guild(id)
   412  	if err == nil {
   413  		return c, nil
   414  	}
   415  
   416  	return s.fetchGuild(id)
   417  }
   418  
   419  // Guilds will only fill a maximum of 100 guilds from the API.
   420  func (s *State) Guilds() ([]discord.Guild, error) {
   421  	c, err := s.Store.Guilds()
   422  	if err == nil {
   423  		return c, nil
   424  	}
   425  
   426  	c, err = s.Session.Guilds(MaxFetchGuilds)
   427  	if err != nil {
   428  		return nil, err
   429  	}
   430  
   431  	for _, ch := range c {
   432  		ch := ch
   433  
   434  		if err := s.Store.GuildSet(ch); err != nil {
   435  			return nil, err
   436  		}
   437  	}
   438  
   439  	return c, nil
   440  }
   441  
   442  ////
   443  
   444  func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
   445  	m, err := s.Store.Member(guildID, userID)
   446  	if err == nil {
   447  		return m, nil
   448  	}
   449  
   450  	return s.fetchMember(guildID, userID)
   451  }
   452  
   453  func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) {
   454  	ms, err := s.Store.Members(guildID)
   455  	if err == nil {
   456  		return ms, nil
   457  	}
   458  
   459  	ms, err = s.Session.Members(guildID, MaxFetchMembers)
   460  	if err != nil {
   461  		return nil, err
   462  	}
   463  
   464  	for _, m := range ms {
   465  		if err := s.Store.MemberSet(guildID, m); err != nil {
   466  			return nil, err
   467  		}
   468  	}
   469  
   470  	return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{
   471  		GuildID:   []discord.GuildID{guildID},
   472  		Presences: true,
   473  	})
   474  }
   475  
   476  ////
   477  
   478  func (s *State) Message(
   479  	channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
   480  
   481  	m, err := s.Store.Message(channelID, messageID)
   482  	if err == nil {
   483  		return m, nil
   484  	}
   485  
   486  	var wg sync.WaitGroup
   487  
   488  	c, cerr := s.Store.Channel(channelID)
   489  	if cerr != nil {
   490  		wg.Add(1)
   491  		go func() {
   492  			c, cerr = s.Session.Channel(channelID)
   493  			if cerr == nil {
   494  				cerr = s.Store.ChannelSet(*c)
   495  			}
   496  
   497  			wg.Done()
   498  		}()
   499  	}
   500  
   501  	m, err = s.Session.Message(channelID, messageID)
   502  	if err != nil {
   503  		return nil, errors.Wrap(err, "unable to fetch message")
   504  	}
   505  
   506  	wg.Wait()
   507  
   508  	if cerr != nil {
   509  		return nil, errors.Wrap(cerr, "unable to fetch channel")
   510  	}
   511  
   512  	m.ChannelID = c.ID
   513  	m.GuildID = c.GuildID
   514  
   515  	return m, s.Store.MessageSet(*m)
   516  }
   517  
   518  // Messages fetches maximum 100 messages from the API, if it has to. There is no
   519  // limit if it's from the State storage.
   520  func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
   521  	// TODO: Think of a design that doesn't rely on MaxMessages().
   522  	var maxMsgs = s.MaxMessages()
   523  
   524  	ms, err := s.Store.Messages(channelID)
   525  	if err == nil {
   526  		// If the state already has as many messages as it can, skip the API.
   527  		if maxMsgs <= len(ms) {
   528  			return ms, nil
   529  		}
   530  
   531  		// Is the channel tiny?
   532  		s.fewMutex.Lock()
   533  		if _, ok := s.fewMessages[channelID]; ok {
   534  			s.fewMutex.Unlock()
   535  			return ms, nil
   536  		}
   537  
   538  		// No, fetch from the state.
   539  		s.fewMutex.Unlock()
   540  	}
   541  
   542  	ms, err = s.Session.Messages(channelID, uint(maxMsgs))
   543  	if err != nil {
   544  		return nil, err
   545  	}
   546  
   547  	// New messages fetched weirdly does not have GuildID filled. We'll try and
   548  	// get it for consistency with incoming message creates.
   549  	var guildID discord.GuildID
   550  
   551  	// A bit too convoluted, but whatever.
   552  	c, err := s.Channel(channelID)
   553  	if err == nil {
   554  		// If it's 0, it's 0 anyway. We don't need a check here.
   555  		guildID = c.GuildID
   556  	}
   557  
   558  	// Iterate in reverse, since the store is expected to prepend the latest
   559  	// messages.
   560  	for i := len(ms) - 1; i >= 0; i-- {
   561  		// Set the guild ID, fine if it's 0 (it's already 0 anyway).
   562  		ms[i].GuildID = guildID
   563  
   564  		if err := s.Store.MessageSet(ms[i]); err != nil {
   565  			return nil, err
   566  		}
   567  	}
   568  
   569  	if len(ms) < maxMsgs {
   570  		// Tiny channel, store this.
   571  		s.fewMutex.Lock()
   572  		s.fewMessages[channelID] = struct{}{}
   573  		s.fewMutex.Unlock()
   574  
   575  		return ms, nil
   576  	}
   577  
   578  	// Since the latest messages are at the end and we already know the maxMsgs,
   579  	// we could slice this right away.
   580  	return ms[:maxMsgs], nil
   581  }
   582  
   583  ////
   584  
   585  // Presence checks the state for user presences. If no guildID is given, it will
   586  // look for the presence in all guilds.
   587  func (s *State) Presence(
   588  	guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
   589  
   590  	p, err := s.Store.Presence(guildID, userID)
   591  	if err == nil {
   592  		return p, nil
   593  	}
   594  
   595  	// If there's no guild ID, look in all guilds
   596  	if !guildID.IsValid() {
   597  		g, err := s.Guilds()
   598  		if err != nil {
   599  			return nil, err
   600  		}
   601  
   602  		for _, g := range g {
   603  			if p, err := s.Store.Presence(g.ID, userID); err == nil {
   604  				return p, nil
   605  			}
   606  		}
   607  	}
   608  
   609  	return nil, err
   610  }
   611  
   612  ////
   613  
   614  func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
   615  	r, err := s.Store.Role(guildID, roleID)
   616  	if err == nil {
   617  		return r, nil
   618  	}
   619  
   620  	rs, err := s.Session.Roles(guildID)
   621  	if err != nil {
   622  		return nil, err
   623  	}
   624  
   625  	var role *discord.Role
   626  
   627  	for _, r := range rs {
   628  		r := r
   629  
   630  		if r.ID == roleID {
   631  			role = &r
   632  		}
   633  
   634  		if err := s.RoleSet(guildID, r); err != nil {
   635  			return role, err
   636  		}
   637  	}
   638  
   639  	if role == nil {
   640  		return nil, ErrStoreNotFound
   641  	}
   642  
   643  	return role, nil
   644  }
   645  
   646  func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
   647  	rs, err := s.Store.Roles(guildID)
   648  	if err == nil {
   649  		return rs, nil
   650  	}
   651  
   652  	rs, err = s.Session.Roles(guildID)
   653  	if err != nil {
   654  		return nil, err
   655  	}
   656  
   657  	for _, r := range rs {
   658  		r := r
   659  
   660  		if err := s.RoleSet(guildID, r); err != nil {
   661  			return rs, err
   662  		}
   663  	}
   664  
   665  	return rs, nil
   666  }
   667  
   668  func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
   669  	g, err = s.Session.Guild(id)
   670  	if err == nil {
   671  		err = s.Store.GuildSet(*g)
   672  	}
   673  
   674  	return
   675  }
   676  
   677  func (s *State) fetchMember(
   678  	guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) {
   679  
   680  	m, err = s.Session.Member(guildID, userID)
   681  	if err == nil {
   682  		err = s.Store.MemberSet(guildID, *m)
   683  	}
   684  
   685  	return
   686  }