github.com/diamondburned/arikawa/v2@v2.1.0/state/store/defaultstore/channel.go (about)

     1  package defaultstore
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  
     7  	"github.com/diamondburned/arikawa/v2/discord"
     8  	"github.com/diamondburned/arikawa/v2/state/store"
     9  )
    10  
    11  type Channel struct {
    12  	mut sync.RWMutex
    13  
    14  	// Channel references must be protected under the same mutex.
    15  
    16  	privates   map[discord.UserID]*discord.Channel
    17  	privateChs []*discord.Channel
    18  
    19  	channels map[discord.ChannelID]*discord.Channel
    20  	guildChs map[discord.GuildID][]*discord.Channel
    21  }
    22  
    23  var _ store.ChannelStore = (*Channel)(nil)
    24  
    25  func NewChannel() *Channel {
    26  	return &Channel{
    27  		privates: map[discord.UserID]*discord.Channel{},
    28  		channels: map[discord.ChannelID]*discord.Channel{},
    29  		guildChs: map[discord.GuildID][]*discord.Channel{},
    30  	}
    31  }
    32  
    33  func (s *Channel) Reset() error {
    34  	s.mut.Lock()
    35  	defer s.mut.Unlock()
    36  
    37  	s.privates = map[discord.UserID]*discord.Channel{}
    38  	s.channels = map[discord.ChannelID]*discord.Channel{}
    39  	s.guildChs = map[discord.GuildID][]*discord.Channel{}
    40  
    41  	return nil
    42  }
    43  
    44  func (s *Channel) Channel(id discord.ChannelID) (*discord.Channel, error) {
    45  	s.mut.RLock()
    46  	defer s.mut.RUnlock()
    47  
    48  	ch, ok := s.channels[id]
    49  	if !ok {
    50  		return nil, store.ErrNotFound
    51  	}
    52  
    53  	cpy := *ch
    54  	return &cpy, nil
    55  }
    56  
    57  func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
    58  	s.mut.RLock()
    59  	defer s.mut.RUnlock()
    60  
    61  	ch, ok := s.privates[recipient]
    62  	if !ok {
    63  		return nil, store.ErrNotFound
    64  	}
    65  
    66  	cpy := *ch
    67  	return &cpy, nil
    68  }
    69  
    70  // Channels returns a list of Guild channels randomly ordered.
    71  func (s *Channel) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
    72  	s.mut.RLock()
    73  	defer s.mut.RUnlock()
    74  
    75  	chRefs, ok := s.guildChs[guildID]
    76  	if !ok {
    77  		return nil, store.ErrNotFound
    78  	}
    79  
    80  	// Reading chRefs is also covered by the global mutex.
    81  
    82  	var channels = make([]discord.Channel, len(chRefs))
    83  	for i, chRef := range chRefs {
    84  		channels[i] = *chRef
    85  	}
    86  
    87  	return channels, nil
    88  }
    89  
    90  // PrivateChannels returns a list of Direct Message channels randomly ordered.
    91  func (s *Channel) PrivateChannels() ([]discord.Channel, error) {
    92  	s.mut.RLock()
    93  	defer s.mut.RUnlock()
    94  
    95  	if len(s.privateChs) == 0 {
    96  		return nil, store.ErrNotFound
    97  	}
    98  
    99  	var channels = make([]discord.Channel, len(s.privateChs))
   100  	for i, ch := range s.privateChs {
   101  		channels[i] = *ch
   102  	}
   103  
   104  	return channels, nil
   105  }
   106  
   107  // ChannelSet sets the Direct Message or Guild channel into the state.
   108  func (s *Channel) ChannelSet(channel discord.Channel) error {
   109  	s.mut.Lock()
   110  	defer s.mut.Unlock()
   111  
   112  	// Update the reference if we can.
   113  	if ch, ok := s.channels[channel.ID]; ok {
   114  		*ch = channel
   115  		return nil
   116  	}
   117  
   118  	switch channel.Type {
   119  	case discord.DirectMessage:
   120  		// Safety bound check.
   121  		if len(channel.DMRecipients) != 1 {
   122  			return errors.New("DirectMessage channel does not have 1 recipient")
   123  		}
   124  		s.privates[channel.DMRecipients[0].ID] = &channel
   125  		fallthrough
   126  	case discord.GroupDM:
   127  		s.privateChs = append(s.privateChs, &channel)
   128  		s.channels[channel.ID] = &channel
   129  		return nil
   130  	}
   131  
   132  	// Ensure that if the channel is not a DM or group DM channel, then it must
   133  	// have a valid guild ID.
   134  	if !channel.GuildID.IsValid() {
   135  		return errors.New("invalid guildID for guild channel")
   136  	}
   137  
   138  	s.channels[channel.ID] = &channel
   139  
   140  	channels, _ := s.guildChs[channel.GuildID]
   141  	channels = append(channels, &channel)
   142  	s.guildChs[channel.GuildID] = channels
   143  
   144  	return nil
   145  }
   146  
   147  func (s *Channel) ChannelRemove(channel discord.Channel) error {
   148  	s.mut.Lock()
   149  	defer s.mut.Unlock()
   150  
   151  	// Wipe the channel off the channel ID index.
   152  	delete(s.channels, channel.ID)
   153  
   154  	// Wipe the channel off the DM recipient index, if available.
   155  	switch channel.Type {
   156  	case discord.DirectMessage:
   157  		// Safety bound check.
   158  		if len(channel.DMRecipients) != 1 {
   159  			return errors.New("DirectMessage channel does not have 1 recipient")
   160  		}
   161  		delete(s.privates, channel.DMRecipients[0].ID)
   162  		fallthrough
   163  	case discord.GroupDM:
   164  		for i, priv := range s.privateChs {
   165  			if priv.ID == channel.ID {
   166  				s.privateChs = removeChannel(s.privateChs, i)
   167  				break
   168  			}
   169  		}
   170  		return nil
   171  	}
   172  
   173  	// Wipe the channel off the guilds index, if available.
   174  	channels, ok := s.guildChs[channel.GuildID]
   175  	if !ok {
   176  		return nil
   177  	}
   178  
   179  	for i, ch := range channels {
   180  		if ch.ID == channel.ID {
   181  			s.guildChs[channel.GuildID] = removeChannel(channels, i)
   182  			break
   183  		}
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  // removeChannel removes the given channel with the index from the given
   190  // channels slice in an unordered fashion.
   191  func removeChannel(channels []*discord.Channel, i int) []*discord.Channel {
   192  	// Fast unordered delete. Not sure if there's a benefit in doing
   193  	// this over using a map, but I guess the memory usage is less and
   194  	// there's no copying.
   195  
   196  	// Move the last channel to the current channel, set the last
   197  	// channel there to a nil value to unreference its children, then
   198  	// slice the last channel off.
   199  	channels[i] = channels[len(channels)-1]
   200  	channels[len(channels)-1] = nil
   201  	channels = channels[:len(channels)-1]
   202  
   203  	return channels
   204  }