github.com/angenalZZZ/gofunc@v0.0.0-20210507121333-48ff1be3917b/f/jwt.go (about)

     1  package f
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/sha256"
     6  	"encoding/base64"
     7  	"errors"
     8  	"fmt"
     9  	json "github.com/json-iterator/go"
    10  	"strings"
    11  )
    12  
    13  var (
    14  	JwtDefaultKey    = []byte("HGJ766GR")
    15  	jwtDefaultHeader = jwt1header{Typ: "JWT", Alg: "HS256"}
    16  	jwtDefaultClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti"}
    17  	jwtEncodeString  = base64.RawURLEncoding.EncodeToString
    18  	jwtDecodeString  = base64.RawURLEncoding.DecodeString
    19  	jwtEncodeJson    = json.Marshal
    20  	jwtDecodeJson    = json.Unmarshal
    21  )
    22  
    23  // NewJwtToken returns a token (string) and error.
    24  // The token is a fully qualified JWT to be sent to a client via HTTP Header or other method.
    25  // Error returned will be from the jwtNewEncoded unexported function.
    26  func NewJwtToken(claims map[string]interface{}) (string, error) {
    27  	enc, err := jwt1NewEncoded(claims)
    28  	if err != nil {
    29  		return "", err
    30  	}
    31  	return enc.token, nil
    32  }
    33  
    34  // IsJwtToken returns a bool indicating whether a token (string) provided has been signed by our server.
    35  // If true, the client is authenticated and may proceed.
    36  func IsJwtToken(token string) (map[string]interface{}, bool) {
    37  	var (
    38  		err error
    39  		dec jwt1decoded
    40  	)
    41  
    42  	// decode the token
    43  	if dec, err = jwt1NewDecoded(token); err != nil {
    44  		// may want to log some error here so we have visibility
    45  		// intentionally simplifying return type to bool for ease
    46  		// of use in API. Caller should only do `if auth.Passes(str) {}`
    47  		return nil, false
    48  	}
    49  
    50  	// base64 decode payload
    51  	var payload []byte
    52  	if payload, err = jwtDecodeString(dec.payload); err != nil {
    53  		return nil, false
    54  	}
    55  	dst := map[string]interface{}{}
    56  	if err = jwtDecodeJson(payload, &dst); err != nil {
    57  		return nil, false
    58  	}
    59  	if signed, err := dec.sign(); err != nil || signed.token() != token {
    60  		return nil, false
    61  	}
    62  	return dst, true
    63  }
    64  
    65  func jwt1NewEncoded(claims map[string]interface{}) (jwt1encoded, error) {
    66  	jwt1header, err := jwtEncodeJson(jwtDefaultHeader)
    67  	if err != nil {
    68  		return jwt1encoded{}, err
    69  	}
    70  
    71  	for _, claim := range jwtDefaultClaims {
    72  		if _, ok := claims[claim]; !ok {
    73  			claims[claim] = nil
    74  		}
    75  	}
    76  
    77  	payload, err := jwtEncodeJson(claims)
    78  	if err != nil {
    79  		return jwt1encoded{}, err
    80  	}
    81  
    82  	d := jwt1decoded{jwt1header: string(jwt1header), payload: string(payload)}
    83  	d.jwt1header = jwtEncodeString([]byte(d.jwt1header))
    84  	d.payload = jwtEncodeString([]byte(d.payload))
    85  	signed, err := d.sign()
    86  	if err != nil {
    87  		return jwt1encoded{}, err
    88  	}
    89  	return jwt1encoded{token: signed.token()}, nil
    90  }
    91  
    92  func jwt1NewDecoded(token string) (jwt1decoded, error) {
    93  	e := jwt1encoded{token: token}
    94  	d, err := e.parseToken()
    95  	if err != nil {
    96  		return d, nil
    97  	}
    98  	return d, nil
    99  }
   100  
   101  type jwt1header struct {
   102  	Typ string `json:"typ"`
   103  	Alg string `json:"alg"`
   104  }
   105  
   106  type jwt1encoded struct {
   107  	token string
   108  }
   109  
   110  type jwt1decoded struct {
   111  	jwt1header string
   112  	payload    string
   113  }
   114  
   115  type jwt1signedDecoded struct {
   116  	jwt1decoded
   117  	signature string
   118  }
   119  
   120  func (s jwt1signedDecoded) token() string {
   121  	return fmt.Sprintf("%s.%s.%s", s.jwt1header, s.payload, s.signature)
   122  }
   123  
   124  func (d *jwt1decoded) sign() (jwt1signedDecoded, error) {
   125  	if d.jwt1header == "" || d.payload == "" {
   126  		return jwt1signedDecoded{}, errors.New("missing jwt1header or payload on Decoded")
   127  	}
   128  
   129  	hashed := hmac.New(sha256.New, JwtDefaultKey)
   130  	unsigned := strings.Join([]string{d.jwt1header, d.payload}, ".")
   131  	_, err := hashed.Write([]byte(unsigned))
   132  	if err != nil {
   133  		return jwt1signedDecoded{}, err
   134  	}
   135  
   136  	signed := jwt1signedDecoded{jwt1decoded: *d}
   137  	signed.signature = jwtEncodeString(hashed.Sum(nil))
   138  
   139  	return signed, nil
   140  }
   141  
   142  func (e jwt1encoded) parseToken() (jwt1decoded, error) {
   143  	parts := strings.Split(e.token, ".")
   144  	if len(parts) != 3 {
   145  		return jwt1decoded{}, errors.New("error: incorrect # of results from string parsing")
   146  	}
   147  
   148  	d := jwt1decoded{
   149  		jwt1header: parts[0],
   150  		payload:    parts[1],
   151  	}
   152  	return d, nil
   153  }