github.com/diamondburned/arikawa@v1.3.14/bot/extras/middlewares/middlewares_test.go (about)

     1  package middlewares
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/diamondburned/arikawa/bot"
     8  	"github.com/diamondburned/arikawa/discord"
     9  	"github.com/diamondburned/arikawa/gateway"
    10  	"github.com/diamondburned/arikawa/state"
    11  )
    12  
    13  func TestAdminOnly(t *testing.T) {
    14  	var ctx = &bot.Context{
    15  		State: &state.State{
    16  			Store: &mockStore{},
    17  		},
    18  	}
    19  	var middleware = AdminOnly(ctx)
    20  
    21  	t.Run("allow message", func(t *testing.T) {
    22  		var msg = &gateway.MessageCreateEvent{
    23  			Message: discord.Message{
    24  				ID:        1,
    25  				ChannelID: 1337,
    26  				Author:    discord.User{ID: 69420},
    27  			},
    28  		}
    29  		expectNil(t, middleware(msg))
    30  	})
    31  
    32  	t.Run("deny message", func(t *testing.T) {
    33  		var msg = &gateway.MessageCreateEvent{
    34  			Message: discord.Message{
    35  				ID:        2,
    36  				ChannelID: 1337,
    37  				Author:    discord.User{ID: 1337},
    38  			},
    39  		}
    40  		expectBreak(t, middleware(msg))
    41  		var pin = &gateway.ChannelPinsUpdateEvent{
    42  			ChannelID: 120,
    43  		}
    44  		expectBreak(t, middleware(pin))
    45  		var tpg = &gateway.TypingStartEvent{}
    46  		expectBreak(t, middleware(tpg))
    47  	})
    48  }
    49  
    50  func TestGuildOnly(t *testing.T) {
    51  	var ctx = &bot.Context{
    52  		State: &state.State{
    53  			Store: &mockStore{},
    54  		},
    55  	}
    56  	var middleware = GuildOnly(ctx)
    57  
    58  	t.Run("allow message with GuildID", func(t *testing.T) {
    59  		var msg = &gateway.MessageCreateEvent{
    60  			Message: discord.Message{
    61  				ID:      3,
    62  				GuildID: 1337,
    63  			},
    64  		}
    65  		expectNil(t, middleware(msg))
    66  	})
    67  
    68  	t.Run("allow message with ChannelID", func(t *testing.T) {
    69  		var msg = &gateway.MessageCreateEvent{
    70  			Message: discord.Message{
    71  				ID:        3,
    72  				ChannelID: 69420,
    73  			},
    74  		}
    75  		expectNil(t, middleware(msg))
    76  	})
    77  
    78  	t.Run("deny message", func(t *testing.T) {
    79  		var msg = &gateway.MessageCreateEvent{
    80  			Message: discord.Message{
    81  				ID:        1,
    82  				ChannelID: 12,
    83  			},
    84  		}
    85  		expectBreak(t, middleware(msg))
    86  
    87  		var msg2 = &gateway.MessageCreateEvent{}
    88  		expectBreak(t, middleware(msg2))
    89  	})
    90  }
    91  
    92  func expectNil(t *testing.T, err error) {
    93  	t.Helper()
    94  	if err != nil {
    95  		t.Fatal("Unexpected error:", err)
    96  	}
    97  }
    98  
    99  func expectBreak(t *testing.T, err error) {
   100  	t.Helper()
   101  	if errors.Is(err, bot.Break) {
   102  		return
   103  	}
   104  	if err != nil {
   105  		t.Fatal("Unexpected error:", err)
   106  	}
   107  	t.Fatal("Expected error, got nothing.")
   108  }
   109  
   110  // BenchmarkGuildOnly runs a message through the GuildOnly middleware to
   111  // calculate the overhead of reflection.
   112  func BenchmarkGuildOnly(b *testing.B) {
   113  	var ctx = &bot.Context{
   114  		State: &state.State{
   115  			Store: &mockStore{},
   116  		},
   117  	}
   118  	var middleware = GuildOnly(ctx)
   119  	var msg = &gateway.MessageCreateEvent{
   120  		Message: discord.Message{
   121  			ID:      3,
   122  			GuildID: 1337,
   123  		},
   124  	}
   125  
   126  	b.ResetTimer()
   127  
   128  	for i := 0; i < b.N; i++ {
   129  		if err := middleware(msg); err != nil {
   130  			b.Fatal("Unexpected error:", err)
   131  		}
   132  	}
   133  }
   134  
   135  // BenchmarkAdminOnly runs a message through the GuildOnly middleware to
   136  // calculate the overhead of reflection.
   137  func BenchmarkAdminOnly(b *testing.B) {
   138  	var ctx = &bot.Context{
   139  		State: &state.State{
   140  			Store: &mockStore{},
   141  		},
   142  	}
   143  	var middleware = AdminOnly(ctx)
   144  	var msg = &gateway.MessageCreateEvent{
   145  		Message: discord.Message{
   146  			ID:        1,
   147  			ChannelID: 1337,
   148  			Author:    discord.User{ID: 69420},
   149  		},
   150  	}
   151  
   152  	b.ResetTimer()
   153  
   154  	for i := 0; i < b.N; i++ {
   155  		if err := middleware(msg); err != nil {
   156  			b.Fatal("Unexpected error:", err)
   157  		}
   158  	}
   159  }
   160  
   161  type mockStore struct {
   162  	state.NoopStore
   163  }
   164  
   165  func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {
   166  	return &discord.Guild{
   167  		ID: id,
   168  		Roles: []discord.Role{{
   169  			ID:          69420,
   170  			Permissions: discord.PermissionAdministrator,
   171  		}},
   172  	}, nil
   173  }
   174  
   175  func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) {
   176  	return &discord.Member{
   177  		User:    discord.User{ID: userID},
   178  		RoleIDs: []discord.RoleID{discord.RoleID(userID)},
   179  	}, nil
   180  }
   181  
   182  // Channel returns a channel with a guildID for #69420.
   183  func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
   184  	if id == 69420 {
   185  		return &discord.Channel{
   186  			ID:      id,
   187  			GuildID: 1337,
   188  		}, nil
   189  	}
   190  
   191  	return &discord.Channel{
   192  		ID: id,
   193  	}, nil
   194  }