github.com/OrigamiWang/msd/micro@v0.0.0-20240229032328-b62246268db9/util/jwt/jwt.go (about)

     1  package jwt
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/sha256"
     6  	"encoding/json"
     7  	"fmt"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/OrigamiWang/msd/micro/util/base64"
    12  	logutil "github.com/OrigamiWang/msd/micro/util/log"
    13  )
    14  
    15  const SECRET = "ecef6883c93214e253352b8ac36ea93cf5ac4c34f8bb1f3217bc8376145661fb"
    16  
    17  type JwtPayload struct {
    18  	Uid   int       `json:"uid"`   // id
    19  	Uname string    `json:"uname"` // username
    20  	Exp   time.Time `json:"exp"`   // expire time
    21  }
    22  
    23  func encodeHeader() string {
    24  	header := map[string]string{}
    25  	header["alg"] = "HS256"
    26  	header["typ"] = "JWT"
    27  	headByte, err := json.Marshal(header)
    28  	if err != nil {
    29  		logutil.Error("jwt header marshal failed, err: %v", err)
    30  		return ""
    31  	}
    32  	headerBase64 := base64.EncodeBase64(headByte)
    33  	return headerBase64
    34  }
    35  
    36  func encodePayload(payload *JwtPayload) string {
    37  	payloadByte, err := json.Marshal(payload)
    38  	if err != nil {
    39  		logutil.Error("jwt payload marshal failed, err: %v", err)
    40  		return ""
    41  	}
    42  	payloadBase64 := base64.EncodeBase64(payloadByte)
    43  	return strings.TrimRight(payloadBase64, "=")
    44  }
    45  
    46  func encodeSignature(data string) string {
    47  	hmacHasher := hmac.New(sha256.New, []byte(SECRET))
    48  	hmacHasher.Write([]byte(data))
    49  	hmacHashed := hmacHasher.Sum(nil)
    50  	signature := base64.EncodeBase64(hmacHashed)
    51  	signature = strings.TrimRight(signature, "=") // 移除 Base64 编码的尾部填充字符
    52  	return signature
    53  }
    54  
    55  func EncodeJwt(jwtPayload *JwtPayload) string {
    56  	header := encodeHeader()
    57  	payload := encodePayload(jwtPayload)
    58  	if header == "" || payload == "" {
    59  		logutil.Error("jwt header or payload is empty")
    60  		return ""
    61  	}
    62  	data := header + "." + payload
    63  	signature := encodeSignature(data)
    64  	jwt := data + "." + signature
    65  	return jwt
    66  }
    67  func DecodeJwt(jwt string) (*JwtPayload, error) {
    68  	if jwt == "" {
    69  		logutil.Error("jwt is empty")
    70  		return nil, fmt.Errorf("jwt is empty")
    71  	}
    72  	jwt = strings.TrimSpace(jwt)
    73  	arr := strings.Split(jwt, ".")
    74  	if len(arr) != 3 {
    75  		logutil.Error("jwt is not valid")
    76  		return nil, fmt.Errorf("jwt is not valid")
    77  	}
    78  	// header
    79  	headerBase64 := arr[0]
    80  	err := decodeHeader(headerBase64)
    81  	if err != nil {
    82  		logutil.Error("decode jwt header base64 failed, err: %v", err)
    83  		return nil, fmt.Errorf("decode jwt header base64 failed, err: %v", err)
    84  	}
    85  	// payload
    86  	payloadBase64 := arr[1]
    87  	jwtPayload := &JwtPayload{}
    88  	err = decodePayload(payloadBase64, jwtPayload)
    89  	if err != nil {
    90  		logutil.Error("decode jwt payload failed, err: %v", err)
    91  		return nil, fmt.Errorf("decode jwt payload failed, err: %v", err)
    92  	}
    93  	// signature
    94  	signature := arr[2]
    95  	data := headerBase64 + "." + payloadBase64
    96  	if checkSignature(data, signature) {
    97  		logutil.Info("signature is valid")
    98  	} else {
    99  		logutil.Info("signature is not valid")
   100  		return nil, fmt.Errorf("signature is not valid")
   101  	}
   102  	return jwtPayload, nil
   103  }
   104  
   105  func decodeHeader(headerBase64 string) error {
   106  	headerBase, err := base64.DecodeBase64(headerBase64)
   107  	if err != nil {
   108  		logutil.Error("decode jwt header base64 failed, err: %v", err)
   109  		return err
   110  	}
   111  	header := map[string]string{}
   112  	err = json.Unmarshal(headerBase, &header)
   113  	if err != nil {
   114  		logutil.Error("jwt header json unmarshal failed, err: %v", err)
   115  		return err
   116  	}
   117  	return nil
   118  }
   119  func decodePayload(payloadBase64 string, jwtPayload *JwtPayload) error {
   120  	payloadBase, err := base64.DecodeBase64(payloadBase64)
   121  	if err != nil {
   122  		logutil.Error("decode jwt payload base64 failed, err: %v", err)
   123  		return err
   124  	}
   125  	err = json.Unmarshal(payloadBase, jwtPayload)
   126  	if err != nil {
   127  		logutil.Error("jwt payload json unmarshal failed, err: %v", err)
   128  		return err
   129  	}
   130  	return nil
   131  }
   132  
   133  func checkSignature(data, rawSignature string) bool {
   134  	return rawSignature == encodeSignature(data)
   135  }