github.com/diamondburned/arikawa/v2@v2.1.0/state/store/defaultstore/message.go (about)

     1  package defaultstore
     2  
     3  import (
     4  	"sync"
     5  
     6  	"github.com/diamondburned/arikawa/v2/discord"
     7  	"github.com/diamondburned/arikawa/v2/internal/moreatomic"
     8  	"github.com/diamondburned/arikawa/v2/state/store"
     9  )
    10  
    11  type Message struct {
    12  	channels moreatomic.Map
    13  	maxMsgs  int
    14  }
    15  
    16  var _ store.MessageStore = (*Message)(nil)
    17  
    18  type messages struct {
    19  	mut      sync.Mutex
    20  	messages []discord.Message
    21  }
    22  
    23  func NewMessage(maxMsgs int) *Message {
    24  	return &Message{
    25  		channels: *moreatomic.NewMap(func() interface{} {
    26  			return &messages{
    27  				messages: []discord.Message{}, // never use a nil slice
    28  			}
    29  		}),
    30  		maxMsgs: maxMsgs,
    31  	}
    32  }
    33  
    34  func (s *Message) Reset() error {
    35  	return s.channels.Reset()
    36  }
    37  
    38  func (s *Message) Message(chID discord.ChannelID, mID discord.MessageID) (*discord.Message, error) {
    39  	iv, ok := s.channels.Load(chID)
    40  	if !ok {
    41  		return nil, store.ErrNotFound
    42  	}
    43  
    44  	msgs := iv.(*messages)
    45  
    46  	msgs.mut.Lock()
    47  	defer msgs.mut.Unlock()
    48  
    49  	for _, m := range msgs.messages {
    50  		if m.ID == mID {
    51  			return &m, nil
    52  		}
    53  	}
    54  
    55  	return nil, store.ErrNotFound
    56  }
    57  
    58  func (s *Message) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
    59  	iv, ok := s.channels.Load(channelID)
    60  	if !ok {
    61  		return nil, store.ErrNotFound
    62  	}
    63  
    64  	msgs := iv.(*messages)
    65  
    66  	msgs.mut.Lock()
    67  	defer msgs.mut.Unlock()
    68  
    69  	return append([]discord.Message(nil), msgs.messages...), nil
    70  }
    71  
    72  func (s *Message) MaxMessages() int {
    73  	return s.maxMsgs
    74  }
    75  
    76  func (s *Message) MessageSet(message discord.Message) error {
    77  	iv, _ := s.channels.LoadOrStore(message.ChannelID)
    78  
    79  	msgs := iv.(*messages)
    80  
    81  	msgs.mut.Lock()
    82  	defer msgs.mut.Unlock()
    83  
    84  	for i, m := range msgs.messages {
    85  		if m.ID == message.ID {
    86  			DiffMessage(message, &m)
    87  			msgs.messages[i] = m
    88  			return nil
    89  		}
    90  	}
    91  
    92  	// Order: latest to earliest, similar to the API.
    93  
    94  	// Check if we already have the message. Try to derive the order otherwise.
    95  	var insertAt int
    96  
    97  	// Since we make the order guarantee ourselves, we can trust that we're
    98  	// iterating from latest to earliest.
    99  	for insertAt < len(msgs.messages) {
   100  		// Check if the new message is older. If it is, then we should insert it
   101  		// right after this message (or before this message in the list; i-1).
   102  		if message.ID > msgs.messages[insertAt].ID {
   103  			break
   104  		}
   105  
   106  		insertAt++
   107  	}
   108  
   109  	end := len(msgs.messages)
   110  	max := s.MaxMessages()
   111  
   112  	if end == max {
   113  		// If insertAt is larger than the length, then the message is older than
   114  		// every other messages we have. We have to discard this message here,
   115  		// since the store is already full.
   116  		if insertAt == end {
   117  			return nil
   118  		}
   119  
   120  		// If the end (length) is approaching the maximum amount, then cap it.
   121  		end = max
   122  	} else {
   123  		// Else, append an empty message to the end.
   124  		msgs.messages = append(msgs.messages, discord.Message{})
   125  		// Increment to update the length.
   126  		end++
   127  	}
   128  
   129  	// Shift the slice right-wards if the current item is not the last.
   130  	if start := insertAt + 1; start < end {
   131  		copy(msgs.messages[insertAt+1:], msgs.messages[insertAt:end-1])
   132  	}
   133  
   134  	// Then, set the nth entry.
   135  	msgs.messages[insertAt] = message
   136  
   137  	return nil
   138  }
   139  
   140  // DiffMessage fills non-empty fields from src to dst.
   141  func DiffMessage(src discord.Message, dst *discord.Message) {
   142  	// Thanks, Discord.
   143  	if src.Content != "" {
   144  		dst.Content = src.Content
   145  	}
   146  	if src.EditedTimestamp.IsValid() {
   147  		dst.EditedTimestamp = src.EditedTimestamp
   148  	}
   149  	if src.Mentions != nil {
   150  		dst.Mentions = src.Mentions
   151  	}
   152  	if src.Embeds != nil {
   153  		dst.Embeds = src.Embeds
   154  	}
   155  	if src.Attachments != nil {
   156  		dst.Attachments = src.Attachments
   157  	}
   158  	if src.Timestamp.IsValid() {
   159  		dst.Timestamp = src.Timestamp
   160  	}
   161  	if src.Author.ID.IsValid() {
   162  		dst.Author = src.Author
   163  	}
   164  	if src.Reactions != nil {
   165  		dst.Reactions = src.Reactions
   166  	}
   167  }
   168  
   169  func (s *Message) MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error {
   170  	iv, ok := s.channels.Load(channelID)
   171  	if !ok {
   172  		return nil
   173  	}
   174  
   175  	msgs := iv.(*messages)
   176  
   177  	msgs.mut.Lock()
   178  	defer msgs.mut.Unlock()
   179  
   180  	for i, m := range msgs.messages {
   181  		if m.ID == messageID {
   182  			msgs.messages = append(msgs.messages[:i], msgs.messages[i+1:]...)
   183  			return nil
   184  		}
   185  	}
   186  
   187  	return nil
   188  }