github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/utils/jwtauth/validate.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package jwtauth 16 17 import ( 18 "errors" 19 "fmt" 20 "time" 21 22 jose "gopkg.in/square/go-jose.v2" 23 "gopkg.in/square/go-jose.v2/jwt" 24 ) 25 26 type KeyProvider interface { 27 GetKey(kid string) ([]jose.JSONWebKey, error) 28 } 29 30 var ErrKeyNotFound = errors.New("Key not found") 31 32 func ValidateJWT(unparsed string, reqTime time.Time, keyProvider KeyProvider, expectedClaims jwt.Expected) (*Claims, error) { 33 parsed, err := jwt.ParseSigned(unparsed) 34 if err != nil { 35 return nil, err 36 } 37 38 if len(parsed.Headers) != 1 { 39 return nil, fmt.Errorf("ValidateJWT: Unexpected JWT headers length %v.", len(parsed.Headers)) 40 } 41 42 if parsed.Headers[0].Algorithm != "RS512" && 43 parsed.Headers[0].Algorithm != "RS256" && 44 parsed.Headers[0].Algorithm != "EdDSA" { 45 return nil, fmt.Errorf("ValidateJWT: Currently only support RS256, RS512 and EdDSA signatures. Unexpected algorithm: %v", parsed.Headers[0].Algorithm) 46 } 47 48 keyID := parsed.Headers[0].KeyID 49 50 keys, err := keyProvider.GetKey(keyID) 51 if err != nil { 52 return nil, err 53 } 54 55 var claims Claims 56 claimsError := fmt.Errorf("ValidateJWT: KeyID: %v. Err: %w", keyID, ErrKeyNotFound) 57 for _, key := range keys { 58 claimsError = parsed.Claims(key.Key, &claims) 59 if claimsError == nil { 60 break 61 } 62 } 63 if claimsError != nil { 64 return nil, claimsError 65 } 66 67 if err := claims.Validate(expectedClaims.WithTime(reqTime)); err != nil { 68 return nil, err 69 } 70 71 return &claims, nil 72 }