github.com/greenpau/go-authcrunch@v1.1.4/pkg/idp/oauth/validator.go (about) 1 // Copyright 2022 Paul Greenberg greenpau@outlook.com 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package oauth 16 17 import ( 18 "fmt" 19 jwtlib "github.com/golang-jwt/jwt/v4" 20 "github.com/greenpau/go-authcrunch/pkg/errors" 21 "github.com/greenpau/go-authcrunch/pkg/kms" 22 "strings" 23 ) 24 25 var ( 26 tokenFields = []string{ 27 "sub", "name", "email", "iat", "exp", "jti", 28 "iss", "groups", "picture", 29 "roles", "role", "groups", "group", 30 "given_name", "family_name", 31 } 32 ) 33 34 func (b *IdentityProvider) validateAccessToken(state string, data map[string]interface{}) (map[string]interface{}, error) { 35 var tokenString string 36 if v, exists := data[b.config.IdentityTokenName]; exists { 37 tokenString = v.(string) 38 } else { 39 return nil, errors.ErrIdentityProviderOAuthAccessTokenNotFound.WithArgs(b.config.IdentityTokenName) 40 } 41 42 token, err := jwtlib.Parse(tokenString, func(token *jwtlib.Token) (interface{}, error) { 43 switch { 44 case strings.HasPrefix(token.Method.Alg(), "RS"): 45 if _, validMethod := token.Method.(*jwtlib.SigningMethodRSA); !validMethod { 46 return nil, errors.ErrIdentityProviderOAuthAccessTokenSignMethodNotSupported.WithArgs(b.config.IdentityTokenName, token.Header["alg"]) 47 } 48 case strings.HasPrefix(token.Method.Alg(), "ES"): 49 if _, validMethod := token.Method.(*jwtlib.SigningMethodECDSA); !validMethod { 50 return nil, errors.ErrIdentityProviderOAuthAccessTokenSignMethodNotSupported.WithArgs(b.config.IdentityTokenName, token.Header["alg"]) 51 } 52 case strings.HasPrefix(token.Method.Alg(), "HS"): 53 return nil, errors.ErrIdentityProviderOAuthAccessTokenSignMethodNotSupported.WithArgs(b.config.IdentityTokenName, token.Method.Alg()) 54 } 55 56 keyID, found := token.Header["kid"].(string) 57 if !found { 58 // If key id is not found in the header, then try the first available key. 59 for _, key := range b.keys { 60 return key.GetPublic(), nil 61 } 62 // return nil, errors.ErrIdentityProviderOAuthAccessTokenKeyIDNotFound.WithArgs(b.config.IdentityTokenName) 63 } 64 key, exists := b.keys[keyID] 65 if !exists { 66 if !b.disableKeyVerification { 67 if err := b.fetchKeysURL(); err != nil { 68 return nil, errors.ErrIdentityProviderOauthKeyFetchFailed.WithArgs(err) 69 } 70 } 71 key, exists = b.keys[keyID] 72 if !exists { 73 return nil, errors.ErrIdentityProviderOAuthAccessTokenKeyIDNotRegistered.WithArgs(b.config.IdentityTokenName, keyID) 74 } 75 } 76 return key.GetPublic(), nil 77 }) 78 79 if err != nil { 80 return nil, errors.ErrIdentityProviderOAuthParseToken.WithArgs(b.config.IdentityTokenName, err) 81 } 82 83 if _, ok := token.Claims.(jwtlib.Claims); !ok && !token.Valid { 84 return nil, errors.ErrIdentityProviderOAuthInvalidToken.WithArgs(b.config.IdentityTokenName, tokenString) 85 } 86 claims := token.Claims.(jwtlib.MapClaims) 87 if _, exists := claims["nonce"]; !exists { 88 return nil, errors.ErrIdentityProviderOAuthNonceValidationFailed.WithArgs(b.config.IdentityTokenName, "nonce not found") 89 } 90 if err := b.state.validateNonce(state, claims["nonce"].(string)); err != nil { 91 return nil, errors.ErrIdentityProviderOAuthNonceValidationFailed.WithArgs(b.config.IdentityTokenName, err) 92 } 93 94 if !b.disableEmailClaimCheck { 95 if _, exists := claims["email"]; !exists { 96 return nil, errors.ErrIdentityProviderOAuthEmailNotFound.WithArgs(b.config.IdentityTokenName) 97 } 98 } 99 100 m := make(map[string]interface{}) 101 for _, k := range tokenFields { 102 if _, exists := claims[k]; !exists { 103 continue 104 } 105 m[k] = claims[k] 106 } 107 108 if _, exists := m["name"]; !exists { 109 if _, exists := m["given_name"]; exists { 110 if _, exists := m["family_name"]; exists { 111 m["name"] = fmt.Sprintf("%s %s", m["given_name"].(string), m["family_name"].(string)) 112 delete(m, "given_name") 113 delete(m, "family_name") 114 } 115 } 116 } 117 118 switch b.config.Driver { 119 case "cognito": 120 if v, exists := data["id_token"]; exists { 121 if tp, err := kms.ParsePayloadFromToken(v.(string)); err == nil { 122 roles := []string{} 123 for k, val := range tp { 124 switch k { 125 case "custom:roles", "cognito:groups", "cognito:roles": 126 switch values := val.(type) { 127 case string: 128 if k == "custom:roles" { 129 for _, roleName := range strings.Split(values, "|") { 130 roles = append(roles, roleName) 131 } 132 } else { 133 roles = append(roles, values) 134 } 135 case []interface{}: 136 for _, value := range values { 137 switch roleName := value.(type) { 138 case string: 139 roles = append(roles, roleName) 140 } 141 } 142 } 143 case "custom:timezone": 144 m["timezone"] = val.(string) 145 case "cognito:username": 146 m["username"] = val.(string) 147 case "zoneinfo": 148 m["timezone"] = val.(string) 149 } 150 } 151 if len(roles) > 0 { 152 m["roles"] = roles 153 } 154 } 155 } 156 } 157 158 return m, nil 159 }