github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/contentsecurityhandler.go (about) 1 package handler 2 3 import ( 4 "net/http" 5 "time" 6 7 "github.com/lingyao2333/mo-zero/core/codec" 8 "github.com/lingyao2333/mo-zero/core/logx" 9 "github.com/lingyao2333/mo-zero/rest/httpx" 10 "github.com/lingyao2333/mo-zero/rest/internal/security" 11 ) 12 13 const contentSecurity = "X-Content-Security" 14 15 // UnsignedCallback defines the method of the unsigned callback. 16 type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) 17 18 // ContentSecurityHandler returns a middleware to verify content security. 19 func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration, 20 strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler { 21 if len(callbacks) == 0 { 22 callbacks = append(callbacks, handleVerificationFailure) 23 } 24 25 return func(next http.Handler) http.Handler { 26 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 switch r.Method { 28 case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut: 29 header, err := security.ParseContentSecurity(decrypters, r) 30 if err != nil { 31 logx.Errorf("Signature parse failed, X-Content-Security: %s, error: %s", 32 r.Header.Get(contentSecurity), err.Error()) 33 executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks) 34 } else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass { 35 logx.Errorf("Signature verification failed, X-Content-Security: %s", 36 r.Header.Get(contentSecurity)) 37 executeCallbacks(w, r, next, strict, code, callbacks) 38 } else if r.ContentLength > 0 && header.Encrypted() { 39 CryptionHandler(header.Key)(next).ServeHTTP(w, r) 40 } else { 41 next.ServeHTTP(w, r) 42 } 43 default: 44 next.ServeHTTP(w, r) 45 } 46 }) 47 } 48 } 49 50 func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, 51 code int, callbacks []UnsignedCallback) { 52 for _, callback := range callbacks { 53 callback(w, r, next, strict, code) 54 } 55 } 56 57 func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { 58 if strict { 59 w.WriteHeader(http.StatusForbidden) 60 } else { 61 next.ServeHTTP(w, r) 62 } 63 }