github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/maxconnshandler_test.go (about)

     1  package handler
     2  
     3  import (
     4  	"io"
     5  	"log"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"sync"
     9  	"testing"
    10  
    11  	"github.com/lingyao2333/mo-zero/core/lang"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  const conns = 4
    16  
    17  func init() {
    18  	log.SetOutput(io.Discard)
    19  }
    20  
    21  func TestMaxConnsHandler(t *testing.T) {
    22  	var waitGroup sync.WaitGroup
    23  	waitGroup.Add(conns)
    24  	done := make(chan lang.PlaceholderType)
    25  	defer close(done)
    26  
    27  	maxConns := MaxConns(conns)
    28  	handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    29  		waitGroup.Done()
    30  		<-done
    31  	}))
    32  
    33  	for i := 0; i < conns; i++ {
    34  		go func() {
    35  			req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    36  			handler.ServeHTTP(httptest.NewRecorder(), req)
    37  		}()
    38  	}
    39  
    40  	waitGroup.Wait()
    41  	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    42  	resp := httptest.NewRecorder()
    43  	handler.ServeHTTP(resp, req)
    44  	assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
    45  }
    46  
    47  func TestWithoutMaxConnsHandler(t *testing.T) {
    48  	const (
    49  		key   = "block"
    50  		value = "1"
    51  	)
    52  	var waitGroup sync.WaitGroup
    53  	waitGroup.Add(conns)
    54  	done := make(chan lang.PlaceholderType)
    55  	defer close(done)
    56  
    57  	maxConns := MaxConns(0)
    58  	handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    59  		val := r.Header.Get(key)
    60  		if val == value {
    61  			waitGroup.Done()
    62  			<-done
    63  		}
    64  	}))
    65  
    66  	for i := 0; i < conns; i++ {
    67  		go func() {
    68  			req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    69  			req.Header.Set(key, value)
    70  			handler.ServeHTTP(httptest.NewRecorder(), req)
    71  		}()
    72  	}
    73  
    74  	waitGroup.Wait()
    75  	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    76  	resp := httptest.NewRecorder()
    77  	handler.ServeHTTP(resp, req)
    78  	assert.Equal(t, http.StatusOK, resp.Code)
    79  }