github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/cryptionhandler_test.go (about) 1 package handler 2 3 import ( 4 "bytes" 5 "encoding/base64" 6 "io" 7 "log" 8 "math/rand" 9 "net/http" 10 "net/http/httptest" 11 "testing" 12 13 "github.com/lingyao2333/mo-zero/core/codec" 14 "github.com/stretchr/testify/assert" 15 ) 16 17 const ( 18 reqText = "ping" 19 respText = "pong" 20 ) 21 22 var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`) 23 24 func init() { 25 log.SetOutput(io.Discard) 26 } 27 28 func TestCryptionHandlerGet(t *testing.T) { 29 req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody) 30 handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 31 _, err := w.Write([]byte(respText)) 32 w.Header().Set("X-Test", "test") 33 assert.Nil(t, err) 34 })) 35 recorder := httptest.NewRecorder() 36 handler.ServeHTTP(recorder, req) 37 38 expect, err := codec.EcbEncrypt(aesKey, []byte(respText)) 39 assert.Nil(t, err) 40 assert.Equal(t, http.StatusOK, recorder.Code) 41 assert.Equal(t, "test", recorder.Header().Get("X-Test")) 42 assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) 43 } 44 45 func TestCryptionHandlerPost(t *testing.T) { 46 var buf bytes.Buffer 47 enc, err := codec.EcbEncrypt(aesKey, []byte(reqText)) 48 assert.Nil(t, err) 49 buf.WriteString(base64.StdEncoding.EncodeToString(enc)) 50 51 req := httptest.NewRequest(http.MethodPost, "/any", &buf) 52 handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 body, err := io.ReadAll(r.Body) 54 assert.Nil(t, err) 55 assert.Equal(t, reqText, string(body)) 56 57 w.Write([]byte(respText)) 58 })) 59 recorder := httptest.NewRecorder() 60 handler.ServeHTTP(recorder, req) 61 62 expect, err := codec.EcbEncrypt(aesKey, []byte(respText)) 63 assert.Nil(t, err) 64 assert.Equal(t, http.StatusOK, recorder.Code) 65 assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) 66 } 67 68 func TestCryptionHandlerPostBadEncryption(t *testing.T) { 69 var buf bytes.Buffer 70 enc, err := codec.EcbEncrypt(aesKey, []byte(reqText)) 71 assert.Nil(t, err) 72 buf.Write(enc) 73 74 req := httptest.NewRequest(http.MethodPost, "/any", &buf) 75 handler := CryptionHandler(aesKey)(nil) 76 recorder := httptest.NewRecorder() 77 handler.ServeHTTP(recorder, req) 78 79 assert.Equal(t, http.StatusBadRequest, recorder.Code) 80 } 81 82 func TestCryptionHandlerWriteHeader(t *testing.T) { 83 req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody) 84 handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 85 w.WriteHeader(http.StatusServiceUnavailable) 86 })) 87 recorder := httptest.NewRecorder() 88 handler.ServeHTTP(recorder, req) 89 assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) 90 } 91 92 func TestCryptionHandlerFlush(t *testing.T) { 93 req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody) 94 handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 95 w.Write([]byte(respText)) 96 flusher, ok := w.(http.Flusher) 97 assert.True(t, ok) 98 flusher.Flush() 99 })) 100 recorder := httptest.NewRecorder() 101 handler.ServeHTTP(recorder, req) 102 103 expect, err := codec.EcbEncrypt(aesKey, []byte(respText)) 104 assert.Nil(t, err) 105 assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) 106 } 107 108 func TestCryptionHandler_Hijack(t *testing.T) { 109 resp := httptest.NewRecorder() 110 writer := newCryptionResponseWriter(resp) 111 assert.NotPanics(t, func() { 112 writer.Hijack() 113 }) 114 115 writer = newCryptionResponseWriter(mockedHijackable{resp}) 116 assert.NotPanics(t, func() { 117 writer.Hijack() 118 }) 119 } 120 121 func TestCryptionHandler_ContentTooLong(t *testing.T) { 122 handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 123 })) 124 svr := httptest.NewServer(handler) 125 defer svr.Close() 126 127 body := make([]byte, maxBytes+1) 128 rand.Read(body) 129 req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body)) 130 assert.Nil(t, err) 131 resp, err := http.DefaultClient.Do(req) 132 assert.Nil(t, err) 133 assert.Equal(t, http.StatusBadRequest, resp.StatusCode) 134 }