github.com/lingyao2333/mo-zero@v1.4.1/rest/internal/security/contentsecurity.go (about)

     1  package security
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding/base64"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/lingyao2333/mo-zero/core/codec"
    16  	"github.com/lingyao2333/mo-zero/core/iox"
    17  	"github.com/lingyao2333/mo-zero/core/logx"
    18  	"github.com/lingyao2333/mo-zero/rest/httpx"
    19  )
    20  
    21  const (
    22  	requestUriHeader = "X-Request-Uri"
    23  	signatureField   = "signature"
    24  	timeField        = "time"
    25  )
    26  
    27  var (
    28  	// ErrInvalidContentType is an error that indicates invalid content type.
    29  	ErrInvalidContentType = errors.New("invalid content type")
    30  	// ErrInvalidHeader is an error that indicates invalid X-Content-Security header.
    31  	ErrInvalidHeader = errors.New("invalid X-Content-Security header")
    32  	// ErrInvalidKey is an error that indicates invalid key.
    33  	ErrInvalidKey = errors.New("invalid key")
    34  	// ErrInvalidPublicKey is an error that indicates invalid public key.
    35  	ErrInvalidPublicKey = errors.New("invalid public key")
    36  	// ErrInvalidSecret is an error that indicates invalid secret.
    37  	ErrInvalidSecret = errors.New("invalid secret")
    38  )
    39  
    40  // A ContentSecurityHeader is a content security header.
    41  type ContentSecurityHeader struct {
    42  	Key         []byte
    43  	Timestamp   string
    44  	ContentType int
    45  	Signature   string
    46  }
    47  
    48  // Encrypted checks if it's a crypted request.
    49  func (h *ContentSecurityHeader) Encrypted() bool {
    50  	return h.ContentType == httpx.CryptionType
    51  }
    52  
    53  // ParseContentSecurity parses content security settings in give r.
    54  func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
    55  	*ContentSecurityHeader, error) {
    56  	contentSecurity := r.Header.Get(httpx.ContentSecurity)
    57  	attrs := httpx.ParseHeader(contentSecurity)
    58  	fingerprint := attrs[httpx.KeyField]
    59  	secret := attrs[httpx.SecretField]
    60  	signature := attrs[signatureField]
    61  
    62  	if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
    63  		return nil, ErrInvalidHeader
    64  	}
    65  
    66  	decrypter, ok := decrypters[fingerprint]
    67  	if !ok {
    68  		return nil, ErrInvalidPublicKey
    69  	}
    70  
    71  	decryptedSecret, err := decrypter.DecryptBase64(secret)
    72  	if err != nil {
    73  		return nil, ErrInvalidSecret
    74  	}
    75  
    76  	attrs = httpx.ParseHeader(string(decryptedSecret))
    77  	base64Key := attrs[httpx.KeyField]
    78  	timestamp := attrs[timeField]
    79  	contentType := attrs[httpx.TypeField]
    80  
    81  	key, err := base64.StdEncoding.DecodeString(base64Key)
    82  	if err != nil {
    83  		return nil, ErrInvalidKey
    84  	}
    85  
    86  	cType, err := strconv.Atoi(contentType)
    87  	if err != nil {
    88  		return nil, ErrInvalidContentType
    89  	}
    90  
    91  	return &ContentSecurityHeader{
    92  		Key:         key,
    93  		Timestamp:   timestamp,
    94  		ContentType: cType,
    95  		Signature:   signature,
    96  	}, nil
    97  }
    98  
    99  // VerifySignature verifies the signature in given r.
   100  func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
   101  	seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
   102  	if err != nil {
   103  		return httpx.CodeSignatureInvalidHeader
   104  	}
   105  
   106  	now := time.Now().Unix()
   107  	toleranceSeconds := int64(tolerance.Seconds())
   108  	if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
   109  		return httpx.CodeSignatureWrongTime
   110  	}
   111  
   112  	reqPath, reqQuery := getPathQuery(r)
   113  	signContent := strings.Join([]string{
   114  		securityHeader.Timestamp,
   115  		r.Method,
   116  		reqPath,
   117  		reqQuery,
   118  		computeBodySignature(r),
   119  	}, "\n")
   120  	actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
   121  
   122  	if securityHeader.Signature == actualSignature {
   123  		return httpx.CodeSignaturePass
   124  	}
   125  
   126  	logx.Infof("signature different, expect: %s, actual: %s",
   127  		securityHeader.Signature, actualSignature)
   128  
   129  	return httpx.CodeSignatureInvalidToken
   130  }
   131  
   132  func computeBodySignature(r *http.Request) string {
   133  	var dup io.ReadCloser
   134  	r.Body, dup = iox.DupReadCloser(r.Body)
   135  	sha := sha256.New()
   136  	io.Copy(sha, r.Body)
   137  	r.Body = dup
   138  	return fmt.Sprintf("%x", sha.Sum(nil))
   139  }
   140  
   141  func getPathQuery(r *http.Request) (string, string) {
   142  	requestUri := r.Header.Get(requestUriHeader)
   143  	if len(requestUri) == 0 {
   144  		return r.URL.Path, r.URL.RawQuery
   145  	}
   146  
   147  	uri, err := url.Parse(requestUri)
   148  	if err != nil {
   149  		return r.URL.Path, r.URL.RawQuery
   150  	}
   151  
   152  	return uri.Path, uri.RawQuery
   153  }