github.com/avenga/couper@v1.12.2/handler/transport/oauth2_req_auth.go (about) 1 package transport 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/url" 7 "sync" 8 "time" 9 10 "github.com/hashicorp/hcl/v2" 11 "github.com/zclconf/go-cty/cty" 12 13 "github.com/avenga/couper/cache" 14 "github.com/avenga/couper/config" 15 "github.com/avenga/couper/config/request" 16 "github.com/avenga/couper/errors" 17 "github.com/avenga/couper/eval" 18 "github.com/avenga/couper/eval/lib" 19 "github.com/avenga/couper/internal/seetie" 20 "github.com/avenga/couper/oauth2" 21 ) 22 23 var supportedGrantTypes = map[string]struct{}{ 24 config.ClientCredentials: {}, 25 config.JwtBearer: {}, 26 config.Password: {}, 27 } 28 29 var ( 30 _ RequestAuthorizer = &OAuth2ReqAuth{} 31 ) 32 33 type assertionCreator interface { 34 createAssertion(ctx *hcl.EvalContext) (string, error) 35 } 36 37 var ( 38 _ assertionCreator = &assertionCreatorFromExpr{} 39 _ assertionCreator = &assertionCreatorFromJSP{} 40 ) 41 42 type assertionCreatorFromExpr struct { 43 expr hcl.Expression 44 } 45 46 func newAssertionCreatorFromExpr(expr hcl.Expression) assertionCreator { 47 return &assertionCreatorFromExpr{ 48 expr, 49 } 50 } 51 52 func (ac *assertionCreatorFromExpr) createAssertion(ctx *hcl.EvalContext) (string, error) { 53 assertionValue, err := eval.Value(ctx, ac.expr) 54 if err != nil { 55 return "", err 56 } 57 58 if assertionValue.IsNull() { 59 return "", fmt.Errorf("assertion expression evaluates to null") 60 } 61 if assertionValue.Type() != cty.String { 62 return "", fmt.Errorf("assertion expression must evaluate to a string") 63 } 64 65 return assertionValue.AsString(), nil 66 } 67 68 type assertionCreatorFromJSP struct { 69 *lib.JWTSigningConfig 70 headers map[string]interface{} 71 claims map[string]interface{} 72 } 73 74 func newAssertionCreatorFromJSP(evalCtx *hcl.EvalContext, jsp *config.JWTSigningProfile) (assertionCreator, error) { 75 signingConfig, err := lib.NewJWTSigningConfigFromJWTSigningProfile(jsp, nil) 76 if err != nil { 77 return nil, err 78 } 79 80 var headers, claims map[string]interface{} 81 82 if signingConfig.Headers != nil { 83 v, err := eval.Value(evalCtx, signingConfig.Headers) 84 if err != nil { 85 return nil, err 86 } 87 headers = seetie.ValueToMap(v) 88 } 89 90 if signingConfig.Claims != nil { 91 cl, err := eval.Value(evalCtx, signingConfig.Claims) 92 if err != nil { 93 return nil, err 94 } 95 claims = seetie.ValueToMap(cl) 96 } 97 98 return &assertionCreatorFromJSP{ 99 signingConfig, 100 headers, 101 claims, 102 }, nil 103 } 104 105 func (ac *assertionCreatorFromJSP) createAssertion(_ *hcl.EvalContext) (string, error) { 106 claims := make(map[string]interface{}) 107 for k, v := range ac.claims { 108 claims[k] = v 109 } 110 now := time.Now().Unix() 111 claims["exp"] = now + ac.TTL 112 113 return lib.CreateJWT(ac.SignatureAlgorithm, ac.Key, claims, ac.headers) 114 } 115 116 // OAuth2ReqAuth represents the transport <OAuth2ReqAuth> object. 117 type OAuth2ReqAuth struct { 118 config *config.OAuth2ReqAuth 119 mu sync.Mutex 120 memStore *cache.MemoryStore 121 oauth2Client *oauth2.Client 122 storageKey string 123 assertionCreator assertionCreator 124 } 125 126 // NewOAuth2ReqAuth implements the http.RoundTripper interface to wrap an existing Backend / http.RoundTripper 127 // to retrieve a valid token before passing the initial out request. 128 func NewOAuth2ReqAuth(evalCtx *hcl.EvalContext, conf *config.OAuth2ReqAuth, memStore *cache.MemoryStore, 129 asBackend http.RoundTripper) (RequestAuthorizer, error) { 130 131 if _, supported := supportedGrantTypes[conf.GrantType]; !supported { 132 return nil, fmt.Errorf("grant_type %s not supported", conf.GrantType) 133 } 134 135 if conf.GrantType == config.Password { 136 if conf.Username == "" { 137 return nil, fmt.Errorf("username must not be empty with grant_type=password") 138 } 139 if conf.Password == "" { 140 return nil, fmt.Errorf("password must not be empty with grant_type=password") 141 } 142 } else { 143 if conf.Username != "" { 144 return nil, fmt.Errorf("username attribute must not be set with grant_type=%s", conf.GrantType) 145 } 146 if conf.Password != "" { 147 return nil, fmt.Errorf("password attribute must not be set with grant_type=%s", conf.GrantType) 148 } 149 } 150 151 var assertionCreator assertionCreator 152 assertionRange := conf.AssertionExpr.Range() 153 assertionSet := assertionRange.Start != assertionRange.End 154 if conf.GrantType == config.JwtBearer { 155 if !assertionSet && conf.JWTSigningProfile == nil { 156 return nil, fmt.Errorf("missing assertion attribute or jwt_signing_profile block with grant_type=%s", conf.GrantType) 157 } 158 if assertionSet { 159 assertionCreator = newAssertionCreatorFromExpr(conf.AssertionExpr) 160 } else { 161 var err error 162 assertionCreator, err = newAssertionCreatorFromJSP(evalCtx, conf.JWTSigningProfile) 163 if err != nil { 164 return nil, err 165 } 166 } 167 } else { 168 if assertionSet { 169 return nil, fmt.Errorf("assertion attribute must not be set with grant_type=%s", conf.GrantType) 170 } 171 } 172 173 oauth2Client, err := oauth2.NewClient(evalCtx, conf.GrantType, conf, conf, asBackend) 174 if err != nil { 175 return nil, err 176 } 177 178 reqAuth := &OAuth2ReqAuth{ 179 config: conf, 180 oauth2Client: oauth2Client, 181 memStore: memStore, 182 assertionCreator: assertionCreator, 183 } 184 reqAuth.storageKey = fmt.Sprintf("oauth2-%p", reqAuth) 185 return reqAuth, nil 186 } 187 188 func (oa *OAuth2ReqAuth) GetToken(req *http.Request) error { 189 token := oa.readAccessToken() 190 if token != "" { 191 req.Header.Set("Authorization", "Bearer "+token) 192 return nil 193 } 194 195 oa.mu.Lock() 196 defer oa.mu.Unlock() 197 198 token = oa.readAccessToken() 199 if token != "" { 200 req.Header.Set("Authorization", "Bearer "+token) 201 return nil 202 } 203 204 requestError := errors.Request.Label("oauth2") 205 formParams := url.Values{} 206 207 if oa.config.GrantType == config.JwtBearer { 208 requestContext := eval.ContextFromRequest(req).HCLContext() 209 assertion, err := oa.assertionCreator.createAssertion(requestContext) 210 if err != nil { 211 return requestError.With(err) 212 } 213 214 formParams.Set("assertion", assertion) 215 } 216 if oa.config.Scope != "" { 217 formParams.Set("scope", oa.config.Scope) 218 } 219 if oa.config.Password != "" || oa.config.Username != "" { 220 formParams.Set("username", oa.config.Username) 221 formParams.Set("password", oa.config.Password) 222 } 223 224 tokenResponseData, token, err := oa.oauth2Client.GetTokenResponse(req.Context(), formParams) 225 if err != nil { 226 return requestError.Message("token request failed") // don't propagate token request roundtrip error 227 } 228 229 oa.updateAccessToken(token, tokenResponseData) 230 231 req.Header.Set("Authorization", "Bearer "+token) 232 return nil 233 } 234 235 func (oa *OAuth2ReqAuth) RetryWithToken(req *http.Request, res *http.Response) (bool, error) { 236 if res == nil || res.StatusCode != http.StatusUnauthorized { 237 return false, nil 238 } 239 240 oa.memStore.Del(oa.storageKey) 241 242 ctx := req.Context() 243 if retries, ok := ctx.Value(request.TokenRequestRetries).(*uint8); !ok || *retries < *oa.config.Retries { 244 *retries++ // increase ptr value instead of context value 245 req.Header.Del("Authorization") 246 err := oa.GetToken(req) 247 return true, err 248 } 249 return false, nil 250 } 251 252 func (oa *OAuth2ReqAuth) readAccessToken() string { 253 if data := oa.memStore.Get(oa.storageKey); data != nil { 254 return data.(string) 255 } 256 257 return "" 258 } 259 260 func (oa *OAuth2ReqAuth) updateAccessToken(token string, jData map[string]interface{}) { 261 if oa.memStore != nil { 262 var ttl int64 263 if t, ok := jData["expires_in"].(float64); ok { 264 ttl = (int64)(t * 0.9) 265 } 266 267 oa.memStore.Set(oa.storageKey, token, ttl) 268 } 269 } 270 271 func (oa *OAuth2ReqAuth) value() (string, string) { 272 token := oa.readAccessToken() 273 return "oauth2", token 274 }