github.com/avenga/couper@v1.12.2/eval/lib/oauth2.go (about)

     1  package lib
     2  
     3  import (
     4  	"fmt"
     5  	"net/url"
     6  
     7  	"github.com/hashicorp/hcl/v2"
     8  	pkce "github.com/jimlambrt/go-oauth-pkce-code-verifier"
     9  	"github.com/zclconf/go-cty/cty"
    10  	"github.com/zclconf/go-cty/cty/function"
    11  
    12  	"github.com/avenga/couper/config"
    13  )
    14  
    15  const (
    16  	CodeVerifier                  = "code_verifier"
    17  	FnOAuthAuthorizationURL       = "oauth2_authorization_url"
    18  	FnOAuthVerifier               = "oauth2_verifier"
    19  	InternalFnOAuthHashedVerifier = "internal_oauth_hashed_verifier"
    20  )
    21  
    22  var NoOpOAuthAuthorizationURLFunction = function.New(&function.Spec{
    23  	Params: []function.Parameter{
    24  		{
    25  			Name: "oauth2_label",
    26  			Type: cty.String,
    27  		},
    28  	},
    29  	Type: function.StaticReturnType(cty.String),
    30  	Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
    31  		if len(args) > 0 {
    32  			return cty.StringVal(""), fmt.Errorf("missing oidc or beta_oauth2 block with referenced label %q", args[0].AsString())
    33  		}
    34  		return cty.StringVal(""), fmt.Errorf("missing oidc or beta_oauth2 definitions")
    35  	},
    36  })
    37  
    38  func NewOAuthAuthorizationURLFunction(ctx *hcl.EvalContext, oauth2s map[string]config.OAuth2Authorization,
    39  	verifier func() (*pkce.CodeVerifier, error), origin *url.URL,
    40  	evalFn func(*hcl.EvalContext, hcl.Expression) (cty.Value, error)) function.Function {
    41  
    42  	emptyStringVal := cty.StringVal("")
    43  
    44  	return function.New(&function.Spec{
    45  		Params: []function.Parameter{
    46  			{
    47  				Name: "oauth2_label",
    48  				Type: cty.String,
    49  			},
    50  		},
    51  		Type: function.StaticReturnType(cty.String),
    52  		Impl: func(args []cty.Value, _ cty.Type) (cty.Value, error) {
    53  			label := args[0].AsString()
    54  			oauth2, exist := oauth2s[label]
    55  			if !exist {
    56  				return NoOpOAuthAuthorizationURLFunction.Call(args)
    57  			}
    58  
    59  			authorizationEndpoint, err := oauth2.GetAuthorizationEndpoint()
    60  			if err != nil {
    61  				return emptyStringVal, err
    62  			}
    63  
    64  			oauthAuthorizationURL, err := url.Parse(authorizationEndpoint)
    65  			if err != nil {
    66  				return emptyStringVal, err
    67  			}
    68  
    69  			redirectURI := oauth2.GetRedirectURI()
    70  			if redirectURI == "" {
    71  				return emptyStringVal, fmt.Errorf("redirect_uri is required")
    72  			}
    73  
    74  			absRedirectURI, err := AbsoluteURL(redirectURI, origin)
    75  			if err != nil {
    76  				return emptyStringVal, err
    77  			}
    78  
    79  			query := oauthAuthorizationURL.Query()
    80  			query.Set("response_type", "code")
    81  			query.Set("client_id", oauth2.GetClientID())
    82  			query.Set("redirect_uri", absRedirectURI)
    83  			if scope := oauth2.GetScope(); scope != "" {
    84  				query.Set("scope", scope)
    85  			}
    86  
    87  			verifierMethod, err := oauth2.GetVerifierMethod()
    88  			if err != nil {
    89  				return cty.StringVal(""), err
    90  			}
    91  
    92  			if verifierMethod == config.CcmS256 {
    93  				codeChallenge, err := createCodeChallenge(verifier)
    94  				if err != nil {
    95  					return cty.StringVal(""), err
    96  				}
    97  
    98  				query.Set("code_challenge_method", "S256")
    99  				query.Set("code_challenge", codeChallenge)
   100  			} else {
   101  				hashedVerifier, err := createCodeChallenge(verifier)
   102  				if err != nil {
   103  					return cty.StringVal(""), err
   104  				}
   105  
   106  				query.Set(verifierMethod, hashedVerifier)
   107  			}
   108  			oauthAuthorizationURL.RawQuery = query.Encode()
   109  
   110  			return cty.StringVal(oauthAuthorizationURL.String()), nil
   111  		},
   112  	})
   113  }
   114  
   115  func NewOAuthCodeVerifierFunction(verifier func() (*pkce.CodeVerifier, error)) function.Function {
   116  	return function.New(&function.Spec{
   117  		Params: []function.Parameter{},
   118  		Type:   function.StaticReturnType(cty.String),
   119  		Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
   120  			codeVerifier, err := verifier()
   121  			if err != nil {
   122  				return cty.StringVal(""), err
   123  			}
   124  
   125  			return cty.StringVal(codeVerifier.String()), nil
   126  		},
   127  	})
   128  }
   129  
   130  func NewOAuthCodeChallengeFunction(verifier func() (*pkce.CodeVerifier, error)) function.Function {
   131  	return function.New(&function.Spec{
   132  		Params: []function.Parameter{},
   133  		Type:   function.StaticReturnType(cty.String),
   134  		Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
   135  			codeChallenge, err := createCodeChallenge(verifier)
   136  			if err != nil {
   137  				return cty.StringVal(""), err
   138  			}
   139  
   140  			return cty.StringVal(codeChallenge), nil
   141  		},
   142  	})
   143  }
   144  
   145  func createCodeChallenge(verifier func() (*pkce.CodeVerifier, error)) (string, error) {
   146  	codeVerifier, err := verifier()
   147  	if err != nil {
   148  		return "", err
   149  	}
   150  
   151  	return codeVerifier.CodeChallengeS256(), nil
   152  }