github.com/diamondburned/arikawa/v2@v2.1.0/state/state.go (about) 1 // Package state provides interfaces for a local or remote state, as well as 2 // abstractions around the REST API and Gateway events. 3 package state 4 5 import ( 6 "context" 7 "sync" 8 9 "github.com/diamondburned/arikawa/v2/discord" 10 "github.com/diamondburned/arikawa/v2/gateway" 11 "github.com/diamondburned/arikawa/v2/session" 12 "github.com/diamondburned/arikawa/v2/state/store" 13 "github.com/diamondburned/arikawa/v2/state/store/defaultstore" 14 "github.com/diamondburned/arikawa/v2/utils/handler" 15 16 "github.com/pkg/errors" 17 ) 18 19 var ( 20 MaxFetchMembers uint = 1000 21 MaxFetchGuilds uint = 100 22 ) 23 24 // State is the cache to store events coming from Discord as well as data from 25 // API calls. 26 // 27 // Store 28 // 29 // The state basically provides abstractions on top of the API and the state 30 // storage (Store). The state storage is effectively a set of interfaces which 31 // allow arbitrary backends to be implemented. 32 // 33 // The default storage backend is a typical in-memory structure consisting of 34 // maps and slices. Custom backend implementations could embed this storage 35 // backend as an in-memory fallback. A good example of this would be embedding 36 // the default store for messages only, while handling everything else in Redis. 37 // 38 // The package also provides a no-op store (NoopStore) that implementations 39 // could embed. This no-op store will always return an error, which makes the 40 // state fetch information from the API. The setters are all no-ops, so the 41 // fetched data won't be updated. 42 // 43 // Handler 44 // 45 // The state uses its own handler over session's to make all handlers run after 46 // the state updates itself. A PreHandler is exposed in any case the user needs 47 // the handlers to run before the state updates itself. Refer to that field's 48 // documentation. 49 // 50 // The state also provides extra events and overrides to make up for Discord's 51 // inconsistencies in data. The following are known instances of such. 52 // 53 // The Guild Create event is split up to make the state's Guild Available, Guild 54 // Ready and Guild Join events. Refer to these events' documentations for more 55 // information. 56 // 57 // The Message Create and Message Update events with the Member field provided 58 // will have the User field copied from Author. This is because the User field 59 // will be empty, while the Member structure expects it to be there. 60 type State struct { 61 *session.Session 62 store.Cabinet 63 64 // *: State doesn't actually keep track of pinned messages. 65 66 readyMu *sync.Mutex 67 ready gateway.ReadyEvent 68 69 // StateLog logs all errors that come from the state cache. This includes 70 // not found errors. Defaults to a no-op, as state errors aren't that 71 // important. 72 StateLog func(error) 73 74 // PreHandler is the manual hook that is executed before the State handler 75 // is. This should only be used for low-level operations. 76 // It's recommended to set Synchronous to true if you mutate the events. 77 PreHandler *handler.Handler // default nil 78 79 // Command handler with inherited methods. Ran after PreHandler. You should 80 // most of the time use this instead of Session's, to avoid race conditions 81 // with the State. 82 *handler.Handler 83 84 // List of channels with few messages, so it doesn't bother hitting the API 85 // again. 86 fewMessages map[discord.ChannelID]struct{} 87 fewMutex *sync.Mutex 88 89 // unavailableGuilds is a set of discord.GuildIDs of guilds that became 90 // unavailable after connecting to the gateway, i.e. they were sent in a 91 // GuildUnavailableEvent. 92 unavailableGuilds map[discord.GuildID]struct{} 93 // unreadyGuilds is a set of discord.GuildIDs of the guilds received during 94 // the Ready event. After receiving guild create events for those guilds, 95 // they will be removed. 96 unreadyGuilds map[discord.GuildID]struct{} 97 guildMutex *sync.Mutex 98 } 99 100 // New creates a new state. 101 func New(token string) (*State, error) { 102 return NewWithStore(token, defaultstore.New()) 103 } 104 105 // NewWithIntents creates a new state with the given gateway intents. For more 106 // information, refer to gateway.Intents. 107 func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) { 108 s, err := session.NewWithIntents(token, intents...) 109 if err != nil { 110 return nil, err 111 } 112 113 return NewFromSession(s, defaultstore.New()), nil 114 } 115 116 func NewWithStore(token string, cabinet store.Cabinet) (*State, error) { 117 s, err := session.New(token) 118 if err != nil { 119 return nil, err 120 } 121 122 return NewFromSession(s, cabinet), nil 123 } 124 125 // NewFromSession creates a new State from the passed Session and Cabinet. 126 func NewFromSession(s *session.Session, cabinet store.Cabinet) *State { 127 state := &State{ 128 Session: s, 129 Cabinet: cabinet, 130 Handler: handler.New(), 131 StateLog: func(err error) {}, 132 readyMu: new(sync.Mutex), 133 fewMessages: map[discord.ChannelID]struct{}{}, 134 fewMutex: new(sync.Mutex), 135 unavailableGuilds: make(map[discord.GuildID]struct{}), 136 unreadyGuilds: make(map[discord.GuildID]struct{}), 137 guildMutex: new(sync.Mutex), 138 } 139 state.hookSession() 140 return state 141 } 142 143 // WithContext returns a shallow copy of State with the context replaced in the 144 // API client. All methods called on the State will use this given context. This 145 // method is thread-safe. 146 func (s *State) WithContext(ctx context.Context) *State { 147 copied := *s 148 copied.Session = s.Session.WithContext(ctx) 149 150 return &copied 151 } 152 153 // Ready returns a copy of the Ready event. Although this function is safe to 154 // call concurrently, its values should still not be changed, as certain types 155 // like slices are not concurrent-safe. 156 // 157 // Note that if Ready events are not received yet, then the returned event will 158 // be a zero-value Ready instance. 159 func (s *State) Ready() gateway.ReadyEvent { 160 s.readyMu.Lock() 161 r := s.ready 162 s.readyMu.Unlock() 163 164 return r 165 } 166 167 //// Helper methods 168 169 func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string { 170 if !message.GuildID.IsValid() { 171 return message.Author.Username 172 } 173 174 if message.Member != nil { 175 if message.Member.Nick != "" { 176 return message.Member.Nick 177 } 178 return message.Author.Username 179 } 180 181 n, err := s.MemberDisplayName(message.GuildID, message.Author.ID) 182 if err != nil { 183 return message.Author.Username 184 } 185 186 return n 187 } 188 189 func (s *State) MemberDisplayName(guildID discord.GuildID, userID discord.UserID) (string, error) { 190 member, err := s.Member(guildID, userID) 191 if err != nil { 192 return "", err 193 } 194 195 if member.Nick == "" { 196 return member.User.Username, nil 197 } 198 199 return member.Nick, nil 200 } 201 202 func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color, error) { 203 if !message.GuildID.IsValid() { // this is a dm 204 return discord.DefaultMemberColor, nil 205 } 206 207 if message.Member != nil { 208 guild, err := s.Guild(message.GuildID) 209 if err != nil { 210 return 0, err 211 } 212 return discord.MemberColor(*guild, *message.Member), nil 213 } 214 215 return s.MemberColor(message.GuildID, message.Author.ID) 216 } 217 218 func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) { 219 var wg sync.WaitGroup 220 221 var ( 222 g *discord.Guild 223 m *discord.Member 224 225 gerr = store.ErrNotFound 226 merr = store.ErrNotFound 227 ) 228 229 if s.Gateway.HasIntents(gateway.IntentGuilds) { 230 g, gerr = s.Cabinet.Guild(guildID) 231 } 232 233 if s.Gateway.HasIntents(gateway.IntentGuildMembers) { 234 m, merr = s.Cabinet.Member(guildID, userID) 235 } 236 237 switch { 238 case gerr != nil && merr != nil: 239 wg.Add(1) 240 go func() { 241 g, gerr = s.fetchGuild(guildID) 242 wg.Done() 243 }() 244 245 m, merr = s.fetchMember(guildID, userID) 246 case gerr != nil: 247 g, gerr = s.fetchGuild(guildID) 248 case merr != nil: 249 m, merr = s.fetchMember(guildID, userID) 250 } 251 252 wg.Wait() 253 254 if gerr != nil { 255 return 0, errors.Wrap(merr, "failed to get guild") 256 } 257 if merr != nil { 258 return 0, errors.Wrap(merr, "failed to get member") 259 } 260 261 return discord.MemberColor(*g, *m), nil 262 } 263 264 //// 265 266 // Permissions gets the user's permissions in the given channel. If the channel 267 // is not in any guild, then an error is returned. 268 func (s *State) Permissions( 269 channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) { 270 271 ch, err := s.Channel(channelID) 272 if err != nil { 273 return 0, errors.Wrap(err, "failed to get channel") 274 } 275 276 if !ch.GuildID.IsValid() { 277 return 0, errors.New("channel is not in a guild") 278 } 279 280 var wg sync.WaitGroup 281 282 var ( 283 g *discord.Guild 284 m *discord.Member 285 286 gerr = store.ErrNotFound 287 merr = store.ErrNotFound 288 ) 289 290 if s.Gateway.HasIntents(gateway.IntentGuilds) { 291 g, gerr = s.Cabinet.Guild(ch.GuildID) 292 } 293 294 if s.Gateway.HasIntents(gateway.IntentGuildMembers) { 295 m, merr = s.Cabinet.Member(ch.GuildID, userID) 296 } 297 298 switch { 299 case gerr != nil && merr != nil: 300 wg.Add(1) 301 go func() { 302 g, gerr = s.fetchGuild(ch.GuildID) 303 wg.Done() 304 }() 305 306 m, merr = s.fetchMember(ch.GuildID, userID) 307 case gerr != nil: 308 g, gerr = s.fetchGuild(ch.GuildID) 309 case merr != nil: 310 m, merr = s.fetchMember(ch.GuildID, userID) 311 } 312 313 wg.Wait() 314 315 if gerr != nil { 316 return 0, errors.Wrap(merr, "failed to get guild") 317 } 318 if merr != nil { 319 return 0, errors.Wrap(merr, "failed to get member") 320 } 321 322 return discord.CalcOverwrites(*g, *ch, *m), nil 323 } 324 325 //// 326 327 func (s *State) Me() (*discord.User, error) { 328 u, err := s.Cabinet.Me() 329 if err == nil { 330 return u, nil 331 } 332 333 u, err = s.Session.Me() 334 if err != nil { 335 return nil, err 336 } 337 338 return u, s.Cabinet.MyselfSet(*u) 339 } 340 341 //// 342 343 func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { 344 c, err = s.Cabinet.Channel(id) 345 if err == nil && s.tracksChannel(c) { 346 return 347 } 348 349 c, err = s.Session.Channel(id) 350 if err != nil { 351 return 352 } 353 354 if s.tracksChannel(c) { 355 err = s.Cabinet.ChannelSet(*c) 356 } 357 358 return 359 } 360 361 func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) { 362 if s.Gateway.HasIntents(gateway.IntentGuilds) { 363 cs, err = s.Cabinet.Channels(guildID) 364 if err == nil { 365 return 366 } 367 } 368 369 cs, err = s.Session.Channels(guildID) 370 if err != nil { 371 return 372 } 373 374 if s.Gateway.HasIntents(gateway.IntentGuilds) { 375 for _, c := range cs { 376 if err = s.Cabinet.ChannelSet(c); err != nil { 377 return 378 } 379 } 380 } 381 382 return 383 } 384 385 func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { 386 c, err := s.Cabinet.CreatePrivateChannel(recipient) 387 if err == nil { 388 return c, nil 389 } 390 391 c, err = s.Session.CreatePrivateChannel(recipient) 392 if err != nil { 393 return nil, err 394 } 395 396 return c, s.Cabinet.ChannelSet(*c) 397 } 398 399 // PrivateChannels gets the direct messages of the user. 400 // This is not supported for bots. 401 func (s *State) PrivateChannels() ([]discord.Channel, error) { 402 cs, err := s.Cabinet.PrivateChannels() 403 if err == nil { 404 return cs, nil 405 } 406 407 cs, err = s.Session.PrivateChannels() 408 if err != nil { 409 return nil, err 410 } 411 412 for _, c := range cs { 413 if err := s.Cabinet.ChannelSet(c); err != nil { 414 return nil, err 415 } 416 } 417 418 return cs, nil 419 } 420 421 //// 422 423 func (s *State) Emoji( 424 guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) { 425 426 if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { 427 e, err = s.Cabinet.Emoji(guildID, emojiID) 428 if err == nil { 429 return 430 } 431 } else { // Fast path 432 return s.Session.Emoji(guildID, emojiID) 433 } 434 435 es, err := s.Session.Emojis(guildID) 436 if err != nil { 437 return nil, err 438 } 439 440 if err = s.Cabinet.EmojiSet(guildID, es); err != nil { 441 return 442 } 443 444 for _, e := range es { 445 if e.ID == emojiID { 446 return &e, nil 447 } 448 } 449 450 return nil, store.ErrNotFound 451 } 452 453 func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) { 454 if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { 455 es, err = s.Cabinet.Emojis(guildID) 456 if err == nil { 457 return 458 } 459 } 460 461 es, err = s.Session.Emojis(guildID) 462 if err != nil { 463 return 464 } 465 466 if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { 467 err = s.Cabinet.EmojiSet(guildID, es) 468 } 469 470 return 471 } 472 473 //// 474 475 func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { 476 if s.Gateway.HasIntents(gateway.IntentGuilds) { 477 c, err := s.Cabinet.Guild(id) 478 if err == nil { 479 return c, nil 480 } 481 } 482 483 return s.fetchGuild(id) 484 } 485 486 // Guilds will only fill a maximum of 100 guilds from the API. 487 func (s *State) Guilds() (gs []discord.Guild, err error) { 488 if s.Gateway.HasIntents(gateway.IntentGuilds) { 489 gs, err = s.Cabinet.Guilds() 490 if err == nil { 491 return 492 } 493 } 494 495 gs, err = s.Session.Guilds(MaxFetchGuilds) 496 if err != nil { 497 return 498 } 499 500 if s.Gateway.HasIntents(gateway.IntentGuilds) { 501 for _, g := range gs { 502 if err = s.Cabinet.GuildSet(g); err != nil { 503 return 504 } 505 } 506 } 507 508 return 509 } 510 511 //// 512 513 func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { 514 if s.Gateway.HasIntents(gateway.IntentGuildMembers) { 515 m, err := s.Cabinet.Member(guildID, userID) 516 if err == nil { 517 return m, nil 518 } 519 } 520 521 return s.fetchMember(guildID, userID) 522 } 523 524 func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) { 525 if s.Gateway.HasIntents(gateway.IntentGuildMembers) { 526 ms, err = s.Cabinet.Members(guildID) 527 if err == nil { 528 return 529 } 530 } 531 532 ms, err = s.Session.Members(guildID, MaxFetchMembers) 533 if err != nil { 534 return 535 } 536 537 if s.Gateway.HasIntents(gateway.IntentGuildMembers) { 538 for _, m := range ms { 539 if err = s.Cabinet.MemberSet(guildID, m); err != nil { 540 return 541 } 542 } 543 } 544 545 return 546 } 547 548 //// 549 550 func (s *State) Message( 551 channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { 552 553 m, err := s.Cabinet.Message(channelID, messageID) 554 if err == nil && s.tracksMessage(m) { 555 return m, nil 556 } 557 558 var ( 559 wg sync.WaitGroup 560 561 c *discord.Channel 562 cerr = store.ErrNotFound 563 ) 564 565 c, cerr = s.Cabinet.Channel(channelID) 566 if cerr != nil || !s.tracksChannel(c) { 567 wg.Add(1) 568 go func() { 569 c, cerr = s.Session.Channel(channelID) 570 if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { 571 cerr = s.Cabinet.ChannelSet(*c) 572 } 573 574 wg.Done() 575 }() 576 } 577 578 m, err = s.Session.Message(channelID, messageID) 579 if err != nil { 580 return nil, errors.Wrap(err, "unable to fetch message") 581 } 582 583 wg.Wait() 584 585 if cerr != nil { 586 return nil, errors.Wrap(cerr, "unable to fetch channel") 587 } 588 589 m.ChannelID = c.ID 590 m.GuildID = c.GuildID 591 592 if s.tracksMessage(m) { 593 err = s.Cabinet.MessageSet(*m) 594 } 595 596 return m, err 597 } 598 599 // Messages fetches maximum 100 messages from the API, if it has to. There is 600 // no limit if it's from the State storage. 601 func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) { 602 // TODO: Think of a design that doesn't rely on MaxMessages(). 603 var maxMsgs = s.MaxMessages() 604 605 ms, err := s.Cabinet.Messages(channelID) 606 if err == nil && (len(ms) == 0 || s.tracksMessage(&ms[0])) { 607 // If the state already has as many messages as it can, skip the API. 608 if maxMsgs <= len(ms) { 609 return ms, nil 610 } 611 612 // Is the channel tiny? 613 s.fewMutex.Lock() 614 if _, ok := s.fewMessages[channelID]; ok { 615 s.fewMutex.Unlock() 616 return ms, nil 617 } 618 619 // No, fetch from the state. 620 s.fewMutex.Unlock() 621 } 622 623 ms, err = s.Session.Messages(channelID, uint(maxMsgs)) 624 if err != nil { 625 return nil, err 626 } 627 628 // New messages fetched weirdly does not have GuildID filled. We'll try and 629 // get it for consistency with incoming message creates. 630 var guildID discord.GuildID 631 632 // A bit too convoluted, but whatever. 633 c, err := s.Channel(channelID) 634 if err == nil { 635 // If it's 0, it's 0 anyway. We don't need a check here. 636 guildID = c.GuildID 637 } 638 639 if len(ms) > 0 && s.tracksMessage(&ms[0]) { 640 // Iterate in reverse, since the store is expected to prepend the latest 641 // messages. 642 for i := len(ms) - 1; i >= 0; i-- { 643 // Set the guild ID, fine if it's 0 (it's already 0 anyway). 644 ms[i].GuildID = guildID 645 646 if err := s.Cabinet.MessageSet(ms[i]); err != nil { 647 return nil, err 648 } 649 } 650 } 651 652 if len(ms) < maxMsgs { 653 // Tiny channel, store this. 654 s.fewMutex.Lock() 655 s.fewMessages[channelID] = struct{}{} 656 s.fewMutex.Unlock() 657 658 return ms, nil 659 } 660 661 // Since the latest messages are at the end and we already know the maxMsgs, 662 // we could slice this right away. 663 return ms[:maxMsgs], nil 664 } 665 666 //// 667 668 // Presence checks the state for user presences. If no guildID is given, it 669 // will look for the presence in all cached guilds. 670 func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*gateway.Presence, error) { 671 if !s.Gateway.HasIntents(gateway.IntentGuildPresences) { 672 return nil, store.ErrNotFound 673 } 674 675 // If there's no guild ID, look in all guilds 676 if !gID.IsValid() { 677 if !s.Gateway.HasIntents(gateway.IntentGuilds) { 678 return nil, store.ErrNotFound 679 } 680 681 g, err := s.Cabinet.Guilds() 682 if err != nil { 683 return nil, err 684 } 685 686 for _, g := range g { 687 if p, err := s.Cabinet.Presence(g.ID, uID); err == nil { 688 return p, nil 689 } 690 } 691 692 return nil, store.ErrNotFound 693 } 694 695 return s.Cabinet.Presence(gID, uID) 696 } 697 698 //// 699 700 func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) { 701 if s.Gateway.HasIntents(gateway.IntentGuilds) { 702 target, err = s.Cabinet.Role(guildID, roleID) 703 if err == nil { 704 return 705 } 706 } 707 708 rs, err := s.Session.Roles(guildID) 709 if err != nil { 710 return 711 } 712 713 for _, r := range rs { 714 if r.ID == roleID { 715 r := r // copy to prevent mem aliasing 716 target = &r 717 } 718 719 if s.Gateway.HasIntents(gateway.IntentGuilds) { 720 if err = s.RoleSet(guildID, r); err != nil { 721 return 722 } 723 } 724 } 725 726 if target == nil { 727 return nil, store.ErrNotFound 728 } 729 730 return 731 } 732 733 func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { 734 rs, err := s.Cabinet.Roles(guildID) 735 if err == nil { 736 return rs, nil 737 } 738 739 rs, err = s.Session.Roles(guildID) 740 if err != nil { 741 return nil, err 742 } 743 744 if s.Gateway.HasIntents(gateway.IntentGuilds) { 745 for _, r := range rs { 746 if err := s.RoleSet(guildID, r); err != nil { 747 return rs, err 748 } 749 } 750 } 751 752 return rs, nil 753 } 754 755 func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { 756 g, err = s.Session.Guild(id) 757 if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { 758 err = s.Cabinet.GuildSet(*g) 759 } 760 761 return 762 } 763 764 func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { 765 m, err = s.Session.Member(gID, uID) 766 if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { 767 err = s.Cabinet.MemberSet(gID, *m) 768 } 769 770 return 771 } 772 773 // tracksMessage reports whether the state would track the passed message and 774 // messages from the same channel. 775 func (s *State) tracksMessage(m *discord.Message) bool { 776 return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || 777 (!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages)) 778 } 779 780 // tracksChannel reports whether the state would track the passed channel. 781 func (s *State) tracksChannel(c *discord.Channel) bool { 782 return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || 783 !c.GuildID.IsValid() 784 }