git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/jst/jst.go (about)

     1  package jst
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/subtle"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"strings"
    12  	"time"
    13  
    14  	"git.sr.ht/~pingoo/stdx/crypto/chacha20"
    15  	"github.com/zeebo/blake3"
    16  )
    17  
    18  // jst.v1.local.[header].[payload].[signature]
    19  
    20  const (
    21  	V1KeySize = 32
    22  	// nonceSize is the size of the nonce to encrypt tokens, in bytes.
    23  	v1nonceSize = chacha20.NonceSizeX
    24  
    25  	v1EncryptionKeyContext     = "jst-v1 2023-12-31 23:59:59.999 encryption-key"
    26  	v1AuthenticationKeyContext = "jst-v1 2024-01-01 00:00:00.000 authentication-key"
    27  )
    28  
    29  var (
    30  	ErrTokenIsNotValid     = errors.New("jst: token is not valid")
    31  	ErrSignatureIsNotValid = errors.New("jst: signature is not valid")
    32  	ErrTokenHasExpired     = errors.New("jst: token has expired")
    33  )
    34  
    35  type Provider struct {
    36  	defaultKey  string
    37  	keyProvider KeyProvider
    38  }
    39  
    40  func NewProvider(keyProvider KeyProvider, defaultKey string) (provider *Provider, err error) {
    41  	provider = &Provider{
    42  		defaultKey:  defaultKey,
    43  		keyProvider: keyProvider,
    44  	}
    45  	return
    46  }
    47  
    48  type TokenOptions struct {
    49  	NotBefore   *time.Time
    50  	IssuedAt    *time.Time
    51  	KeyID       string
    52  	Compression string
    53  }
    54  
    55  type HeaderV1 struct {
    56  	ExpiresAt   *time.Time `json:"expires_at,omitempty"`
    57  	NotBefore   *time.Time `json:"not_before,omitempty"`
    58  	IssuedAt    *time.Time `json:"issued_at,omitempty"`
    59  	KeyID       string     `json:"key_id,omitempty"`
    60  	Compression string     `json:"compression,omitempty"`
    61  	Nonce       []byte     `json:"nonce"`
    62  }
    63  
    64  func (provider *Provider) IssueToken(payload any, expiresAt *time.Time, options *TokenOptions) (token string, err error) {
    65  	tokenBuffer := bytes.NewBuffer(make([]byte, 0, 120))
    66  
    67  	if options == nil {
    68  		options = &TokenOptions{}
    69  	}
    70  
    71  	keyId := options.KeyID
    72  	if keyId == "" {
    73  		keyId = provider.defaultKey
    74  	}
    75  
    76  	maskterKey, err := provider.keyProvider.GetKey(keyId)
    77  	if err != nil {
    78  		return
    79  	}
    80  	if len(maskterKey) != V1KeySize {
    81  		err = fmt.Errorf("jst: key %s is invalid. Expected size: %d bytes", keyId, V1KeySize)
    82  		return
    83  	}
    84  
    85  	nonce := make([]byte, v1nonceSize)
    86  	_, err = rand.Read(nonce)
    87  	if err != nil {
    88  		err = fmt.Errorf("jst: error generating random nonce: %w", err)
    89  		return
    90  	}
    91  
    92  	// derive keys
    93  	encryptionKey := deriveKey(maskterKey, v1EncryptionKeyContext, nonce)
    94  	authenticationKey := deriveKey(maskterKey, v1AuthenticationKeyContext, nonce)
    95  
    96  	// prefix
    97  	_, err = tokenBuffer.WriteString("jst.v1.")
    98  	if err != nil {
    99  		err = fmt.Errorf("jst: generating token: %w", err)
   100  		return
   101  	}
   102  
   103  	// header
   104  	header := HeaderV1{
   105  		ExpiresAt: expiresAt,
   106  		NotBefore: options.NotBefore,
   107  		IssuedAt:  options.IssuedAt,
   108  		KeyID:     keyId,
   109  		Nonce:     nonce,
   110  	}
   111  	headerJson, err := json.Marshal(header)
   112  	if err != nil {
   113  		err = fmt.Errorf("jst: error encoding header to JSON: %w", err)
   114  		return
   115  	}
   116  	headerBase64 := base64.RawURLEncoding.EncodeToString(headerJson)
   117  
   118  	// we can ignore some errors as Buffer.Write* methods never return an error
   119  	_, _ = tokenBuffer.WriteString(headerBase64)
   120  
   121  	// payload
   122  	payloadJSON, err := json.Marshal(payload)
   123  	if err != nil {
   124  		err = fmt.Errorf("jst: error encoding payload to JSON: %w", err)
   125  		return
   126  	}
   127  
   128  	// we can ignore error as we already checked that the key and nonce are of the correct size
   129  	cipher, _ := chacha20.NewX(encryptionKey, nonce)
   130  	cipherTextBuffer := make([]byte, len(payloadJSON))
   131  	cipher.XORKeyStream(cipherTextBuffer, payloadJSON)
   132  	payloadBase64 := base64.RawURLEncoding.EncodeToString(cipherTextBuffer)
   133  
   134  	_ = tokenBuffer.WriteByte('.')
   135  	_, _ = tokenBuffer.WriteString(payloadBase64)
   136  
   137  	// we can ignore error as we are sure that the key is of the good size
   138  	macHasher, _ := blake3.NewKeyed(authenticationKey)
   139  	macHasher.Write(tokenBuffer.Bytes())
   140  	signature := macHasher.Sum(nil)
   141  	signatureBase64 := base64.RawURLEncoding.EncodeToString(signature)
   142  
   143  	_ = tokenBuffer.WriteByte('.')
   144  	_, _ = tokenBuffer.WriteString(signatureBase64)
   145  
   146  	token = tokenBuffer.String()
   147  
   148  	return
   149  }
   150  
   151  func (provider *Provider) VerifyToken(token string, data any) (header HeaderV1, err error) {
   152  	if strings.Count(token, ".") != 4 {
   153  		err = ErrTokenIsNotValid
   154  		return
   155  	}
   156  
   157  	if !strings.HasPrefix(token, "jst.v1.") {
   158  		err = ErrTokenIsNotValid
   159  		return
   160  	}
   161  
   162  	// Header
   163  	headerEnd := strings.IndexByte(token[7:], '.') + 7
   164  	encodedHeader := token[7:headerEnd]
   165  	headerJson, err := base64.RawURLEncoding.DecodeString(encodedHeader)
   166  	if err != nil {
   167  		err = ErrTokenIsNotValid
   168  		return
   169  	}
   170  	err = json.Unmarshal(headerJson, &header)
   171  	if err != nil {
   172  		err = ErrTokenIsNotValid
   173  		return
   174  	}
   175  
   176  	if len(header.Nonce) != v1nonceSize {
   177  		err = ErrTokenIsNotValid
   178  		return
   179  	}
   180  
   181  	maskterKey, err := provider.keyProvider.GetKey(header.KeyID)
   182  	if err != nil {
   183  		return
   184  	}
   185  
   186  	// derive keys
   187  	encryptionKey := deriveKey(maskterKey, v1EncryptionKeyContext, header.Nonce)
   188  	authenticationKey := deriveKey(maskterKey, v1AuthenticationKeyContext, header.Nonce)
   189  
   190  	signatureStart := strings.LastIndexByte(token, '.')
   191  	encodedSignature := token[signatureStart+1:]
   192  	tokenSignature, err := base64.RawURLEncoding.DecodeString(encodedSignature)
   193  	if err != nil {
   194  		err = ErrTokenIsNotValid
   195  		return
   196  	}
   197  
   198  	encodedHeaderAndPayload := token[:signatureStart]
   199  
   200  	// we can ignore error as we are sure that the key is of the good size
   201  	macHasher, _ := blake3.NewKeyed(authenticationKey)
   202  	macHasher.Write([]byte(encodedHeaderAndPayload))
   203  	signature := macHasher.Sum(nil)
   204  
   205  	if subtle.ConstantTimeCompare(tokenSignature, signature) != 1 {
   206  		err = ErrSignatureIsNotValid
   207  		return
   208  	}
   209  
   210  	// Payload
   211  	encodedPayload := token[headerEnd+1 : signatureStart]
   212  	encryptedPayload, err := base64.RawURLEncoding.DecodeString(encodedPayload)
   213  	if err != nil {
   214  		err = ErrTokenIsNotValid
   215  		return
   216  	}
   217  	// we can ignore error as we already checked that the key and nonce are of the correct size
   218  	cipher, _ := chacha20.NewX(encryptionKey, header.Nonce)
   219  	payloadJson := make([]byte, len(encryptedPayload))
   220  	cipher.XORKeyStream(payloadJson, encryptedPayload)
   221  
   222  	err = json.Unmarshal(payloadJson, data)
   223  	if err != nil {
   224  		err = ErrTokenIsNotValid
   225  		return
   226  	}
   227  
   228  	return
   229  }
   230  
   231  func deriveKey(parentKey []byte, context string, nonce []byte) []byte {
   232  	hasher := blake3.NewDeriveKey(context)
   233  	hasher.Write(nonce)
   234  	hasher.Write(parentKey)
   235  	// hasher.Write(binary.LittleEndian.AppendUint64([]byte{}, uint64(len(nonce))))
   236  	// hasher.Write(binary.LittleEndian.AppendUint64([]byte{}, uint64(len(parentKey))))
   237  	return hasher.Sum(nil)
   238  }