github.com/lingyao2333/mo-zero@v1.4.1/rest/internal/security/contentsecurity.go (about) 1 package security 2 3 import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "io" 9 "net/http" 10 "net/url" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/lingyao2333/mo-zero/core/codec" 16 "github.com/lingyao2333/mo-zero/core/iox" 17 "github.com/lingyao2333/mo-zero/core/logx" 18 "github.com/lingyao2333/mo-zero/rest/httpx" 19 ) 20 21 const ( 22 requestUriHeader = "X-Request-Uri" 23 signatureField = "signature" 24 timeField = "time" 25 ) 26 27 var ( 28 // ErrInvalidContentType is an error that indicates invalid content type. 29 ErrInvalidContentType = errors.New("invalid content type") 30 // ErrInvalidHeader is an error that indicates invalid X-Content-Security header. 31 ErrInvalidHeader = errors.New("invalid X-Content-Security header") 32 // ErrInvalidKey is an error that indicates invalid key. 33 ErrInvalidKey = errors.New("invalid key") 34 // ErrInvalidPublicKey is an error that indicates invalid public key. 35 ErrInvalidPublicKey = errors.New("invalid public key") 36 // ErrInvalidSecret is an error that indicates invalid secret. 37 ErrInvalidSecret = errors.New("invalid secret") 38 ) 39 40 // A ContentSecurityHeader is a content security header. 41 type ContentSecurityHeader struct { 42 Key []byte 43 Timestamp string 44 ContentType int 45 Signature string 46 } 47 48 // Encrypted checks if it's a crypted request. 49 func (h *ContentSecurityHeader) Encrypted() bool { 50 return h.ContentType == httpx.CryptionType 51 } 52 53 // ParseContentSecurity parses content security settings in give r. 54 func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) ( 55 *ContentSecurityHeader, error) { 56 contentSecurity := r.Header.Get(httpx.ContentSecurity) 57 attrs := httpx.ParseHeader(contentSecurity) 58 fingerprint := attrs[httpx.KeyField] 59 secret := attrs[httpx.SecretField] 60 signature := attrs[signatureField] 61 62 if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 { 63 return nil, ErrInvalidHeader 64 } 65 66 decrypter, ok := decrypters[fingerprint] 67 if !ok { 68 return nil, ErrInvalidPublicKey 69 } 70 71 decryptedSecret, err := decrypter.DecryptBase64(secret) 72 if err != nil { 73 return nil, ErrInvalidSecret 74 } 75 76 attrs = httpx.ParseHeader(string(decryptedSecret)) 77 base64Key := attrs[httpx.KeyField] 78 timestamp := attrs[timeField] 79 contentType := attrs[httpx.TypeField] 80 81 key, err := base64.StdEncoding.DecodeString(base64Key) 82 if err != nil { 83 return nil, ErrInvalidKey 84 } 85 86 cType, err := strconv.Atoi(contentType) 87 if err != nil { 88 return nil, ErrInvalidContentType 89 } 90 91 return &ContentSecurityHeader{ 92 Key: key, 93 Timestamp: timestamp, 94 ContentType: cType, 95 Signature: signature, 96 }, nil 97 } 98 99 // VerifySignature verifies the signature in given r. 100 func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int { 101 seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64) 102 if err != nil { 103 return httpx.CodeSignatureInvalidHeader 104 } 105 106 now := time.Now().Unix() 107 toleranceSeconds := int64(tolerance.Seconds()) 108 if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds { 109 return httpx.CodeSignatureWrongTime 110 } 111 112 reqPath, reqQuery := getPathQuery(r) 113 signContent := strings.Join([]string{ 114 securityHeader.Timestamp, 115 r.Method, 116 reqPath, 117 reqQuery, 118 computeBodySignature(r), 119 }, "\n") 120 actualSignature := codec.HmacBase64(securityHeader.Key, signContent) 121 122 if securityHeader.Signature == actualSignature { 123 return httpx.CodeSignaturePass 124 } 125 126 logx.Infof("signature different, expect: %s, actual: %s", 127 securityHeader.Signature, actualSignature) 128 129 return httpx.CodeSignatureInvalidToken 130 } 131 132 func computeBodySignature(r *http.Request) string { 133 var dup io.ReadCloser 134 r.Body, dup = iox.DupReadCloser(r.Body) 135 sha := sha256.New() 136 io.Copy(sha, r.Body) 137 r.Body = dup 138 return fmt.Sprintf("%x", sha.Sum(nil)) 139 } 140 141 func getPathQuery(r *http.Request) (string, string) { 142 requestUri := r.Header.Get(requestUriHeader) 143 if len(requestUri) == 0 { 144 return r.URL.Path, r.URL.RawQuery 145 } 146 147 uri, err := url.Parse(requestUri) 148 if err != nil { 149 return r.URL.Path, r.URL.RawQuery 150 } 151 152 return uri.Path, uri.RawQuery 153 }