github.com/avenga/couper@v1.12.2/handler/transport/oauth2_req_auth.go (about)

     1  package transport
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/url"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/hashicorp/hcl/v2"
    11  	"github.com/zclconf/go-cty/cty"
    12  
    13  	"github.com/avenga/couper/cache"
    14  	"github.com/avenga/couper/config"
    15  	"github.com/avenga/couper/config/request"
    16  	"github.com/avenga/couper/errors"
    17  	"github.com/avenga/couper/eval"
    18  	"github.com/avenga/couper/eval/lib"
    19  	"github.com/avenga/couper/internal/seetie"
    20  	"github.com/avenga/couper/oauth2"
    21  )
    22  
    23  var supportedGrantTypes = map[string]struct{}{
    24  	config.ClientCredentials: {},
    25  	config.JwtBearer:         {},
    26  	config.Password:          {},
    27  }
    28  
    29  var (
    30  	_ RequestAuthorizer = &OAuth2ReqAuth{}
    31  )
    32  
    33  type assertionCreator interface {
    34  	createAssertion(ctx *hcl.EvalContext) (string, error)
    35  }
    36  
    37  var (
    38  	_ assertionCreator = &assertionCreatorFromExpr{}
    39  	_ assertionCreator = &assertionCreatorFromJSP{}
    40  )
    41  
    42  type assertionCreatorFromExpr struct {
    43  	expr hcl.Expression
    44  }
    45  
    46  func newAssertionCreatorFromExpr(expr hcl.Expression) assertionCreator {
    47  	return &assertionCreatorFromExpr{
    48  		expr,
    49  	}
    50  }
    51  
    52  func (ac *assertionCreatorFromExpr) createAssertion(ctx *hcl.EvalContext) (string, error) {
    53  	assertionValue, err := eval.Value(ctx, ac.expr)
    54  	if err != nil {
    55  		return "", err
    56  	}
    57  
    58  	if assertionValue.IsNull() {
    59  		return "", fmt.Errorf("assertion expression evaluates to null")
    60  	}
    61  	if assertionValue.Type() != cty.String {
    62  		return "", fmt.Errorf("assertion expression must evaluate to a string")
    63  	}
    64  
    65  	return assertionValue.AsString(), nil
    66  }
    67  
    68  type assertionCreatorFromJSP struct {
    69  	*lib.JWTSigningConfig
    70  	headers map[string]interface{}
    71  	claims  map[string]interface{}
    72  }
    73  
    74  func newAssertionCreatorFromJSP(evalCtx *hcl.EvalContext, jsp *config.JWTSigningProfile) (assertionCreator, error) {
    75  	signingConfig, err := lib.NewJWTSigningConfigFromJWTSigningProfile(jsp, nil)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	var headers, claims map[string]interface{}
    81  
    82  	if signingConfig.Headers != nil {
    83  		v, err := eval.Value(evalCtx, signingConfig.Headers)
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  		headers = seetie.ValueToMap(v)
    88  	}
    89  
    90  	if signingConfig.Claims != nil {
    91  		cl, err := eval.Value(evalCtx, signingConfig.Claims)
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  		claims = seetie.ValueToMap(cl)
    96  	}
    97  
    98  	return &assertionCreatorFromJSP{
    99  		signingConfig,
   100  		headers,
   101  		claims,
   102  	}, nil
   103  }
   104  
   105  func (ac *assertionCreatorFromJSP) createAssertion(_ *hcl.EvalContext) (string, error) {
   106  	claims := make(map[string]interface{})
   107  	for k, v := range ac.claims {
   108  		claims[k] = v
   109  	}
   110  	now := time.Now().Unix()
   111  	claims["exp"] = now + ac.TTL
   112  
   113  	return lib.CreateJWT(ac.SignatureAlgorithm, ac.Key, claims, ac.headers)
   114  }
   115  
   116  // OAuth2ReqAuth represents the transport <OAuth2ReqAuth> object.
   117  type OAuth2ReqAuth struct {
   118  	config           *config.OAuth2ReqAuth
   119  	mu               sync.Mutex
   120  	memStore         *cache.MemoryStore
   121  	oauth2Client     *oauth2.Client
   122  	storageKey       string
   123  	assertionCreator assertionCreator
   124  }
   125  
   126  // NewOAuth2ReqAuth implements the http.RoundTripper interface to wrap an existing Backend / http.RoundTripper
   127  // to retrieve a valid token before passing the initial out request.
   128  func NewOAuth2ReqAuth(evalCtx *hcl.EvalContext, conf *config.OAuth2ReqAuth, memStore *cache.MemoryStore,
   129  	asBackend http.RoundTripper) (RequestAuthorizer, error) {
   130  
   131  	if _, supported := supportedGrantTypes[conf.GrantType]; !supported {
   132  		return nil, fmt.Errorf("grant_type %s not supported", conf.GrantType)
   133  	}
   134  
   135  	if conf.GrantType == config.Password {
   136  		if conf.Username == "" {
   137  			return nil, fmt.Errorf("username must not be empty with grant_type=password")
   138  		}
   139  		if conf.Password == "" {
   140  			return nil, fmt.Errorf("password must not be empty with grant_type=password")
   141  		}
   142  	} else {
   143  		if conf.Username != "" {
   144  			return nil, fmt.Errorf("username attribute must not be set with grant_type=%s", conf.GrantType)
   145  		}
   146  		if conf.Password != "" {
   147  			return nil, fmt.Errorf("password attribute must not be set with grant_type=%s", conf.GrantType)
   148  		}
   149  	}
   150  
   151  	var assertionCreator assertionCreator
   152  	assertionRange := conf.AssertionExpr.Range()
   153  	assertionSet := assertionRange.Start != assertionRange.End
   154  	if conf.GrantType == config.JwtBearer {
   155  		if !assertionSet && conf.JWTSigningProfile == nil {
   156  			return nil, fmt.Errorf("missing assertion attribute or jwt_signing_profile block with grant_type=%s", conf.GrantType)
   157  		}
   158  		if assertionSet {
   159  			assertionCreator = newAssertionCreatorFromExpr(conf.AssertionExpr)
   160  		} else {
   161  			var err error
   162  			assertionCreator, err = newAssertionCreatorFromJSP(evalCtx, conf.JWTSigningProfile)
   163  			if err != nil {
   164  				return nil, err
   165  			}
   166  		}
   167  	} else {
   168  		if assertionSet {
   169  			return nil, fmt.Errorf("assertion attribute must not be set with grant_type=%s", conf.GrantType)
   170  		}
   171  	}
   172  
   173  	oauth2Client, err := oauth2.NewClient(evalCtx, conf.GrantType, conf, conf, asBackend)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	reqAuth := &OAuth2ReqAuth{
   179  		config:           conf,
   180  		oauth2Client:     oauth2Client,
   181  		memStore:         memStore,
   182  		assertionCreator: assertionCreator,
   183  	}
   184  	reqAuth.storageKey = fmt.Sprintf("oauth2-%p", reqAuth)
   185  	return reqAuth, nil
   186  }
   187  
   188  func (oa *OAuth2ReqAuth) GetToken(req *http.Request) error {
   189  	token := oa.readAccessToken()
   190  	if token != "" {
   191  		req.Header.Set("Authorization", "Bearer "+token)
   192  		return nil
   193  	}
   194  
   195  	oa.mu.Lock()
   196  	defer oa.mu.Unlock()
   197  
   198  	token = oa.readAccessToken()
   199  	if token != "" {
   200  		req.Header.Set("Authorization", "Bearer "+token)
   201  		return nil
   202  	}
   203  
   204  	requestError := errors.Request.Label("oauth2")
   205  	formParams := url.Values{}
   206  
   207  	if oa.config.GrantType == config.JwtBearer {
   208  		requestContext := eval.ContextFromRequest(req).HCLContext()
   209  		assertion, err := oa.assertionCreator.createAssertion(requestContext)
   210  		if err != nil {
   211  			return requestError.With(err)
   212  		}
   213  
   214  		formParams.Set("assertion", assertion)
   215  	}
   216  	if oa.config.Scope != "" {
   217  		formParams.Set("scope", oa.config.Scope)
   218  	}
   219  	if oa.config.Password != "" || oa.config.Username != "" {
   220  		formParams.Set("username", oa.config.Username)
   221  		formParams.Set("password", oa.config.Password)
   222  	}
   223  
   224  	tokenResponseData, token, err := oa.oauth2Client.GetTokenResponse(req.Context(), formParams)
   225  	if err != nil {
   226  		return requestError.Message("token request failed") // don't propagate token request roundtrip error
   227  	}
   228  
   229  	oa.updateAccessToken(token, tokenResponseData)
   230  
   231  	req.Header.Set("Authorization", "Bearer "+token)
   232  	return nil
   233  }
   234  
   235  func (oa *OAuth2ReqAuth) RetryWithToken(req *http.Request, res *http.Response) (bool, error) {
   236  	if res == nil || res.StatusCode != http.StatusUnauthorized {
   237  		return false, nil
   238  	}
   239  
   240  	oa.memStore.Del(oa.storageKey)
   241  
   242  	ctx := req.Context()
   243  	if retries, ok := ctx.Value(request.TokenRequestRetries).(*uint8); !ok || *retries < *oa.config.Retries {
   244  		*retries++ // increase ptr value instead of context value
   245  		req.Header.Del("Authorization")
   246  		err := oa.GetToken(req)
   247  		return true, err
   248  	}
   249  	return false, nil
   250  }
   251  
   252  func (oa *OAuth2ReqAuth) readAccessToken() string {
   253  	if data := oa.memStore.Get(oa.storageKey); data != nil {
   254  		return data.(string)
   255  	}
   256  
   257  	return ""
   258  }
   259  
   260  func (oa *OAuth2ReqAuth) updateAccessToken(token string, jData map[string]interface{}) {
   261  	if oa.memStore != nil {
   262  		var ttl int64
   263  		if t, ok := jData["expires_in"].(float64); ok {
   264  			ttl = (int64)(t * 0.9)
   265  		}
   266  
   267  		oa.memStore.Set(oa.storageKey, token, ttl)
   268  	}
   269  }
   270  
   271  func (oa *OAuth2ReqAuth) value() (string, string) {
   272  	token := oa.readAccessToken()
   273  	return "oauth2", token
   274  }