github.com/diamondburned/arikawa@v1.3.14/state/state.go (about) 1 // Package state provides interfaces for a local or remote state, as well as 2 // abstractions around the REST API and Gateway events. 3 package state 4 5 import ( 6 "context" 7 "sync" 8 9 "github.com/diamondburned/arikawa/discord" 10 "github.com/diamondburned/arikawa/gateway" 11 "github.com/diamondburned/arikawa/internal/moreatomic" 12 "github.com/diamondburned/arikawa/session" 13 "github.com/diamondburned/arikawa/utils/handler" 14 15 "github.com/pkg/errors" 16 ) 17 18 var ( 19 MaxFetchMembers uint = 1000 20 MaxFetchGuilds uint = 10 21 ) 22 23 // State is the cache to store events coming from Discord as well as data from 24 // API calls. 25 // 26 // Store 27 // 28 // The state basically provides abstractions on top of the API and the state 29 // storage (Store). The state storage is effectively a set of interfaces which 30 // allow arbitrary backends to be implemented. 31 // 32 // The default storage backend is a typical in-memory structure consisting of 33 // maps and slices. Custom backend implementations could embed this storage 34 // backend as an in-memory fallback. A good example of this would be embedding 35 // the default store for messages only, while handling everything else in Redis. 36 // 37 // The package also provides a no-op store (NoopStore) that implementations 38 // could embed. This no-op store will always return an error, which makes the 39 // state fetch information from the API. The setters are all no-ops, so the 40 // fetched data won't be updated. 41 // 42 // Handler 43 // 44 // The state uses its own handler over session's to make all handlers run after 45 // the state updates itself. A PreHandler is exposed in any case the user needs 46 // the handlers to run before the state updates itself. Refer to that field's 47 // documentation. 48 // 49 // The state also provides extra events and overrides to make up for Discord's 50 // inconsistencies in data. The following are known instances of such. 51 // 52 // The Guild Create event is split up to make the state's Guild Available, Guild 53 // Ready and Guild Join events. Refer to these events' documentations for more 54 // information. 55 // 56 // The Message Create and Message Update events with the Member field provided 57 // will have the User field copied from Author. This is because the User field 58 // will be empty, while the Member structure expects it to be there. 59 type State struct { 60 *session.Session 61 Store 62 63 // *: State doesn't actually keep track of pinned messages. 64 65 // Ready is not updated by the state. 66 Ready gateway.ReadyEvent 67 68 // StateLog logs all errors that come from the state cache. This includes 69 // not found errors. Defaults to a no-op, as state errors aren't that 70 // important. 71 StateLog func(error) 72 73 // PreHandler is the manual hook that is executed before the State handler 74 // is. This should only be used for low-level operations. 75 // It's recommended to set Synchronous to true if you mutate the events. 76 PreHandler *handler.Handler // default nil 77 78 // Command handler with inherited methods. Ran after PreHandler. You should 79 // most of the time use this instead of Session's, to avoid race conditions 80 // with the State. 81 *handler.Handler 82 83 // List of channels with few messages, so it doesn't bother hitting the API 84 // again. 85 fewMessages map[discord.ChannelID]struct{} 86 fewMutex *sync.Mutex 87 88 // unavailableGuilds is a set of discord.GuildIDs of guilds that became 89 // unavailable when already connected to the gateway, i.e. sent in a 90 // GuildUnavailableEvent. 91 unavailableGuilds *moreatomic.GuildIDSet 92 // unreadyGuilds is a set of discord.GuildIDs of guilds that were 93 // unavailable when connecting to the gateway, i.e. they had Unavailable 94 // set to true during Ready. 95 unreadyGuilds *moreatomic.GuildIDSet 96 } 97 98 // New creates a new state. 99 func New(token string) (*State, error) { 100 return NewWithStore(token, NewDefaultStore(nil)) 101 } 102 103 // NewWithIntents creates a new state with the given gateway intents. For more 104 // information, refer to gateway.Intents. 105 func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) { 106 s, err := session.NewWithIntents(token, intents...) 107 if err != nil { 108 return nil, err 109 } 110 111 return NewFromSession(s, NewDefaultStore(nil)) 112 } 113 114 func NewWithStore(token string, store Store) (*State, error) { 115 s, err := session.New(token) 116 if err != nil { 117 return nil, err 118 } 119 120 return NewFromSession(s, store) 121 } 122 123 // NewFromSession never returns an error. This API is kept for backwards 124 // compatibility. 125 func NewFromSession(s *session.Session, store Store) (*State, error) { 126 state := &State{ 127 Session: s, 128 Store: store, 129 Handler: handler.New(), 130 StateLog: func(err error) {}, 131 fewMessages: map[discord.ChannelID]struct{}{}, 132 fewMutex: new(sync.Mutex), 133 unavailableGuilds: moreatomic.NewGuildIDSet(), 134 unreadyGuilds: moreatomic.NewGuildIDSet(), 135 } 136 state.hookSession() 137 return state, nil 138 } 139 140 // WithContext returns a shallow copy of State with the context replaced in the 141 // API client. All methods called on the State will use this given context. This 142 // method is thread-safe. 143 func (s *State) WithContext(ctx context.Context) *State { 144 copied := *s 145 copied.Client = copied.Client.WithContext(ctx) 146 147 return &copied 148 } 149 150 //// Helper methods 151 152 func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string { 153 if !message.GuildID.IsValid() { 154 return message.Author.Username 155 } 156 157 if message.Member != nil { 158 if message.Member.Nick != "" { 159 return message.Member.Nick 160 } 161 return message.Author.Username 162 } 163 164 n, err := s.MemberDisplayName(message.GuildID, message.Author.ID) 165 if err != nil { 166 return message.Author.Username 167 } 168 169 return n 170 } 171 172 func (s *State) MemberDisplayName(guildID discord.GuildID, userID discord.UserID) (string, error) { 173 member, err := s.Member(guildID, userID) 174 if err != nil { 175 return "", err 176 } 177 178 if member.Nick == "" { 179 return member.User.Username, nil 180 } 181 182 return member.Nick, nil 183 } 184 185 func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color, error) { 186 if !message.GuildID.IsValid() { // this is a dm 187 return discord.DefaultMemberColor, nil 188 } 189 190 if message.Member != nil { 191 guild, err := s.Guild(message.GuildID) 192 if err != nil { 193 return 0, err 194 } 195 return discord.MemberColor(*guild, *message.Member), nil 196 } 197 198 return s.MemberColor(message.GuildID, message.Author.ID) 199 } 200 201 func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) { 202 var wg sync.WaitGroup 203 204 g, gerr := s.Store.Guild(guildID) 205 m, merr := s.Store.Member(guildID, userID) 206 207 switch { 208 case gerr != nil && merr != nil: 209 wg.Add(1) 210 go func() { 211 g, gerr = s.fetchGuild(guildID) 212 wg.Done() 213 }() 214 215 m, merr = s.fetchMember(guildID, userID) 216 case gerr != nil: 217 g, gerr = s.fetchGuild(guildID) 218 case merr != nil: 219 m, merr = s.fetchMember(guildID, userID) 220 } 221 222 wg.Wait() 223 224 if gerr != nil { 225 return 0, errors.Wrap(merr, "failed to get guild") 226 } 227 if merr != nil { 228 return 0, errors.Wrap(merr, "failed to get member") 229 } 230 231 return discord.MemberColor(*g, *m), nil 232 } 233 234 //// 235 236 func (s *State) Permissions( 237 channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) { 238 239 ch, err := s.Channel(channelID) 240 if err != nil { 241 return 0, errors.Wrap(err, "failed to get channel") 242 } 243 244 var wg sync.WaitGroup 245 246 g, gerr := s.Store.Guild(ch.GuildID) 247 m, merr := s.Store.Member(ch.GuildID, userID) 248 249 switch { 250 case gerr != nil && merr != nil: 251 wg.Add(1) 252 go func() { 253 g, gerr = s.fetchGuild(ch.GuildID) 254 wg.Done() 255 }() 256 257 m, merr = s.fetchMember(ch.GuildID, userID) 258 case gerr != nil: 259 g, gerr = s.fetchGuild(ch.GuildID) 260 case merr != nil: 261 m, merr = s.fetchMember(ch.GuildID, userID) 262 } 263 264 wg.Wait() 265 266 if gerr != nil { 267 return 0, errors.Wrap(merr, "failed to get guild") 268 } 269 if merr != nil { 270 return 0, errors.Wrap(merr, "failed to get member") 271 } 272 273 return discord.CalcOverwrites(*g, *ch, *m), nil 274 } 275 276 //// 277 278 func (s *State) Me() (*discord.User, error) { 279 u, err := s.Store.Me() 280 if err == nil { 281 return u, nil 282 } 283 284 u, err = s.Session.Me() 285 if err != nil { 286 return nil, err 287 } 288 289 return u, s.Store.MyselfSet(*u) 290 } 291 292 //// 293 294 func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) { 295 c, err := s.Store.Channel(id) 296 if err == nil { 297 return c, nil 298 } 299 300 c, err = s.Session.Channel(id) 301 if err != nil { 302 return nil, err 303 } 304 305 return c, s.Store.ChannelSet(*c) 306 } 307 308 func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) { 309 c, err := s.Store.Channels(guildID) 310 if err == nil { 311 return c, nil 312 } 313 314 c, err = s.Session.Channels(guildID) 315 if err != nil { 316 return nil, err 317 } 318 319 for _, ch := range c { 320 ch := ch 321 322 if err := s.Store.ChannelSet(ch); err != nil { 323 return nil, err 324 } 325 } 326 327 return c, nil 328 } 329 330 func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { 331 c, err := s.Store.CreatePrivateChannel(recipient) 332 if err == nil { 333 return c, nil 334 } 335 336 c, err = s.Session.CreatePrivateChannel(recipient) 337 if err != nil { 338 return nil, err 339 } 340 341 return c, s.Store.ChannelSet(*c) 342 } 343 344 func (s *State) PrivateChannels() ([]discord.Channel, error) { 345 c, err := s.Store.PrivateChannels() 346 if err == nil { 347 return c, nil 348 } 349 350 c, err = s.Session.PrivateChannels() 351 if err != nil { 352 return nil, err 353 } 354 355 for _, ch := range c { 356 ch := ch 357 358 if err := s.Store.ChannelSet(ch); err != nil { 359 return nil, err 360 } 361 } 362 363 return c, nil 364 } 365 366 //// 367 368 func (s *State) Emoji( 369 guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) { 370 371 e, err := s.Store.Emoji(guildID, emojiID) 372 if err == nil { 373 return e, nil 374 } 375 376 es, err := s.Session.Emojis(guildID) 377 if err != nil { 378 return nil, err 379 } 380 381 if err := s.Store.EmojiSet(guildID, es); err != nil { 382 return nil, err 383 } 384 385 for _, e := range es { 386 if e.ID == emojiID { 387 return &e, nil 388 } 389 } 390 391 return nil, ErrStoreNotFound 392 } 393 394 func (s *State) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) { 395 e, err := s.Store.Emojis(guildID) 396 if err == nil { 397 return e, nil 398 } 399 400 es, err := s.Session.Emojis(guildID) 401 if err != nil { 402 return nil, err 403 } 404 405 return es, s.Store.EmojiSet(guildID, es) 406 } 407 408 //// 409 410 func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { 411 c, err := s.Store.Guild(id) 412 if err == nil { 413 return c, nil 414 } 415 416 return s.fetchGuild(id) 417 } 418 419 // Guilds will only fill a maximum of 100 guilds from the API. 420 func (s *State) Guilds() ([]discord.Guild, error) { 421 c, err := s.Store.Guilds() 422 if err == nil { 423 return c, nil 424 } 425 426 c, err = s.Session.Guilds(MaxFetchGuilds) 427 if err != nil { 428 return nil, err 429 } 430 431 for _, ch := range c { 432 ch := ch 433 434 if err := s.Store.GuildSet(ch); err != nil { 435 return nil, err 436 } 437 } 438 439 return c, nil 440 } 441 442 //// 443 444 func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { 445 m, err := s.Store.Member(guildID, userID) 446 if err == nil { 447 return m, nil 448 } 449 450 return s.fetchMember(guildID, userID) 451 } 452 453 func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) { 454 ms, err := s.Store.Members(guildID) 455 if err == nil { 456 return ms, nil 457 } 458 459 ms, err = s.Session.Members(guildID, MaxFetchMembers) 460 if err != nil { 461 return nil, err 462 } 463 464 for _, m := range ms { 465 if err := s.Store.MemberSet(guildID, m); err != nil { 466 return nil, err 467 } 468 } 469 470 return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{ 471 GuildID: []discord.GuildID{guildID}, 472 Presences: true, 473 }) 474 } 475 476 //// 477 478 func (s *State) Message( 479 channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { 480 481 m, err := s.Store.Message(channelID, messageID) 482 if err == nil { 483 return m, nil 484 } 485 486 var wg sync.WaitGroup 487 488 c, cerr := s.Store.Channel(channelID) 489 if cerr != nil { 490 wg.Add(1) 491 go func() { 492 c, cerr = s.Session.Channel(channelID) 493 if cerr == nil { 494 cerr = s.Store.ChannelSet(*c) 495 } 496 497 wg.Done() 498 }() 499 } 500 501 m, err = s.Session.Message(channelID, messageID) 502 if err != nil { 503 return nil, errors.Wrap(err, "unable to fetch message") 504 } 505 506 wg.Wait() 507 508 if cerr != nil { 509 return nil, errors.Wrap(cerr, "unable to fetch channel") 510 } 511 512 m.ChannelID = c.ID 513 m.GuildID = c.GuildID 514 515 return m, s.Store.MessageSet(*m) 516 } 517 518 // Messages fetches maximum 100 messages from the API, if it has to. There is no 519 // limit if it's from the State storage. 520 func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) { 521 // TODO: Think of a design that doesn't rely on MaxMessages(). 522 var maxMsgs = s.MaxMessages() 523 524 ms, err := s.Store.Messages(channelID) 525 if err == nil { 526 // If the state already has as many messages as it can, skip the API. 527 if maxMsgs <= len(ms) { 528 return ms, nil 529 } 530 531 // Is the channel tiny? 532 s.fewMutex.Lock() 533 if _, ok := s.fewMessages[channelID]; ok { 534 s.fewMutex.Unlock() 535 return ms, nil 536 } 537 538 // No, fetch from the state. 539 s.fewMutex.Unlock() 540 } 541 542 ms, err = s.Session.Messages(channelID, uint(maxMsgs)) 543 if err != nil { 544 return nil, err 545 } 546 547 // New messages fetched weirdly does not have GuildID filled. We'll try and 548 // get it for consistency with incoming message creates. 549 var guildID discord.GuildID 550 551 // A bit too convoluted, but whatever. 552 c, err := s.Channel(channelID) 553 if err == nil { 554 // If it's 0, it's 0 anyway. We don't need a check here. 555 guildID = c.GuildID 556 } 557 558 // Iterate in reverse, since the store is expected to prepend the latest 559 // messages. 560 for i := len(ms) - 1; i >= 0; i-- { 561 // Set the guild ID, fine if it's 0 (it's already 0 anyway). 562 ms[i].GuildID = guildID 563 564 if err := s.Store.MessageSet(ms[i]); err != nil { 565 return nil, err 566 } 567 } 568 569 if len(ms) < maxMsgs { 570 // Tiny channel, store this. 571 s.fewMutex.Lock() 572 s.fewMessages[channelID] = struct{}{} 573 s.fewMutex.Unlock() 574 575 return ms, nil 576 } 577 578 // Since the latest messages are at the end and we already know the maxMsgs, 579 // we could slice this right away. 580 return ms[:maxMsgs], nil 581 } 582 583 //// 584 585 // Presence checks the state for user presences. If no guildID is given, it will 586 // look for the presence in all guilds. 587 func (s *State) Presence( 588 guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { 589 590 p, err := s.Store.Presence(guildID, userID) 591 if err == nil { 592 return p, nil 593 } 594 595 // If there's no guild ID, look in all guilds 596 if !guildID.IsValid() { 597 g, err := s.Guilds() 598 if err != nil { 599 return nil, err 600 } 601 602 for _, g := range g { 603 if p, err := s.Store.Presence(g.ID, userID); err == nil { 604 return p, nil 605 } 606 } 607 } 608 609 return nil, err 610 } 611 612 //// 613 614 func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) { 615 r, err := s.Store.Role(guildID, roleID) 616 if err == nil { 617 return r, nil 618 } 619 620 rs, err := s.Session.Roles(guildID) 621 if err != nil { 622 return nil, err 623 } 624 625 var role *discord.Role 626 627 for _, r := range rs { 628 r := r 629 630 if r.ID == roleID { 631 role = &r 632 } 633 634 if err := s.RoleSet(guildID, r); err != nil { 635 return role, err 636 } 637 } 638 639 if role == nil { 640 return nil, ErrStoreNotFound 641 } 642 643 return role, nil 644 } 645 646 func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { 647 rs, err := s.Store.Roles(guildID) 648 if err == nil { 649 return rs, nil 650 } 651 652 rs, err = s.Session.Roles(guildID) 653 if err != nil { 654 return nil, err 655 } 656 657 for _, r := range rs { 658 r := r 659 660 if err := s.RoleSet(guildID, r); err != nil { 661 return rs, err 662 } 663 } 664 665 return rs, nil 666 } 667 668 func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { 669 g, err = s.Session.Guild(id) 670 if err == nil { 671 err = s.Store.GuildSet(*g) 672 } 673 674 return 675 } 676 677 func (s *State) fetchMember( 678 guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) { 679 680 m, err = s.Session.Member(guildID, userID) 681 if err == nil { 682 err = s.Store.MemberSet(guildID, *m) 683 } 684 685 return 686 }