github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/limiter/limiter_test.go (about)

     1  package limiter_test
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/sohaha/zlsgo"
    12  	"github.com/sohaha/zlsgo/znet"
    13  	"github.com/sohaha/zlsgo/znet/limiter"
    14  )
    15  
    16  var (
    17  	one    sync.Once
    18  	engine *znet.Engine
    19  )
    20  
    21  func newServer() *znet.Engine {
    22  	one.Do(func() {
    23  		engine = znet.New("limiter_test")
    24  		engine.AddAddr("3787")
    25  	})
    26  	return engine
    27  }
    28  
    29  func TestNew(tt *testing.T) {
    30  	t := zlsgo.NewTest(tt)
    31  	r := newServer()
    32  
    33  	rule := limiter.NewRule()
    34  	rule.AddRule(time.Second, 3)
    35  	rule.AddRule(time.Second*2, 4, 5)
    36  	r.GET("/limiterCustomize", func(c *znet.Context) {
    37  		c.String(200, "ok")
    38  	}, func(c *znet.Context) {
    39  		if !rule.AllowVisitByIP(c.GetClientIP()) {
    40  			c.String(http.StatusTooManyRequests, "超过限制")
    41  			c.Abort()
    42  			return
    43  		}
    44  		c.Next()
    45  	})
    46  
    47  	r.GET("/limiterCustomizeUser", func(c *znet.Context) {
    48  		c.String(200, "ok")
    49  	}, func(c *znet.Context) {
    50  		if !rule.AllowVisit("username") {
    51  			c.Abort()
    52  			c.String(http.StatusTooManyRequests, "超过限制")
    53  			return
    54  		}
    55  		c.Next()
    56  	})
    57  
    58  	r.GET("/limiter", func(c *znet.Context) {
    59  		c.String(200, "ok")
    60  	}, limiter.New(3, func(c *znet.Context) {
    61  		c.String(http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests))
    62  	}))
    63  
    64  	r.GET("/limiter2", func(c *znet.Context) {
    65  		c.String(200, "ok")
    66  	}, limiter.New(3))
    67  
    68  	ii := run("/limiterCustomize", r)
    69  	t.EqualExit(int64(3), ii)
    70  
    71  	tt.Log(rule.Remaining("username"))
    72  	ii = run("/limiterCustomizeUser", r)
    73  	t.EqualExit(int64(3), ii)
    74  	t.EqualExit([]int{0, 1}, rule.Remaining("username"))
    75  	tt.Log(rule.GetOnline())
    76  
    77  	ii = run("/limiter", r)
    78  	t.EqualExit(int64(3), ii)
    79  
    80  	ii = run("/limiter2", r)
    81  	t.EqualExit(int64(3), ii)
    82  
    83  	time.Sleep(time.Second)
    84  
    85  	ii = run("/limiterCustomizeUser", r)
    86  	t.EqualExit(int64(1), ii)
    87  	t.EqualExit([]int{0, 0}, rule.Remaining("username"))
    88  
    89  	time.Sleep(time.Second * 2)
    90  
    91  	ii = run("/limiterCustomizeUser", r)
    92  	t.EqualExit(int64(3), ii)
    93  	t.EqualExit([]int{0, 1}, rule.Remaining("username"))
    94  }
    95  
    96  func run(url string, r *znet.Engine) int64 {
    97  	var wg sync.WaitGroup
    98  	var ii int64
    99  	for i := 0; i < 10; i++ {
   100  		wg.Add(1)
   101  		go func() {
   102  			w := httptest.NewRecorder()
   103  			req, _ := http.NewRequest("GET", url, nil)
   104  			req.Header.Set("X-Real-Ip", "192.168.1.1")
   105  			r.ServeHTTP(w, req)
   106  			if w.Code == http.StatusOK {
   107  				atomic.AddInt64(&ii, 1)
   108  			}
   109  			wg.Done()
   110  		}()
   111  	}
   112  	wg.Wait()
   113  	return ii
   114  }