github.com/diamondburned/arikawa@v1.3.14/state/store_default.go (about)

     1  package state
     2  
     3  import (
     4  	"sync"
     5  
     6  	"github.com/diamondburned/arikawa/discord"
     7  )
     8  
     9  // TODO: make an ExpiryStore
    10  
    11  type DefaultStore struct {
    12  	DefaultStoreOptions
    13  
    14  	self discord.User
    15  
    16  	// includes normal and private
    17  	privates map[discord.ChannelID]discord.Channel
    18  	guilds   map[discord.GuildID]discord.Guild
    19  
    20  	roles       map[discord.GuildID][]discord.Role
    21  	emojis      map[discord.GuildID][]discord.Emoji
    22  	channels    map[discord.GuildID][]discord.Channel
    23  	presences   map[discord.GuildID][]discord.Presence
    24  	voiceStates map[discord.GuildID][]discord.VoiceState
    25  	messages    map[discord.ChannelID][]discord.Message
    26  
    27  	// special case; optimize for lots of members
    28  	members map[discord.GuildID]map[discord.UserID]discord.Member
    29  
    30  	mut sync.RWMutex
    31  }
    32  
    33  type DefaultStoreOptions struct {
    34  	MaxMessages uint // default 50
    35  }
    36  
    37  var _ Store = (*DefaultStore)(nil)
    38  
    39  func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore {
    40  	if opts == nil {
    41  		opts = &DefaultStoreOptions{
    42  			MaxMessages: 50,
    43  		}
    44  	}
    45  
    46  	ds := &DefaultStore{DefaultStoreOptions: *opts}
    47  	ds.Reset()
    48  
    49  	return ds
    50  }
    51  
    52  func (s *DefaultStore) Reset() error {
    53  	s.mut.Lock()
    54  	defer s.mut.Unlock()
    55  
    56  	s.self = discord.User{}
    57  
    58  	s.privates = map[discord.ChannelID]discord.Channel{}
    59  	s.guilds = map[discord.GuildID]discord.Guild{}
    60  
    61  	s.roles = map[discord.GuildID][]discord.Role{}
    62  	s.emojis = map[discord.GuildID][]discord.Emoji{}
    63  	s.channels = map[discord.GuildID][]discord.Channel{}
    64  	s.presences = map[discord.GuildID][]discord.Presence{}
    65  	s.voiceStates = map[discord.GuildID][]discord.VoiceState{}
    66  	s.messages = map[discord.ChannelID][]discord.Message{}
    67  
    68  	s.members = map[discord.GuildID]map[discord.UserID]discord.Member{}
    69  
    70  	return nil
    71  }
    72  
    73  ////
    74  
    75  func (s *DefaultStore) Me() (*discord.User, error) {
    76  	s.mut.RLock()
    77  	defer s.mut.RUnlock()
    78  
    79  	if !s.self.ID.IsValid() {
    80  		return nil, ErrStoreNotFound
    81  	}
    82  
    83  	return &s.self, nil
    84  }
    85  
    86  func (s *DefaultStore) MyselfSet(me discord.User) error {
    87  	s.mut.Lock()
    88  	s.self = me
    89  	s.mut.Unlock()
    90  
    91  	return nil
    92  }
    93  
    94  ////
    95  
    96  func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
    97  	s.mut.RLock()
    98  	defer s.mut.RUnlock()
    99  
   100  	if ch, ok := s.privates[id]; ok {
   101  		// implicit copy
   102  		return &ch, nil
   103  	}
   104  
   105  	for _, chs := range s.channels {
   106  		for _, ch := range chs {
   107  			if ch.ID == id {
   108  				return &ch, nil
   109  			}
   110  		}
   111  	}
   112  
   113  	return nil, ErrStoreNotFound
   114  }
   115  
   116  func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
   117  	s.mut.RLock()
   118  	defer s.mut.RUnlock()
   119  
   120  	chs, ok := s.channels[guildID]
   121  	if !ok {
   122  		return nil, ErrStoreNotFound
   123  	}
   124  
   125  	return append([]discord.Channel{}, chs...), nil
   126  }
   127  
   128  // CreatePrivateChannel searches in the cache for a private channel. It makes no
   129  // API calls.
   130  func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
   131  	s.mut.RLock()
   132  	defer s.mut.RUnlock()
   133  
   134  	// slow way
   135  	for _, ch := range s.privates {
   136  		if ch.Type != discord.DirectMessage || len(ch.DMRecipients) == 0 {
   137  			continue
   138  		}
   139  		if ch.DMRecipients[0].ID == recipient {
   140  			// Return an implicit copy made by range.
   141  			return &ch, nil
   142  		}
   143  	}
   144  	return nil, ErrStoreNotFound
   145  }
   146  
   147  // PrivateChannels returns a list of Direct Message channels randomly ordered.
   148  func (s *DefaultStore) PrivateChannels() ([]discord.Channel, error) {
   149  	s.mut.RLock()
   150  	defer s.mut.RUnlock()
   151  
   152  	var chs = make([]discord.Channel, 0, len(s.privates))
   153  	for i := range s.privates {
   154  		chs = append(chs, s.privates[i])
   155  	}
   156  
   157  	return chs, nil
   158  }
   159  
   160  func (s *DefaultStore) ChannelSet(channel discord.Channel) error {
   161  	s.mut.Lock()
   162  	defer s.mut.Unlock()
   163  
   164  	if !channel.GuildID.IsValid() {
   165  		s.privates[channel.ID] = channel
   166  
   167  	} else {
   168  		chs := s.channels[channel.GuildID]
   169  
   170  		for i, ch := range chs {
   171  			if ch.ID == channel.ID {
   172  				// Also from discordgo.
   173  				if channel.Permissions == nil {
   174  					channel.Permissions = ch.Permissions
   175  				}
   176  
   177  				// Found, just edit
   178  				chs[i] = channel
   179  
   180  				return nil
   181  			}
   182  		}
   183  
   184  		chs = append(chs, channel)
   185  		s.channels[channel.GuildID] = chs
   186  	}
   187  
   188  	return nil
   189  }
   190  
   191  func (s *DefaultStore) ChannelRemove(channel discord.Channel) error {
   192  	s.mut.Lock()
   193  	defer s.mut.Unlock()
   194  
   195  	chs, ok := s.channels[channel.GuildID]
   196  	if !ok {
   197  		return ErrStoreNotFound
   198  	}
   199  
   200  	for i, ch := range chs {
   201  		if ch.ID == channel.ID {
   202  			// Fast unordered delete.
   203  			chs[i] = chs[len(chs)-1]
   204  			chs = chs[:len(chs)-1]
   205  
   206  			s.channels[channel.GuildID] = chs
   207  			return nil
   208  		}
   209  	}
   210  
   211  	return ErrStoreNotFound
   212  }
   213  
   214  ////
   215  
   216  func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
   217  	s.mut.RLock()
   218  	defer s.mut.RUnlock()
   219  
   220  	emojis, ok := s.emojis[guildID]
   221  	if !ok {
   222  		return nil, ErrStoreNotFound
   223  	}
   224  
   225  	for _, emoji := range emojis {
   226  		if emoji.ID == emojiID {
   227  			// Emoji is an implicit copy, so we could do this safely.
   228  			return &emoji, nil
   229  		}
   230  	}
   231  
   232  	return nil, ErrStoreNotFound
   233  }
   234  
   235  func (s *DefaultStore) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
   236  	s.mut.RLock()
   237  	defer s.mut.RUnlock()
   238  
   239  	emojis, ok := s.emojis[guildID]
   240  	if !ok {
   241  		return nil, ErrStoreNotFound
   242  	}
   243  
   244  	return append([]discord.Emoji{}, emojis...), nil
   245  }
   246  
   247  func (s *DefaultStore) EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error {
   248  	s.mut.Lock()
   249  	defer s.mut.Unlock()
   250  
   251  	// A nil slice is acceptable, as we'll make a new slice later on and set it.
   252  	s.emojis[guildID] = emojis
   253  
   254  	return nil
   255  }
   256  
   257  ////
   258  
   259  func (s *DefaultStore) Guild(id discord.GuildID) (*discord.Guild, error) {
   260  	s.mut.RLock()
   261  	defer s.mut.RUnlock()
   262  
   263  	ch, ok := s.guilds[id]
   264  	if !ok {
   265  		return nil, ErrStoreNotFound
   266  	}
   267  
   268  	// implicit copy
   269  	return &ch, nil
   270  }
   271  
   272  func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
   273  	s.mut.RLock()
   274  	defer s.mut.RUnlock()
   275  
   276  	if len(s.guilds) == 0 {
   277  		return nil, ErrStoreNotFound
   278  	}
   279  
   280  	var gs = make([]discord.Guild, 0, len(s.guilds))
   281  	for _, g := range s.guilds {
   282  		gs = append(gs, g)
   283  	}
   284  
   285  	return gs, nil
   286  }
   287  
   288  func (s *DefaultStore) GuildSet(guild discord.Guild) error {
   289  	s.mut.Lock()
   290  	defer s.mut.Unlock()
   291  
   292  	s.guilds[guild.ID] = guild
   293  	return nil
   294  }
   295  
   296  func (s *DefaultStore) GuildRemove(id discord.GuildID) error {
   297  	s.mut.Lock()
   298  	defer s.mut.Unlock()
   299  
   300  	if _, ok := s.guilds[id]; !ok {
   301  		return ErrStoreNotFound
   302  	}
   303  
   304  	delete(s.guilds, id)
   305  	return nil
   306  }
   307  
   308  ////
   309  
   310  func (s *DefaultStore) Member(
   311  	guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
   312  
   313  	s.mut.RLock()
   314  	defer s.mut.RUnlock()
   315  
   316  	ms, ok := s.members[guildID]
   317  	if !ok {
   318  		return nil, ErrStoreNotFound
   319  	}
   320  
   321  	m, ok := ms[userID]
   322  	if ok {
   323  		return &m, nil
   324  	}
   325  
   326  	return nil, ErrStoreNotFound
   327  }
   328  
   329  func (s *DefaultStore) Members(guildID discord.GuildID) ([]discord.Member, error) {
   330  	s.mut.RLock()
   331  	defer s.mut.RUnlock()
   332  
   333  	ms, ok := s.members[guildID]
   334  	if !ok {
   335  		return nil, ErrStoreNotFound
   336  	}
   337  
   338  	var members = make([]discord.Member, 0, len(ms))
   339  	for _, m := range ms {
   340  		members = append(members, m)
   341  	}
   342  
   343  	return members, nil
   344  }
   345  
   346  func (s *DefaultStore) MemberSet(guildID discord.GuildID, member discord.Member) error {
   347  	s.mut.Lock()
   348  	defer s.mut.Unlock()
   349  
   350  	ms, ok := s.members[guildID]
   351  	if !ok {
   352  		ms = make(map[discord.UserID]discord.Member, 1)
   353  	}
   354  
   355  	ms[member.User.ID] = member
   356  	s.members[guildID] = ms
   357  
   358  	return nil
   359  }
   360  
   361  func (s *DefaultStore) MemberRemove(guildID discord.GuildID, userID discord.UserID) error {
   362  	s.mut.Lock()
   363  	defer s.mut.Unlock()
   364  
   365  	ms, ok := s.members[guildID]
   366  	if !ok {
   367  		return ErrStoreNotFound
   368  	}
   369  
   370  	if _, ok := ms[userID]; !ok {
   371  		return ErrStoreNotFound
   372  	}
   373  
   374  	delete(ms, userID)
   375  	return nil
   376  }
   377  
   378  ////
   379  
   380  func (s *DefaultStore) Message(
   381  	channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
   382  
   383  	s.mut.RLock()
   384  	defer s.mut.RUnlock()
   385  
   386  	ms, ok := s.messages[channelID]
   387  	if !ok {
   388  		return nil, ErrStoreNotFound
   389  	}
   390  
   391  	for _, m := range ms {
   392  		if m.ID == messageID {
   393  			return &m, nil
   394  		}
   395  	}
   396  
   397  	return nil, ErrStoreNotFound
   398  }
   399  
   400  func (s *DefaultStore) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
   401  	s.mut.RLock()
   402  	defer s.mut.RUnlock()
   403  
   404  	ms, ok := s.messages[channelID]
   405  	if !ok {
   406  		return nil, ErrStoreNotFound
   407  	}
   408  
   409  	return append([]discord.Message{}, ms...), nil
   410  }
   411  
   412  func (s *DefaultStore) MaxMessages() int {
   413  	return int(s.DefaultStoreOptions.MaxMessages)
   414  }
   415  
   416  func (s *DefaultStore) MessageSet(message discord.Message) error {
   417  	s.mut.Lock()
   418  	defer s.mut.Unlock()
   419  
   420  	ms, ok := s.messages[message.ChannelID]
   421  	if !ok {
   422  		ms = make([]discord.Message, 0, s.MaxMessages()+1)
   423  	}
   424  
   425  	// Check if we already have the message.
   426  	for i, m := range ms {
   427  		if m.ID == message.ID {
   428  			DiffMessage(message, &m)
   429  			ms[i] = m
   430  			return nil
   431  		}
   432  	}
   433  
   434  	// Order: latest to earliest, similar to the API.
   435  
   436  	var end = len(ms)
   437  	if max := s.MaxMessages(); end >= max {
   438  		// If the end (length) is larger than the maximum amount, then cap it.
   439  		end = max
   440  	} else {
   441  		// Else, append an empty message to the end.
   442  		ms = append(ms, discord.Message{})
   443  		// Increment to update the length.
   444  		end++
   445  	}
   446  
   447  	// Copy hack to prepend. This copies the 0th-(end-1)th entries to
   448  	// 1st-endth.
   449  	copy(ms[1:end], ms[0:end-1])
   450  	// Then, set the 0th entry.
   451  	ms[0] = message
   452  
   453  	s.messages[message.ChannelID] = ms
   454  	return nil
   455  }
   456  
   457  func (s *DefaultStore) MessageRemove(
   458  	channelID discord.ChannelID, messageID discord.MessageID) error {
   459  
   460  	s.mut.Lock()
   461  	defer s.mut.Unlock()
   462  
   463  	ms, ok := s.messages[channelID]
   464  	if !ok {
   465  		return ErrStoreNotFound
   466  	}
   467  
   468  	for i, m := range ms {
   469  		if m.ID == messageID {
   470  			ms = append(ms[:i], ms[i+1:]...)
   471  			s.messages[channelID] = ms
   472  			return nil
   473  		}
   474  	}
   475  
   476  	return ErrStoreNotFound
   477  }
   478  
   479  ////
   480  
   481  func (s *DefaultStore) Presence(
   482  	guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
   483  
   484  	s.mut.RLock()
   485  	defer s.mut.RUnlock()
   486  
   487  	ps, ok := s.presences[guildID]
   488  	if !ok {
   489  		return nil, ErrStoreNotFound
   490  	}
   491  
   492  	for _, p := range ps {
   493  		if p.User.ID == userID {
   494  			return &p, nil
   495  		}
   496  	}
   497  
   498  	return nil, ErrStoreNotFound
   499  }
   500  
   501  func (s *DefaultStore) Presences(guildID discord.GuildID) ([]discord.Presence, error) {
   502  	s.mut.RLock()
   503  	defer s.mut.RUnlock()
   504  
   505  	ps, ok := s.presences[guildID]
   506  	if !ok {
   507  		return nil, ErrStoreNotFound
   508  	}
   509  
   510  	return append([]discord.Presence{}, ps...), nil
   511  }
   512  
   513  func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence discord.Presence) error {
   514  	s.mut.Lock()
   515  	defer s.mut.Unlock()
   516  
   517  	ps, _ := s.presences[guildID]
   518  
   519  	for i, p := range ps {
   520  		if p.User.ID == presence.User.ID {
   521  			// Change the backing array.
   522  			ps[i] = presence
   523  			return nil
   524  		}
   525  	}
   526  
   527  	ps = append(ps, presence)
   528  	s.presences[guildID] = ps
   529  	return nil
   530  }
   531  
   532  func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.UserID) error {
   533  	s.mut.Lock()
   534  	defer s.mut.Unlock()
   535  
   536  	ps, ok := s.presences[guildID]
   537  	if !ok {
   538  		return ErrStoreNotFound
   539  	}
   540  
   541  	for i, p := range ps {
   542  		if p.User.ID == userID {
   543  			ps[i] = ps[len(ps)-1]
   544  			ps = ps[:len(ps)-1]
   545  
   546  			s.presences[guildID] = ps
   547  			return nil
   548  		}
   549  	}
   550  
   551  	return ErrStoreNotFound
   552  }
   553  
   554  ////
   555  
   556  func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
   557  	s.mut.RLock()
   558  	defer s.mut.RUnlock()
   559  
   560  	rs, ok := s.roles[guildID]
   561  	if !ok {
   562  		return nil, ErrStoreNotFound
   563  	}
   564  
   565  	for _, r := range rs {
   566  		if r.ID == roleID {
   567  			return &r, nil
   568  		}
   569  	}
   570  
   571  	return nil, ErrStoreNotFound
   572  }
   573  
   574  func (s *DefaultStore) Roles(guildID discord.GuildID) ([]discord.Role, error) {
   575  	s.mut.RLock()
   576  	defer s.mut.RUnlock()
   577  
   578  	rs, ok := s.roles[guildID]
   579  	if !ok {
   580  		return nil, ErrStoreNotFound
   581  	}
   582  
   583  	return append([]discord.Role{}, rs...), nil
   584  }
   585  
   586  func (s *DefaultStore) RoleSet(guildID discord.GuildID, role discord.Role) error {
   587  	s.mut.Lock()
   588  	defer s.mut.Unlock()
   589  
   590  	// A nil slice is fine, since we can just append the role.
   591  	rs, _ := s.roles[guildID]
   592  
   593  	for i, r := range rs {
   594  		if r.ID == role.ID {
   595  			// This changes the backing array, so we don't need to reset the
   596  			// slice.
   597  			rs[i] = role
   598  			return nil
   599  		}
   600  	}
   601  
   602  	rs = append(rs, role)
   603  	s.roles[guildID] = rs
   604  	return nil
   605  }
   606  
   607  func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error {
   608  	s.mut.Lock()
   609  	defer s.mut.Unlock()
   610  
   611  	rs, ok := s.roles[guildID]
   612  	if !ok {
   613  		return ErrStoreNotFound
   614  	}
   615  
   616  	for i, r := range rs {
   617  		if r.ID == roleID {
   618  			// Fast delete.
   619  			rs[i] = rs[len(rs)-1]
   620  			rs = rs[:len(rs)-1]
   621  
   622  			s.roles[guildID] = rs
   623  			return nil
   624  		}
   625  	}
   626  
   627  	return ErrStoreNotFound
   628  }
   629  
   630  ////
   631  
   632  func (s *DefaultStore) VoiceState(
   633  	guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) {
   634  
   635  	s.mut.RLock()
   636  	defer s.mut.RUnlock()
   637  
   638  	states, ok := s.voiceStates[guildID]
   639  	if !ok {
   640  		return nil, ErrStoreNotFound
   641  	}
   642  
   643  	for _, vs := range states {
   644  		if vs.UserID == userID {
   645  			return &vs, nil
   646  		}
   647  	}
   648  
   649  	return nil, ErrStoreNotFound
   650  }
   651  
   652  func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) {
   653  	s.mut.RLock()
   654  	defer s.mut.RUnlock()
   655  
   656  	states, ok := s.voiceStates[guildID]
   657  	if !ok {
   658  		return nil, ErrStoreNotFound
   659  	}
   660  
   661  	return append([]discord.VoiceState{}, states...), nil
   662  }
   663  
   664  func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error {
   665  	s.mut.Lock()
   666  	defer s.mut.Unlock()
   667  
   668  	states, _ := s.voiceStates[guildID]
   669  
   670  	for i, vs := range states {
   671  		if vs.UserID == voiceState.UserID {
   672  			// change the backing array
   673  			states[i] = voiceState
   674  			return nil
   675  		}
   676  	}
   677  
   678  	states = append(states, voiceState)
   679  	s.voiceStates[guildID] = states
   680  	return nil
   681  }
   682  
   683  func (s *DefaultStore) VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error {
   684  	s.mut.Lock()
   685  	defer s.mut.Unlock()
   686  
   687  	states, ok := s.voiceStates[guildID]
   688  	if !ok {
   689  		return ErrStoreNotFound
   690  	}
   691  
   692  	for i, vs := range states {
   693  		if vs.UserID == userID {
   694  			states = append(states[:i], states[i+1:]...)
   695  			s.voiceStates[guildID] = states
   696  
   697  			return nil
   698  		}
   699  	}
   700  
   701  	return ErrStoreNotFound
   702  }