github.com/diamondburned/arikawa/v2@v2.1.0/bot/extras/middlewares/middlewares_test.go (about)

     1  package middlewares
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/diamondburned/arikawa/v2/bot"
     8  	"github.com/diamondburned/arikawa/v2/discord"
     9  	"github.com/diamondburned/arikawa/v2/gateway"
    10  	"github.com/diamondburned/arikawa/v2/session"
    11  	"github.com/diamondburned/arikawa/v2/state"
    12  	"github.com/diamondburned/arikawa/v2/state/store"
    13  )
    14  
    15  func TestAdminOnly(t *testing.T) {
    16  	var ctx = &bot.Context{
    17  		State: &state.State{
    18  			Session: &session.Session{
    19  				Gateway: &gateway.Gateway{
    20  					Identifier: &gateway.Identifier{
    21  						IdentifyData: gateway.IdentifyData{
    22  							Intents: gateway.IntentGuilds | gateway.IntentGuildMembers,
    23  						},
    24  					},
    25  				},
    26  			},
    27  			Cabinet: mockCabinet(),
    28  		},
    29  	}
    30  	var middleware = AdminOnly(ctx)
    31  
    32  	t.Run("allow message", func(t *testing.T) {
    33  		var msg = &gateway.MessageCreateEvent{
    34  			Message: discord.Message{
    35  				ID:        1,
    36  				ChannelID: 69420,
    37  				Author:    discord.User{ID: 69420},
    38  			},
    39  		}
    40  		expectNil(t, middleware(msg))
    41  	})
    42  
    43  	t.Run("deny message", func(t *testing.T) {
    44  		var msg = &gateway.MessageCreateEvent{
    45  			Message: discord.Message{
    46  				ID:        2,
    47  				ChannelID: 1337,
    48  				Author:    discord.User{ID: 1337},
    49  			},
    50  		}
    51  		expectBreak(t, middleware(msg))
    52  		var pin = &gateway.ChannelPinsUpdateEvent{
    53  			ChannelID: 120,
    54  		}
    55  		expectBreak(t, middleware(pin))
    56  		var tpg = &gateway.TypingStartEvent{}
    57  		expectBreak(t, middleware(tpg))
    58  	})
    59  }
    60  
    61  func TestGuildOnly(t *testing.T) {
    62  	var ctx = &bot.Context{
    63  		State: &state.State{
    64  			Session: &session.Session{
    65  				Gateway: &gateway.Gateway{
    66  					Identifier: &gateway.Identifier{
    67  						IdentifyData: gateway.IdentifyData{
    68  							Intents: gateway.IntentGuilds,
    69  						},
    70  					},
    71  				},
    72  			},
    73  			Cabinet: mockCabinet(),
    74  		},
    75  	}
    76  	var middleware = GuildOnly(ctx)
    77  
    78  	t.Run("allow message with GuildID", func(t *testing.T) {
    79  		var msg = &gateway.MessageCreateEvent{
    80  			Message: discord.Message{
    81  				ID:      3,
    82  				GuildID: 1337,
    83  			},
    84  		}
    85  		expectNil(t, middleware(msg))
    86  	})
    87  
    88  	t.Run("allow message with ChannelID", func(t *testing.T) {
    89  		var msg = &gateway.MessageCreateEvent{
    90  			Message: discord.Message{
    91  				ID:        3,
    92  				ChannelID: 69420,
    93  			},
    94  		}
    95  		expectNil(t, middleware(msg))
    96  	})
    97  
    98  	t.Run("deny message", func(t *testing.T) {
    99  		var msg = &gateway.MessageCreateEvent{
   100  			Message: discord.Message{
   101  				ID:        1,
   102  				ChannelID: 12,
   103  			},
   104  		}
   105  		expectBreak(t, middleware(msg))
   106  
   107  		var msg2 = &gateway.MessageCreateEvent{}
   108  		expectBreak(t, middleware(msg2))
   109  	})
   110  }
   111  
   112  func expectNil(t *testing.T, err error) {
   113  	t.Helper()
   114  	if err != nil {
   115  		t.Fatal("Unexpected error:", err)
   116  	}
   117  }
   118  
   119  func expectBreak(t *testing.T, err error) {
   120  	t.Helper()
   121  	if errors.Is(err, bot.Break) {
   122  		return
   123  	}
   124  	if err != nil {
   125  		t.Fatal("Unexpected error:", err)
   126  	}
   127  	t.Fatal("Expected error, got nothing.")
   128  }
   129  
   130  // BenchmarkGuildOnly runs a message through the GuildOnly middleware to
   131  // calculate the overhead of reflection.
   132  func BenchmarkGuildOnly(b *testing.B) {
   133  	var ctx = &bot.Context{
   134  		State: &state.State{
   135  			Cabinet: mockCabinet(),
   136  		},
   137  	}
   138  	var middleware = GuildOnly(ctx)
   139  	var msg = &gateway.MessageCreateEvent{
   140  		Message: discord.Message{
   141  			ID:      3,
   142  			GuildID: 1337,
   143  		},
   144  	}
   145  
   146  	b.ResetTimer()
   147  
   148  	for i := 0; i < b.N; i++ {
   149  		if err := middleware(msg); err != nil {
   150  			b.Fatal("Unexpected error:", err)
   151  		}
   152  	}
   153  }
   154  
   155  // BenchmarkAdminOnly runs a message through the GuildOnly middleware to
   156  // calculate the overhead of reflection.
   157  func BenchmarkAdminOnly(b *testing.B) {
   158  	var ctx = &bot.Context{
   159  		State: &state.State{
   160  			Cabinet: mockCabinet(),
   161  		},
   162  	}
   163  	var middleware = AdminOnly(ctx)
   164  	var msg = &gateway.MessageCreateEvent{
   165  		Message: discord.Message{
   166  			ID:        1,
   167  			ChannelID: 1337,
   168  			Author:    discord.User{ID: 69420},
   169  		},
   170  	}
   171  
   172  	b.ResetTimer()
   173  
   174  	for i := 0; i < b.N; i++ {
   175  		if err := middleware(msg); err != nil {
   176  			b.Fatal("Unexpected error:", err)
   177  		}
   178  	}
   179  }
   180  
   181  type mockStore struct {
   182  	store.NoopStore
   183  }
   184  
   185  func mockCabinet() store.Cabinet {
   186  	c := store.NoopCabinet
   187  	c.GuildStore = &mockStore{}
   188  	c.MemberStore = &mockStore{}
   189  	c.ChannelStore = &mockStore{}
   190  
   191  	return c
   192  }
   193  
   194  func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {
   195  	return &discord.Guild{
   196  		ID: id,
   197  		Roles: []discord.Role{{
   198  			ID:          69420,
   199  			Permissions: discord.PermissionAdministrator,
   200  		}},
   201  	}, nil
   202  }
   203  
   204  func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) {
   205  	return &discord.Member{
   206  		User:    discord.User{ID: userID},
   207  		RoleIDs: []discord.RoleID{discord.RoleID(userID)},
   208  	}, nil
   209  }
   210  
   211  // Channel returns a channel with a guildID for #69420.
   212  func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
   213  	if id == 69420 {
   214  		return &discord.Channel{
   215  			ID:      id,
   216  			GuildID: 1337,
   217  		}, nil
   218  	}
   219  
   220  	return &discord.Channel{
   221  		ID: id,
   222  	}, nil
   223  }