github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sign/sign.go (about)

     1  package sign
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"net/http"
     9  	"net/url"
    10  	"sort"
    11  
    12  	"github.com/artisanhe/tools/courier/status_error"
    13  	"github.com/artisanhe/tools/courier/transport_http/transform"
    14  )
    15  
    16  var Sign = "sign"
    17  var RandString = "randString"
    18  var AccessKey = "AccessKey"
    19  
    20  type SecretExchanger func(key string) (string, error)
    21  
    22  type SignParams struct {
    23  	AccessKeyParam
    24  	// 签名
    25  	Sign string `json:"sign" validate:"@string[1,32]" in:"query"`
    26  	// 随机字符串
    27  	RandString string `json:"randString" validate:"@string[1,32]" in:"query"`
    28  }
    29  
    30  func getSign(req *http.Request, query url.Values, secretExchanger SecretExchanger) (sign []byte, origin []byte, err error) {
    31  	accessKey := req.Header.Get(AccessKey)
    32  	randString := query.Get(RandString)
    33  
    34  	if accessKey != "" && randString != "" {
    35  		secret, errForExchange := secretExchanger(accessKey)
    36  		if errForExchange != nil {
    37  			err = status_error.InvalidSecret.StatusError().WithDesc(errForExchange.Error())
    38  			return
    39  		}
    40  		bodyBytes := make([]byte, 0)
    41  		if req.Body != nil {
    42  			bodyBytes, err = transform.CloneRequestBody(req)
    43  			if err != nil {
    44  				return
    45  			}
    46  		}
    47  		sign, origin = Secret(secret).Encode(query, bodyBytes)
    48  	}
    49  	return
    50  }
    51  
    52  type Secret string
    53  
    54  func (secret Secret) Encode(query url.Values, body []byte) (sign []byte, origin []byte) {
    55  	keyList := make([]string, 0)
    56  	for key := range query {
    57  		keyList = append(keyList, key)
    58  	}
    59  	sort.Strings(keyList)
    60  
    61  	rawSignStr := &bytes.Buffer{}
    62  	for _, key := range keyList {
    63  		values := query[key]
    64  		if len(values) == 0 || key == Sign {
    65  			continue
    66  		}
    67  		for _, v := range values {
    68  			rawSignStr.WriteString(key)
    69  			rawSignStr.WriteString("=")
    70  			rawSignStr.WriteString(v)
    71  			rawSignStr.WriteString("&")
    72  		}
    73  	}
    74  
    75  	if len(body) > 0 {
    76  		rawSignStr.WriteString("body")
    77  		rawSignStr.WriteString("=")
    78  		rawSignStr.Write(genMd5(body))
    79  		rawSignStr.WriteString("&")
    80  	}
    81  
    82  	rawSignStr.WriteString("secret")
    83  	rawSignStr.WriteString("=")
    84  	rawSignStr.WriteString(string(secret))
    85  
    86  	origin = rawSignStr.Bytes()
    87  	sign = genMd5(origin)
    88  	return
    89  }
    90  
    91  func genMd5(src []byte) (dst []byte) {
    92  	hasher := md5.New()
    93  	hasher.Write(src)
    94  	sum := hasher.Sum(nil)
    95  
    96  	dst = make([]byte, hex.EncodedLen(len(sum)))
    97  	hex.Encode(dst, sum)
    98  	return
    99  }
   100  
   101  func GenSign(secret string, queryData map[string]interface{}, body []byte) string {
   102  	keyList := []string{}
   103  	for key, _ := range queryData {
   104  		keyList = append(keyList, key)
   105  	}
   106  	sort.Strings(keyList)
   107  	var rawSignStr string
   108  	for _, key := range keyList {
   109  		rawSignStr += fmt.Sprintf("%v=%v%s", key, queryData[key], "&")
   110  	}
   111  	if len(body) > 0 {
   112  		rawSignStr += fmt.Sprintf("body=%s%s", string(genMd5(body)), "&")
   113  	}
   114  	rawSignStr += fmt.Sprintf("secret=%s", secret)
   115  
   116  	return string(genMd5([]byte(rawSignStr)))
   117  }