github.com/lingyao2333/mo-zero@v1.4.1/rest/httpx/responses_test.go (about) 1 package httpx 2 3 import ( 4 "errors" 5 "net/http" 6 "strings" 7 "testing" 8 9 "github.com/lingyao2333/mo-zero/core/logx" 10 "github.com/stretchr/testify/assert" 11 "google.golang.org/grpc/codes" 12 "google.golang.org/grpc/status" 13 ) 14 15 type message struct { 16 Name string `json:"name"` 17 } 18 19 func init() { 20 logx.Disable() 21 } 22 23 func TestError(t *testing.T) { 24 const ( 25 body = "foo" 26 wrappedBody = `"foo"` 27 ) 28 29 tests := []struct { 30 name string 31 input string 32 errorHandler func(error) (int, interface{}) 33 expectHasBody bool 34 expectBody string 35 expectCode int 36 }{ 37 { 38 name: "default error handler", 39 input: body, 40 expectHasBody: true, 41 expectBody: body, 42 expectCode: http.StatusBadRequest, 43 }, 44 { 45 name: "customized error handler return string", 46 input: body, 47 errorHandler: func(err error) (int, interface{}) { 48 return http.StatusForbidden, err.Error() 49 }, 50 expectHasBody: true, 51 expectBody: wrappedBody, 52 expectCode: http.StatusForbidden, 53 }, 54 { 55 name: "customized error handler return error", 56 input: body, 57 errorHandler: func(err error) (int, interface{}) { 58 return http.StatusForbidden, err 59 }, 60 expectHasBody: true, 61 expectBody: body, 62 expectCode: http.StatusForbidden, 63 }, 64 { 65 name: "customized error handler return nil", 66 input: body, 67 errorHandler: func(err error) (int, interface{}) { 68 return http.StatusForbidden, nil 69 }, 70 expectHasBody: false, 71 expectBody: "", 72 expectCode: http.StatusForbidden, 73 }, 74 } 75 76 for _, test := range tests { 77 t.Run(test.name, func(t *testing.T) { 78 w := tracedResponseWriter{ 79 headers: make(map[string][]string), 80 } 81 if test.errorHandler != nil { 82 lock.RLock() 83 prev := errorHandler 84 lock.RUnlock() 85 SetErrorHandler(test.errorHandler) 86 defer func() { 87 lock.Lock() 88 errorHandler = prev 89 lock.Unlock() 90 }() 91 } 92 Error(&w, errors.New(test.input)) 93 assert.Equal(t, test.expectCode, w.code) 94 assert.Equal(t, test.expectHasBody, w.hasBody) 95 assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String())) 96 }) 97 } 98 } 99 100 func TestErrorWithGrpcError(t *testing.T) { 101 w := tracedResponseWriter{ 102 headers: make(map[string][]string), 103 } 104 Error(&w, status.Error(codes.Unavailable, "foo")) 105 assert.Equal(t, http.StatusServiceUnavailable, w.code) 106 assert.True(t, w.hasBody) 107 assert.True(t, strings.Contains(w.builder.String(), "foo")) 108 } 109 110 func TestErrorWithHandler(t *testing.T) { 111 w := tracedResponseWriter{ 112 headers: make(map[string][]string), 113 } 114 Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) { 115 http.Error(w, err.Error(), 499) 116 }) 117 assert.Equal(t, 499, w.code) 118 assert.True(t, w.hasBody) 119 assert.Equal(t, "foo", strings.TrimSpace(w.builder.String())) 120 } 121 122 func TestOk(t *testing.T) { 123 w := tracedResponseWriter{ 124 headers: make(map[string][]string), 125 } 126 Ok(&w) 127 assert.Equal(t, http.StatusOK, w.code) 128 } 129 130 func TestOkJson(t *testing.T) { 131 w := tracedResponseWriter{ 132 headers: make(map[string][]string), 133 } 134 msg := message{Name: "anyone"} 135 OkJson(&w, msg) 136 assert.Equal(t, http.StatusOK, w.code) 137 assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String()) 138 } 139 140 func TestWriteJsonTimeout(t *testing.T) { 141 // only log it and ignore 142 w := tracedResponseWriter{ 143 headers: make(map[string][]string), 144 err: http.ErrHandlerTimeout, 145 } 146 msg := message{Name: "anyone"} 147 WriteJson(&w, http.StatusOK, msg) 148 assert.Equal(t, http.StatusOK, w.code) 149 } 150 151 func TestWriteJsonError(t *testing.T) { 152 // only log it and ignore 153 w := tracedResponseWriter{ 154 headers: make(map[string][]string), 155 err: errors.New("foo"), 156 } 157 msg := message{Name: "anyone"} 158 WriteJson(&w, http.StatusOK, msg) 159 assert.Equal(t, http.StatusOK, w.code) 160 } 161 162 func TestWriteJsonLessWritten(t *testing.T) { 163 w := tracedResponseWriter{ 164 headers: make(map[string][]string), 165 lessWritten: true, 166 } 167 msg := message{Name: "anyone"} 168 WriteJson(&w, http.StatusOK, msg) 169 assert.Equal(t, http.StatusOK, w.code) 170 } 171 172 func TestWriteJsonMarshalFailed(t *testing.T) { 173 w := tracedResponseWriter{ 174 headers: make(map[string][]string), 175 } 176 WriteJson(&w, http.StatusOK, map[string]interface{}{ 177 "Data": complex(0, 0), 178 }) 179 assert.Equal(t, http.StatusInternalServerError, w.code) 180 } 181 182 type tracedResponseWriter struct { 183 headers map[string][]string 184 builder strings.Builder 185 hasBody bool 186 code int 187 lessWritten bool 188 wroteHeader bool 189 err error 190 } 191 192 func (w *tracedResponseWriter) Header() http.Header { 193 return w.headers 194 } 195 196 func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) { 197 if w.err != nil { 198 return 0, w.err 199 } 200 201 n, err = w.builder.Write(bytes) 202 if w.lessWritten { 203 n-- 204 } 205 w.hasBody = true 206 207 return 208 } 209 210 func (w *tracedResponseWriter) WriteHeader(code int) { 211 if w.wroteHeader { 212 return 213 } 214 w.wroteHeader = true 215 w.code = code 216 }