github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/authhandler_test.go (about)

     1  package handler
     2  
     3  import (
     4  	"bufio"
     5  	"net"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/golang-jwt/jwt/v4"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  func TestAuthHandlerFailed(t *testing.T) {
    16  	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    17  	handler := Authorize("B63F477D-BBA3-4E52-96D3-C0034C27694A", WithUnauthorizedCallback(
    18  		func(w http.ResponseWriter, r *http.Request, err error) {
    19  			assert.NotNil(t, err)
    20  			w.Header().Set("X-Test", err.Error())
    21  			w.WriteHeader(http.StatusUnauthorized)
    22  			_, err = w.Write([]byte("content"))
    23  			assert.Nil(t, err)
    24  		}))(
    25  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    26  			w.WriteHeader(http.StatusOK)
    27  		}))
    28  
    29  	resp := httptest.NewRecorder()
    30  	handler.ServeHTTP(resp, req)
    31  	assert.Equal(t, http.StatusUnauthorized, resp.Code)
    32  }
    33  
    34  func TestAuthHandler(t *testing.T) {
    35  	const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
    36  	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    37  	token, err := buildToken(key, map[string]interface{}{
    38  		"key": "value",
    39  	}, 3600)
    40  	assert.Nil(t, err)
    41  	req.Header.Set("Authorization", "Bearer "+token)
    42  	handler := Authorize(key)(
    43  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  			w.Header().Set("X-Test", "test")
    45  			_, err := w.Write([]byte("content"))
    46  			assert.Nil(t, err)
    47  
    48  			flusher, ok := w.(http.Flusher)
    49  			assert.True(t, ok)
    50  			flusher.Flush()
    51  		}))
    52  
    53  	resp := httptest.NewRecorder()
    54  	handler.ServeHTTP(resp, req)
    55  	assert.Equal(t, http.StatusOK, resp.Code)
    56  	assert.Equal(t, "test", resp.Header().Get("X-Test"))
    57  	assert.Equal(t, "content", resp.Body.String())
    58  }
    59  
    60  func TestAuthHandlerWithPrevSecret(t *testing.T) {
    61  	const (
    62  		key     = "14F17379-EB8F-411B-8F12-6929002DCA76"
    63  		prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
    64  	)
    65  	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    66  	token, err := buildToken(key, map[string]interface{}{
    67  		"key": "value",
    68  	}, 3600)
    69  	assert.Nil(t, err)
    70  	req.Header.Set("Authorization", "Bearer "+token)
    71  	handler := Authorize(key, WithPrevSecret(prevKey))(
    72  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    73  			w.Header().Set("X-Test", "test")
    74  			_, err := w.Write([]byte("content"))
    75  			assert.Nil(t, err)
    76  		}))
    77  
    78  	resp := httptest.NewRecorder()
    79  	handler.ServeHTTP(resp, req)
    80  	assert.Equal(t, http.StatusOK, resp.Code)
    81  	assert.Equal(t, "test", resp.Header().Get("X-Test"))
    82  	assert.Equal(t, "content", resp.Body.String())
    83  }
    84  
    85  func TestAuthHandler_NilError(t *testing.T) {
    86  	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
    87  	resp := httptest.NewRecorder()
    88  	assert.NotPanics(t, func() {
    89  		unauthorized(resp, req, nil, nil)
    90  	})
    91  }
    92  
    93  func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
    94  	now := time.Now().Unix()
    95  	claims := make(jwt.MapClaims)
    96  	claims["exp"] = now + seconds
    97  	claims["iat"] = now
    98  	for k, v := range payloads {
    99  		claims[k] = v
   100  	}
   101  
   102  	token := jwt.New(jwt.SigningMethodHS256)
   103  	token.Claims = claims
   104  
   105  	return token.SignedString([]byte(secretKey))
   106  }
   107  
   108  type mockedHijackable struct {
   109  	*httptest.ResponseRecorder
   110  }
   111  
   112  func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   113  	return nil, nil, nil
   114  }