git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/jwt/jwt.go (about) 1 package jwt 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "crypto/sha512" 7 "encoding/base64" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "strings" 12 "time" 13 ) 14 15 type Algorithm string 16 type Type string 17 18 const ( 19 AlgorithmHS256 Algorithm = "HS256" 20 AlgorithmHS512 Algorithm = "HAS512" 21 ) 22 23 const ( 24 TypeJWT Type = "JWT" 25 ) 26 27 var ( 28 ErrTokenIsNotValid = errors.New("The token is not valid") 29 ErrSignatureIsNotValid = errors.New("Signature is not valid") 30 ErrTokenHasExpired = errors.New("The token has expired") 31 ErrAlgorithmIsNotValid = fmt.Errorf("Algorithm is not valid. Valid algorithms values are: [%s, %s]", AlgorithmHS256, AlgorithmHS512) 32 ) 33 34 type Provider struct { 35 signingSecretKey []byte 36 algorithm Algorithm 37 verifyingKeys [][]byte 38 } 39 40 type header struct { 41 Algorithm Algorithm `json:"alg"` 42 Type Type `json:"typ"` 43 } 44 45 // registered claim names from https://www.rfc-editor.org/rfc/rfc7519#section-4.1 46 type reservedClaims struct { 47 ExpirationTime int64 `json:"exp,omitempty"` 48 NotBefore int64 `json:"nbf,omitempty"` 49 } 50 51 type NewProviderOptions struct { 52 VerifyingKeys [][]byte 53 } 54 55 func NewProvider(signingSecretKey []byte, algorithm Algorithm, options *NewProviderOptions) (provider *Provider, err error) { 56 if len(signingSecretKey) < 32 { 57 err = errors.New("jwt: secretKey is too short. Min length: 32 bytes") 58 return 59 } 60 61 if algorithm != AlgorithmHS256 && algorithm != AlgorithmHS512 { 62 err = ErrAlgorithmIsNotValid 63 return 64 } 65 66 defaultOptions := defaultNewProviderOptions() 67 if options == nil { 68 options = defaultOptions 69 } else { 70 if options.VerifyingKeys == nil { 71 options.VerifyingKeys = defaultOptions.VerifyingKeys 72 } 73 } 74 75 provider = &Provider{ 76 signingSecretKey: signingSecretKey, 77 algorithm: algorithm, 78 verifyingKeys: options.VerifyingKeys, 79 } 80 return 81 } 82 83 func defaultNewProviderOptions() *NewProviderOptions { 84 return &NewProviderOptions{ 85 VerifyingKeys: [][]byte{}, 86 } 87 } 88 89 type TokenOptions struct { 90 ExpirationTime *time.Time 91 NotBefore *time.Time 92 } 93 94 func (provider *Provider) IssueToken(data any, options *TokenOptions) (token string, err error) { 95 tokenBuffer := bytes.NewBuffer(make([]byte, 0, 100)) 96 97 header := header{Algorithm: provider.algorithm, Type: TypeJWT} 98 headerJson, err := json.Marshal(header) 99 if err != nil { 100 err = fmt.Errorf("jwt: encoding the header to JSON: %w", err) 101 return 102 } 103 encodedHeader := base64.RawURLEncoding.EncodeToString(headerJson) 104 tokenBuffer.WriteString(encodedHeader) 105 tokenBuffer.WriteString(".") 106 107 var claimsJson []byte 108 if options != nil && (options.ExpirationTime != nil || options.NotBefore != nil) { 109 var dataJson []byte 110 var reservedClaims = reservedClaims{} 111 112 if options.ExpirationTime != nil { 113 reservedClaims.ExpirationTime = options.ExpirationTime.Unix() 114 if reservedClaims.ExpirationTime < 1 { 115 err = fmt.Errorf("jwt: ExpirationTime should not be < 1") 116 return 117 } 118 } 119 if options.NotBefore != nil { 120 reservedClaims.NotBefore = options.NotBefore.Unix() 121 if reservedClaims.NotBefore < 1 { 122 err = fmt.Errorf("jwt: NotBefore should not be < 1") 123 return 124 } 125 } 126 127 claimsJson, err = json.Marshal(reservedClaims) 128 if err != nil { 129 err = fmt.Errorf("jwt: encoding claims to JSON: %w", err) 130 return 131 } 132 dataJson, err = json.Marshal(data) 133 if err != nil { 134 err = fmt.Errorf("jwt: encoding claims to JSON: %w", err) 135 return 136 } 137 if string(dataJson) != "{}" { 138 dataJson[0] = ',' 139 claimsJson = append(claimsJson[:len(claimsJson)-1], dataJson...) 140 } 141 } else { 142 claimsJson, err = json.Marshal(data) 143 if err != nil { 144 err = fmt.Errorf("jwt: encoding claims to JSON: %w", err) 145 return 146 } 147 if err != nil { 148 err = fmt.Errorf("jwt: encoding claims to JSON: %w", err) 149 return 150 } 151 } 152 153 encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJson) 154 tokenBuffer.WriteString(encodedClaims) 155 156 var rawSignature []byte 157 switch provider.algorithm { 158 case AlgorithmHS256: 159 rawSignature = signTokenHMAC(sha256.New, provider.signingSecretKey, tokenBuffer.Bytes()) 160 case AlgorithmHS512: 161 rawSignature = signTokenHMAC(sha512.New, provider.signingSecretKey, tokenBuffer.Bytes()) 162 default: 163 err = ErrAlgorithmIsNotValid 164 return 165 } 166 encodedSignature := base64.RawURLEncoding.EncodeToString(rawSignature) 167 tokenBuffer.WriteString(".") 168 tokenBuffer.WriteString(encodedSignature) 169 170 token = tokenBuffer.String() 171 172 return 173 } 174 175 func (provider *Provider) VerifyToken(token string, data any) (err error) { 176 if strings.Count(token, ".") != 2 { 177 err = ErrTokenIsNotValid 178 return 179 } 180 181 // Signature 182 signatureStart := strings.LastIndexByte(token, '.') 183 encodedSignature := token[signatureStart+1:] 184 signature, err := base64.RawURLEncoding.DecodeString(encodedSignature) 185 if err != nil { 186 err = ErrTokenIsNotValid 187 return 188 } 189 190 encodedHeaderAndClaims := token[:signatureStart] 191 192 switch provider.algorithm { 193 case AlgorithmHS256: 194 err = verifyTokenHMAC(sha256.New, provider.signingSecretKey, signature, []byte(encodedHeaderAndClaims)) 195 case AlgorithmHS512: 196 err = verifyTokenHMAC(sha512.New, provider.signingSecretKey, signature, []byte(encodedHeaderAndClaims)) 197 default: 198 err = ErrTokenIsNotValid 199 } 200 if err != nil { 201 return 202 } 203 204 // Header 205 var header header 206 headerEnd := strings.IndexByte(token, '.') 207 encodedHeader := token[:headerEnd] 208 headerJson, err := base64.RawURLEncoding.DecodeString(encodedHeader) 209 if err != nil { 210 err = ErrTokenIsNotValid 211 return 212 } 213 err = json.Unmarshal(headerJson, &header) 214 if err != nil { 215 err = ErrTokenIsNotValid 216 return 217 } 218 219 if header.Algorithm != provider.algorithm || header.Type != TypeJWT { 220 err = ErrTokenIsNotValid 221 return 222 } 223 224 // Reserved Claims 225 encodedClaims := token[headerEnd+1 : signatureStart] 226 claimsJson, err := base64.RawURLEncoding.DecodeString(encodedClaims) 227 if err != nil { 228 err = ErrTokenIsNotValid 229 return 230 } 231 232 var reservedClaims reservedClaims 233 err = json.Unmarshal(claimsJson, &reservedClaims) 234 if err != nil { 235 err = ErrTokenIsNotValid 236 return 237 } 238 239 now := time.Now().Unix() 240 if reservedClaims.ExpirationTime != 0 { 241 if now > reservedClaims.ExpirationTime { 242 err = ErrTokenHasExpired 243 return 244 } 245 } 246 if reservedClaims.NotBefore != 0 { 247 if now < reservedClaims.NotBefore { 248 err = ErrTokenIsNotValid 249 return 250 } 251 } 252 253 err = json.Unmarshal(claimsJson, data) 254 if err != nil { 255 err = ErrTokenIsNotValid 256 return 257 } 258 259 return 260 }