github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/cryptionhandler_test.go (about)

     1  package handler
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"io"
     7  	"log"
     8  	"math/rand"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"testing"
    12  
    13  	"github.com/lingyao2333/mo-zero/core/codec"
    14  	"github.com/stretchr/testify/assert"
    15  )
    16  
    17  const (
    18  	reqText  = "ping"
    19  	respText = "pong"
    20  )
    21  
    22  var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
    23  
    24  func init() {
    25  	log.SetOutput(io.Discard)
    26  }
    27  
    28  func TestCryptionHandlerGet(t *testing.T) {
    29  	req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
    30  	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    31  		_, err := w.Write([]byte(respText))
    32  		w.Header().Set("X-Test", "test")
    33  		assert.Nil(t, err)
    34  	}))
    35  	recorder := httptest.NewRecorder()
    36  	handler.ServeHTTP(recorder, req)
    37  
    38  	expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
    39  	assert.Nil(t, err)
    40  	assert.Equal(t, http.StatusOK, recorder.Code)
    41  	assert.Equal(t, "test", recorder.Header().Get("X-Test"))
    42  	assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
    43  }
    44  
    45  func TestCryptionHandlerPost(t *testing.T) {
    46  	var buf bytes.Buffer
    47  	enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
    48  	assert.Nil(t, err)
    49  	buf.WriteString(base64.StdEncoding.EncodeToString(enc))
    50  
    51  	req := httptest.NewRequest(http.MethodPost, "/any", &buf)
    52  	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    53  		body, err := io.ReadAll(r.Body)
    54  		assert.Nil(t, err)
    55  		assert.Equal(t, reqText, string(body))
    56  
    57  		w.Write([]byte(respText))
    58  	}))
    59  	recorder := httptest.NewRecorder()
    60  	handler.ServeHTTP(recorder, req)
    61  
    62  	expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
    63  	assert.Nil(t, err)
    64  	assert.Equal(t, http.StatusOK, recorder.Code)
    65  	assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
    66  }
    67  
    68  func TestCryptionHandlerPostBadEncryption(t *testing.T) {
    69  	var buf bytes.Buffer
    70  	enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
    71  	assert.Nil(t, err)
    72  	buf.Write(enc)
    73  
    74  	req := httptest.NewRequest(http.MethodPost, "/any", &buf)
    75  	handler := CryptionHandler(aesKey)(nil)
    76  	recorder := httptest.NewRecorder()
    77  	handler.ServeHTTP(recorder, req)
    78  
    79  	assert.Equal(t, http.StatusBadRequest, recorder.Code)
    80  }
    81  
    82  func TestCryptionHandlerWriteHeader(t *testing.T) {
    83  	req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
    84  	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    85  		w.WriteHeader(http.StatusServiceUnavailable)
    86  	}))
    87  	recorder := httptest.NewRecorder()
    88  	handler.ServeHTTP(recorder, req)
    89  	assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
    90  }
    91  
    92  func TestCryptionHandlerFlush(t *testing.T) {
    93  	req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
    94  	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    95  		w.Write([]byte(respText))
    96  		flusher, ok := w.(http.Flusher)
    97  		assert.True(t, ok)
    98  		flusher.Flush()
    99  	}))
   100  	recorder := httptest.NewRecorder()
   101  	handler.ServeHTTP(recorder, req)
   102  
   103  	expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
   104  	assert.Nil(t, err)
   105  	assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
   106  }
   107  
   108  func TestCryptionHandler_Hijack(t *testing.T) {
   109  	resp := httptest.NewRecorder()
   110  	writer := newCryptionResponseWriter(resp)
   111  	assert.NotPanics(t, func() {
   112  		writer.Hijack()
   113  	})
   114  
   115  	writer = newCryptionResponseWriter(mockedHijackable{resp})
   116  	assert.NotPanics(t, func() {
   117  		writer.Hijack()
   118  	})
   119  }
   120  
   121  func TestCryptionHandler_ContentTooLong(t *testing.T) {
   122  	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   123  	}))
   124  	svr := httptest.NewServer(handler)
   125  	defer svr.Close()
   126  
   127  	body := make([]byte, maxBytes+1)
   128  	rand.Read(body)
   129  	req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
   130  	assert.Nil(t, err)
   131  	resp, err := http.DefaultClient.Do(req)
   132  	assert.Nil(t, err)
   133  	assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   134  }