github.com/greenpau/go-authcrunch@v1.1.4/pkg/idp/oauth/validator.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth
    16  
    17  import (
    18  	"fmt"
    19  	jwtlib "github.com/golang-jwt/jwt/v4"
    20  	"github.com/greenpau/go-authcrunch/pkg/errors"
    21  	"github.com/greenpau/go-authcrunch/pkg/kms"
    22  	"strings"
    23  )
    24  
    25  var (
    26  	tokenFields = []string{
    27  		"sub", "name", "email", "iat", "exp", "jti",
    28  		"iss", "groups", "picture",
    29  		"roles", "role", "groups", "group",
    30  		"given_name", "family_name",
    31  	}
    32  )
    33  
    34  func (b *IdentityProvider) validateAccessToken(state string, data map[string]interface{}) (map[string]interface{}, error) {
    35  	var tokenString string
    36  	if v, exists := data[b.config.IdentityTokenName]; exists {
    37  		tokenString = v.(string)
    38  	} else {
    39  		return nil, errors.ErrIdentityProviderOAuthAccessTokenNotFound.WithArgs(b.config.IdentityTokenName)
    40  	}
    41  
    42  	token, err := jwtlib.Parse(tokenString, func(token *jwtlib.Token) (interface{}, error) {
    43  		switch {
    44  		case strings.HasPrefix(token.Method.Alg(), "RS"):
    45  			if _, validMethod := token.Method.(*jwtlib.SigningMethodRSA); !validMethod {
    46  				return nil, errors.ErrIdentityProviderOAuthAccessTokenSignMethodNotSupported.WithArgs(b.config.IdentityTokenName, token.Header["alg"])
    47  			}
    48  		case strings.HasPrefix(token.Method.Alg(), "ES"):
    49  			if _, validMethod := token.Method.(*jwtlib.SigningMethodECDSA); !validMethod {
    50  				return nil, errors.ErrIdentityProviderOAuthAccessTokenSignMethodNotSupported.WithArgs(b.config.IdentityTokenName, token.Header["alg"])
    51  			}
    52  		case strings.HasPrefix(token.Method.Alg(), "HS"):
    53  			return nil, errors.ErrIdentityProviderOAuthAccessTokenSignMethodNotSupported.WithArgs(b.config.IdentityTokenName, token.Method.Alg())
    54  		}
    55  
    56  		keyID, found := token.Header["kid"].(string)
    57  		if !found {
    58  			// If key id is not found in the header, then try the first available key.
    59  			for _, key := range b.keys {
    60  				return key.GetPublic(), nil
    61  			}
    62  			// return nil, errors.ErrIdentityProviderOAuthAccessTokenKeyIDNotFound.WithArgs(b.config.IdentityTokenName)
    63  		}
    64  		key, exists := b.keys[keyID]
    65  		if !exists {
    66  			if !b.disableKeyVerification {
    67  				if err := b.fetchKeysURL(); err != nil {
    68  					return nil, errors.ErrIdentityProviderOauthKeyFetchFailed.WithArgs(err)
    69  				}
    70  			}
    71  			key, exists = b.keys[keyID]
    72  			if !exists {
    73  				return nil, errors.ErrIdentityProviderOAuthAccessTokenKeyIDNotRegistered.WithArgs(b.config.IdentityTokenName, keyID)
    74  			}
    75  		}
    76  		return key.GetPublic(), nil
    77  	})
    78  
    79  	if err != nil {
    80  		return nil, errors.ErrIdentityProviderOAuthParseToken.WithArgs(b.config.IdentityTokenName, err)
    81  	}
    82  
    83  	if _, ok := token.Claims.(jwtlib.Claims); !ok && !token.Valid {
    84  		return nil, errors.ErrIdentityProviderOAuthInvalidToken.WithArgs(b.config.IdentityTokenName, tokenString)
    85  	}
    86  	claims := token.Claims.(jwtlib.MapClaims)
    87  	if _, exists := claims["nonce"]; !exists {
    88  		return nil, errors.ErrIdentityProviderOAuthNonceValidationFailed.WithArgs(b.config.IdentityTokenName, "nonce not found")
    89  	}
    90  	if err := b.state.validateNonce(state, claims["nonce"].(string)); err != nil {
    91  		return nil, errors.ErrIdentityProviderOAuthNonceValidationFailed.WithArgs(b.config.IdentityTokenName, err)
    92  	}
    93  
    94  	if !b.disableEmailClaimCheck {
    95  		if _, exists := claims["email"]; !exists {
    96  			return nil, errors.ErrIdentityProviderOAuthEmailNotFound.WithArgs(b.config.IdentityTokenName)
    97  		}
    98  	}
    99  
   100  	m := make(map[string]interface{})
   101  	for _, k := range tokenFields {
   102  		if _, exists := claims[k]; !exists {
   103  			continue
   104  		}
   105  		m[k] = claims[k]
   106  	}
   107  
   108  	if _, exists := m["name"]; !exists {
   109  		if _, exists := m["given_name"]; exists {
   110  			if _, exists := m["family_name"]; exists {
   111  				m["name"] = fmt.Sprintf("%s %s", m["given_name"].(string), m["family_name"].(string))
   112  				delete(m, "given_name")
   113  				delete(m, "family_name")
   114  			}
   115  		}
   116  	}
   117  
   118  	switch b.config.Driver {
   119  	case "cognito":
   120  		if v, exists := data["id_token"]; exists {
   121  			if tp, err := kms.ParsePayloadFromToken(v.(string)); err == nil {
   122  				roles := []string{}
   123  				for k, val := range tp {
   124  					switch k {
   125  					case "custom:roles", "cognito:groups", "cognito:roles":
   126  						switch values := val.(type) {
   127  						case string:
   128  							if k == "custom:roles" {
   129  								for _, roleName := range strings.Split(values, "|") {
   130  									roles = append(roles, roleName)
   131  								}
   132  							} else {
   133  								roles = append(roles, values)
   134  							}
   135  						case []interface{}:
   136  							for _, value := range values {
   137  								switch roleName := value.(type) {
   138  								case string:
   139  									roles = append(roles, roleName)
   140  								}
   141  							}
   142  						}
   143  					case "custom:timezone":
   144  						m["timezone"] = val.(string)
   145  					case "cognito:username":
   146  						m["username"] = val.(string)
   147  					case "zoneinfo":
   148  						m["timezone"] = val.(string)
   149  					}
   150  				}
   151  				if len(roles) > 0 {
   152  					m["roles"] = roles
   153  				}
   154  			}
   155  		}
   156  	}
   157  
   158  	return m, nil
   159  }