github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/contentsecurityhandler_test.go (about) 1 package handler 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/base64" 7 "fmt" 8 "io" 9 "log" 10 "net/http" 11 "net/http/httptest" 12 "net/url" 13 "os" 14 "strconv" 15 "strings" 16 "testing" 17 "time" 18 19 "github.com/lingyao2333/mo-zero/core/codec" 20 "github.com/lingyao2333/mo-zero/rest/httpx" 21 "github.com/stretchr/testify/assert" 22 ) 23 24 const timeDiff = time.Hour * 2 * 24 25 26 var ( 27 fingerprint = "12345" 28 pubKey = []byte(`-----BEGIN PUBLIC KEY----- 29 MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE 30 eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH 31 miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR 32 my47YlhspwszKdRP+wIDAQAB 33 -----END PUBLIC KEY-----`) 34 priKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 35 MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i 36 1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/ 37 r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB 38 AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH 39 Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY 40 J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0 41 Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP 42 cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO 43 ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR 44 3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV 45 MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l 46 Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc 47 moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ= 48 -----END RSA PRIVATE KEY-----`) 49 key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") 50 ) 51 52 type requestSettings struct { 53 method string 54 url string 55 body io.Reader 56 strict bool 57 crypt bool 58 requestUri string 59 timestamp int64 60 fingerprint string 61 missHeader bool 62 signature string 63 } 64 65 func init() { 66 log.SetOutput(io.Discard) 67 } 68 69 func TestContentSecurityHandler(t *testing.T) { 70 tests := []struct { 71 method string 72 url string 73 body string 74 strict bool 75 crypt bool 76 requestUri string 77 timestamp int64 78 fingerprint string 79 missHeader bool 80 signature string 81 statusCode int 82 }{ 83 { 84 method: http.MethodGet, 85 url: "http://localhost/a/b?c=d&e=f", 86 strict: true, 87 crypt: false, 88 }, 89 { 90 method: http.MethodPost, 91 url: "http://localhost/a/b?c=d&e=f", 92 body: "hello", 93 strict: true, 94 crypt: false, 95 }, 96 { 97 method: http.MethodGet, 98 url: "http://localhost/a/b?c=d&e=f", 99 strict: true, 100 crypt: true, 101 }, 102 { 103 method: http.MethodPost, 104 url: "http://localhost/a/b?c=d&e=f", 105 body: "hello", 106 strict: true, 107 crypt: true, 108 }, 109 { 110 method: http.MethodGet, 111 url: "http://localhost/a/b?c=d&e=f", 112 strict: true, 113 crypt: true, 114 timestamp: time.Now().Add(timeDiff).Unix(), 115 statusCode: http.StatusForbidden, 116 }, 117 { 118 method: http.MethodPost, 119 url: "http://localhost/a/b?c=d&e=f", 120 body: "hello", 121 strict: true, 122 crypt: true, 123 timestamp: time.Now().Add(-timeDiff).Unix(), 124 statusCode: http.StatusForbidden, 125 }, 126 { 127 method: http.MethodPost, 128 url: "http://remotehost/", 129 body: "hello", 130 strict: true, 131 crypt: true, 132 requestUri: "http://localhost/a/b?c=d&e=f", 133 }, 134 { 135 method: http.MethodPost, 136 url: "http://localhost/a/b?c=d&e=f", 137 body: "hello", 138 strict: false, 139 crypt: true, 140 fingerprint: "badone", 141 }, 142 { 143 method: http.MethodPost, 144 url: "http://localhost/a/b?c=d&e=f", 145 body: "hello", 146 strict: true, 147 crypt: true, 148 timestamp: time.Now().Add(-timeDiff).Unix(), 149 fingerprint: "badone", 150 statusCode: http.StatusForbidden, 151 }, 152 { 153 method: http.MethodPost, 154 url: "http://localhost/a/b?c=d&e=f", 155 body: "hello", 156 strict: true, 157 crypt: true, 158 missHeader: true, 159 statusCode: http.StatusForbidden, 160 }, 161 { 162 method: http.MethodHead, 163 url: "http://localhost/a/b?c=d&e=f", 164 strict: true, 165 crypt: false, 166 }, 167 { 168 method: http.MethodGet, 169 url: "http://localhost/a/b?c=d&e=f", 170 strict: true, 171 crypt: false, 172 signature: "badone", 173 statusCode: http.StatusForbidden, 174 }, 175 } 176 177 for _, test := range tests { 178 t.Run(test.url, func(t *testing.T) { 179 if test.statusCode == 0 { 180 test.statusCode = http.StatusOK 181 } 182 if len(test.fingerprint) == 0 { 183 test.fingerprint = fingerprint 184 } 185 if test.timestamp == 0 { 186 test.timestamp = time.Now().Unix() 187 } 188 189 func() { 190 keyFile, err := createTempFile(priKey) 191 defer os.Remove(keyFile) 192 193 assert.Nil(t, err) 194 decrypter, err := codec.NewRsaDecrypter(keyFile) 195 assert.Nil(t, err) 196 contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{ 197 fingerprint: decrypter, 198 }, time.Hour, test.strict) 199 handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 200 })) 201 202 var reader io.Reader 203 if len(test.body) > 0 { 204 reader = strings.NewReader(test.body) 205 } 206 setting := requestSettings{ 207 method: test.method, 208 url: test.url, 209 body: reader, 210 strict: test.strict, 211 crypt: test.crypt, 212 requestUri: test.requestUri, 213 timestamp: test.timestamp, 214 fingerprint: test.fingerprint, 215 missHeader: test.missHeader, 216 signature: test.signature, 217 } 218 req, err := buildRequest(setting) 219 assert.Nil(t, err) 220 resp := httptest.NewRecorder() 221 handler.ServeHTTP(resp, req) 222 assert.Equal(t, test.statusCode, resp.Code) 223 }() 224 }) 225 } 226 } 227 228 func TestContentSecurityHandler_UnsignedCallback(t *testing.T) { 229 keyFile, err := createTempFile(priKey) 230 defer os.Remove(keyFile) 231 232 assert.Nil(t, err) 233 decrypter, err := codec.NewRsaDecrypter(keyFile) 234 assert.Nil(t, err) 235 contentSecurityHandler := ContentSecurityHandler( 236 map[string]codec.RsaDecrypter{ 237 fingerprint: decrypter, 238 }, 239 time.Hour, 240 true, 241 func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { 242 w.WriteHeader(http.StatusOK) 243 }) 244 handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 245 246 setting := requestSettings{ 247 method: http.MethodGet, 248 url: "http://localhost/a/b?c=d&e=f", 249 signature: "badone", 250 } 251 req, err := buildRequest(setting) 252 assert.Nil(t, err) 253 resp := httptest.NewRecorder() 254 handler.ServeHTTP(resp, req) 255 assert.Equal(t, http.StatusOK, resp.Code) 256 } 257 258 func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) { 259 keyFile, err := createTempFile(priKey) 260 defer os.Remove(keyFile) 261 262 assert.Nil(t, err) 263 decrypter, err := codec.NewRsaDecrypter(keyFile) 264 assert.Nil(t, err) 265 contentSecurityHandler := ContentSecurityHandler( 266 map[string]codec.RsaDecrypter{ 267 fingerprint: decrypter, 268 }, 269 time.Hour, 270 true, 271 func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { 272 assert.Equal(t, httpx.CodeSignatureWrongTime, code) 273 w.WriteHeader(http.StatusOK) 274 }) 275 handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 276 277 reader := strings.NewReader("hello") 278 setting := requestSettings{ 279 method: http.MethodPost, 280 url: "http://localhost/a/b?c=d&e=f", 281 body: reader, 282 strict: true, 283 crypt: true, 284 timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(), 285 fingerprint: fingerprint, 286 } 287 req, err := buildRequest(setting) 288 assert.Nil(t, err) 289 resp := httptest.NewRecorder() 290 handler.ServeHTTP(resp, req) 291 assert.Equal(t, http.StatusOK, resp.Code) 292 } 293 294 func buildRequest(rs requestSettings) (*http.Request, error) { 295 var bodyStr string 296 var err error 297 298 if rs.crypt && rs.body != nil { 299 var buf bytes.Buffer 300 io.Copy(&buf, rs.body) 301 bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes()) 302 if err != nil { 303 return nil, err 304 } 305 bodyStr = base64.StdEncoding.EncodeToString(bodyBytes) 306 } 307 308 r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr)) 309 if len(rs.signature) == 0 { 310 sha := sha256.New() 311 sha.Write([]byte(bodyStr)) 312 bodySign := fmt.Sprintf("%x", sha.Sum(nil)) 313 var path string 314 var query string 315 if len(rs.requestUri) > 0 { 316 u, err := url.Parse(rs.requestUri) 317 if err != nil { 318 return nil, err 319 } 320 321 path = u.Path 322 query = u.RawQuery 323 } else { 324 path = r.URL.Path 325 query = r.URL.RawQuery 326 } 327 contentOfSign := strings.Join([]string{ 328 strconv.FormatInt(rs.timestamp, 10), 329 rs.method, 330 path, 331 query, 332 bodySign, 333 }, "\n") 334 rs.signature = codec.HmacBase64([]byte(key), contentOfSign) 335 } 336 337 var mode string 338 if rs.crypt { 339 mode = "1" 340 } else { 341 mode = "0" 342 } 343 content := strings.Join([]string{ 344 "version=v1", 345 "type=" + mode, 346 fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)), 347 "time=" + strconv.FormatInt(rs.timestamp, 10), 348 }, "; ") 349 350 encrypter, err := codec.NewRsaEncrypter([]byte(pubKey)) 351 if err != nil { 352 log.Fatal(err) 353 } 354 355 output, err := encrypter.Encrypt([]byte(content)) 356 if err != nil { 357 log.Fatal(err) 358 } 359 360 encryptedContent := base64.StdEncoding.EncodeToString(output) 361 if !rs.missHeader { 362 r.Header.Set(httpx.ContentSecurity, strings.Join([]string{ 363 fmt.Sprintf("key=%s", rs.fingerprint), 364 "secret=" + encryptedContent, 365 "signature=" + rs.signature, 366 }, "; ")) 367 } 368 if len(rs.requestUri) > 0 { 369 r.Header.Set("X-Request-Uri", rs.requestUri) 370 } 371 372 return r, nil 373 } 374 375 func createTempFile(body []byte) (string, error) { 376 tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp") 377 if err != nil { 378 return "", err 379 } 380 381 tmpFile.Close() 382 err = os.WriteFile(tmpFile.Name(), body, os.ModePerm) 383 if err != nil { 384 return "", err 385 } 386 387 return tmpFile.Name(), nil 388 }