git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/jwt/jwt.go (about)

     1  package jwt
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"crypto/sha512"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"strings"
    12  	"time"
    13  )
    14  
    15  type Algorithm string
    16  type Type string
    17  
    18  const (
    19  	AlgorithmHS256 Algorithm = "HS256"
    20  	AlgorithmHS512 Algorithm = "HAS512"
    21  )
    22  
    23  const (
    24  	TypeJWT Type = "JWT"
    25  )
    26  
    27  var (
    28  	ErrTokenIsNotValid     = errors.New("The token is not valid")
    29  	ErrSignatureIsNotValid = errors.New("Signature is not valid")
    30  	ErrTokenHasExpired     = errors.New("The token has expired")
    31  	ErrAlgorithmIsNotValid = fmt.Errorf("Algorithm is not valid. Valid algorithms values are: [%s, %s]", AlgorithmHS256, AlgorithmHS512)
    32  )
    33  
    34  type Provider struct {
    35  	signingSecretKey []byte
    36  	algorithm        Algorithm
    37  	verifyingKeys    [][]byte
    38  }
    39  
    40  type header struct {
    41  	Algorithm Algorithm `json:"alg"`
    42  	Type      Type      `json:"typ"`
    43  }
    44  
    45  // registered claim names from https://www.rfc-editor.org/rfc/rfc7519#section-4.1
    46  type reservedClaims struct {
    47  	ExpirationTime int64 `json:"exp,omitempty"`
    48  	NotBefore      int64 `json:"nbf,omitempty"`
    49  }
    50  
    51  type NewProviderOptions struct {
    52  	VerifyingKeys [][]byte
    53  }
    54  
    55  func NewProvider(signingSecretKey []byte, algorithm Algorithm, options *NewProviderOptions) (provider *Provider, err error) {
    56  	if len(signingSecretKey) < 32 {
    57  		err = errors.New("jwt: secretKey is too short. Min length: 32 bytes")
    58  		return
    59  	}
    60  
    61  	if algorithm != AlgorithmHS256 && algorithm != AlgorithmHS512 {
    62  		err = ErrAlgorithmIsNotValid
    63  		return
    64  	}
    65  
    66  	defaultOptions := defaultNewProviderOptions()
    67  	if options == nil {
    68  		options = defaultOptions
    69  	} else {
    70  		if options.VerifyingKeys == nil {
    71  			options.VerifyingKeys = defaultOptions.VerifyingKeys
    72  		}
    73  	}
    74  
    75  	provider = &Provider{
    76  		signingSecretKey: signingSecretKey,
    77  		algorithm:        algorithm,
    78  		verifyingKeys:    options.VerifyingKeys,
    79  	}
    80  	return
    81  }
    82  
    83  func defaultNewProviderOptions() *NewProviderOptions {
    84  	return &NewProviderOptions{
    85  		VerifyingKeys: [][]byte{},
    86  	}
    87  }
    88  
    89  type TokenOptions struct {
    90  	ExpirationTime *time.Time
    91  	NotBefore      *time.Time
    92  }
    93  
    94  func (provider *Provider) IssueToken(data any, options *TokenOptions) (token string, err error) {
    95  	tokenBuffer := bytes.NewBuffer(make([]byte, 0, 100))
    96  
    97  	header := header{Algorithm: provider.algorithm, Type: TypeJWT}
    98  	headerJson, err := json.Marshal(header)
    99  	if err != nil {
   100  		err = fmt.Errorf("jwt: encoding the header to JSON: %w", err)
   101  		return
   102  	}
   103  	encodedHeader := base64.RawURLEncoding.EncodeToString(headerJson)
   104  	tokenBuffer.WriteString(encodedHeader)
   105  	tokenBuffer.WriteString(".")
   106  
   107  	var claimsJson []byte
   108  	if options != nil && (options.ExpirationTime != nil || options.NotBefore != nil) {
   109  		var dataJson []byte
   110  		var reservedClaims = reservedClaims{}
   111  
   112  		if options.ExpirationTime != nil {
   113  			reservedClaims.ExpirationTime = options.ExpirationTime.Unix()
   114  			if reservedClaims.ExpirationTime < 1 {
   115  				err = fmt.Errorf("jwt: ExpirationTime should not be < 1")
   116  				return
   117  			}
   118  		}
   119  		if options.NotBefore != nil {
   120  			reservedClaims.NotBefore = options.NotBefore.Unix()
   121  			if reservedClaims.NotBefore < 1 {
   122  				err = fmt.Errorf("jwt: NotBefore should not be < 1")
   123  				return
   124  			}
   125  		}
   126  
   127  		claimsJson, err = json.Marshal(reservedClaims)
   128  		if err != nil {
   129  			err = fmt.Errorf("jwt: encoding claims to JSON: %w", err)
   130  			return
   131  		}
   132  		dataJson, err = json.Marshal(data)
   133  		if err != nil {
   134  			err = fmt.Errorf("jwt: encoding claims to JSON: %w", err)
   135  			return
   136  		}
   137  		if string(dataJson) != "{}" {
   138  			dataJson[0] = ','
   139  			claimsJson = append(claimsJson[:len(claimsJson)-1], dataJson...)
   140  		}
   141  	} else {
   142  		claimsJson, err = json.Marshal(data)
   143  		if err != nil {
   144  			err = fmt.Errorf("jwt: encoding claims to JSON: %w", err)
   145  			return
   146  		}
   147  		if err != nil {
   148  			err = fmt.Errorf("jwt: encoding claims to JSON: %w", err)
   149  			return
   150  		}
   151  	}
   152  
   153  	encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJson)
   154  	tokenBuffer.WriteString(encodedClaims)
   155  
   156  	var rawSignature []byte
   157  	switch provider.algorithm {
   158  	case AlgorithmHS256:
   159  		rawSignature = signTokenHMAC(sha256.New, provider.signingSecretKey, tokenBuffer.Bytes())
   160  	case AlgorithmHS512:
   161  		rawSignature = signTokenHMAC(sha512.New, provider.signingSecretKey, tokenBuffer.Bytes())
   162  	default:
   163  		err = ErrAlgorithmIsNotValid
   164  		return
   165  	}
   166  	encodedSignature := base64.RawURLEncoding.EncodeToString(rawSignature)
   167  	tokenBuffer.WriteString(".")
   168  	tokenBuffer.WriteString(encodedSignature)
   169  
   170  	token = tokenBuffer.String()
   171  
   172  	return
   173  }
   174  
   175  func (provider *Provider) VerifyToken(token string, data any) (err error) {
   176  	if strings.Count(token, ".") != 2 {
   177  		err = ErrTokenIsNotValid
   178  		return
   179  	}
   180  
   181  	// Signature
   182  	signatureStart := strings.LastIndexByte(token, '.')
   183  	encodedSignature := token[signatureStart+1:]
   184  	signature, err := base64.RawURLEncoding.DecodeString(encodedSignature)
   185  	if err != nil {
   186  		err = ErrTokenIsNotValid
   187  		return
   188  	}
   189  
   190  	encodedHeaderAndClaims := token[:signatureStart]
   191  
   192  	switch provider.algorithm {
   193  	case AlgorithmHS256:
   194  		err = verifyTokenHMAC(sha256.New, provider.signingSecretKey, signature, []byte(encodedHeaderAndClaims))
   195  	case AlgorithmHS512:
   196  		err = verifyTokenHMAC(sha512.New, provider.signingSecretKey, signature, []byte(encodedHeaderAndClaims))
   197  	default:
   198  		err = ErrTokenIsNotValid
   199  	}
   200  	if err != nil {
   201  		return
   202  	}
   203  
   204  	// Header
   205  	var header header
   206  	headerEnd := strings.IndexByte(token, '.')
   207  	encodedHeader := token[:headerEnd]
   208  	headerJson, err := base64.RawURLEncoding.DecodeString(encodedHeader)
   209  	if err != nil {
   210  		err = ErrTokenIsNotValid
   211  		return
   212  	}
   213  	err = json.Unmarshal(headerJson, &header)
   214  	if err != nil {
   215  		err = ErrTokenIsNotValid
   216  		return
   217  	}
   218  
   219  	if header.Algorithm != provider.algorithm || header.Type != TypeJWT {
   220  		err = ErrTokenIsNotValid
   221  		return
   222  	}
   223  
   224  	// Reserved Claims
   225  	encodedClaims := token[headerEnd+1 : signatureStart]
   226  	claimsJson, err := base64.RawURLEncoding.DecodeString(encodedClaims)
   227  	if err != nil {
   228  		err = ErrTokenIsNotValid
   229  		return
   230  	}
   231  
   232  	var reservedClaims reservedClaims
   233  	err = json.Unmarshal(claimsJson, &reservedClaims)
   234  	if err != nil {
   235  		err = ErrTokenIsNotValid
   236  		return
   237  	}
   238  
   239  	now := time.Now().Unix()
   240  	if reservedClaims.ExpirationTime != 0 {
   241  		if now > reservedClaims.ExpirationTime {
   242  			err = ErrTokenHasExpired
   243  			return
   244  		}
   245  	}
   246  	if reservedClaims.NotBefore != 0 {
   247  		if now < reservedClaims.NotBefore {
   248  			err = ErrTokenIsNotValid
   249  			return
   250  		}
   251  	}
   252  
   253  	err = json.Unmarshal(claimsJson, data)
   254  	if err != nil {
   255  		err = ErrTokenIsNotValid
   256  		return
   257  	}
   258  
   259  	return
   260  }