github.com/avenga/couper@v1.12.2/handler/transport/token_request.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"sync"
     8  
     9  	"github.com/zclconf/go-cty/cty"
    10  
    11  	"github.com/avenga/couper/cache"
    12  	"github.com/avenga/couper/config"
    13  	"github.com/avenga/couper/config/request"
    14  	"github.com/avenga/couper/errors"
    15  	"github.com/avenga/couper/eval"
    16  	"github.com/avenga/couper/eval/buffer"
    17  	"github.com/avenga/couper/handler/producer"
    18  )
    19  
    20  var (
    21  	_ RequestAuthorizer = &TokenRequest{}
    22  )
    23  
    24  type TokenRequest struct {
    25  	config      *config.TokenRequest
    26  	mu          sync.Mutex
    27  	memStore    *cache.MemoryStore
    28  	reqProducer producer.Roundtrip
    29  	storageKey  string
    30  }
    31  
    32  func NewTokenRequest(conf *config.TokenRequest, memStore *cache.MemoryStore, reqProducer producer.Roundtrip) (RequestAuthorizer, error) {
    33  	tr := &TokenRequest{
    34  		config:      conf,
    35  		memStore:    memStore,
    36  		reqProducer: reqProducer,
    37  	}
    38  	tr.storageKey = fmt.Sprintf("TokenRequest-%p", tr)
    39  	return tr, nil
    40  }
    41  
    42  func (t *TokenRequest) GetToken(req *http.Request) error {
    43  	token := t.readToken()
    44  	if token != "" {
    45  		return nil
    46  	}
    47  
    48  	// block during read/request process
    49  	t.mu.Lock()
    50  	defer t.mu.Unlock()
    51  
    52  	token = t.readToken()
    53  	if token != "" {
    54  		return nil
    55  	}
    56  
    57  	var (
    58  		ttl int64
    59  		err error
    60  	)
    61  	token, ttl, err = t.requestToken(req)
    62  	if err != nil {
    63  		return errors.Request.Label(t.config.Name).With(err)
    64  	}
    65  
    66  	t.memStore.Set(t.storageKey, token, ttl)
    67  	return nil
    68  }
    69  
    70  func (t *TokenRequest) RetryWithToken(_ *http.Request, _ *http.Response) (bool, error) {
    71  	return false, nil
    72  }
    73  
    74  func (t *TokenRequest) readToken() string {
    75  	if data := t.memStore.Get(t.storageKey); data != nil {
    76  		return data.(string)
    77  	}
    78  
    79  	return ""
    80  }
    81  
    82  func (t *TokenRequest) requestToken(req *http.Request) (string, int64, error) {
    83  	ctx := context.WithValue(req.Context(), request.Wildcard, nil)       // disable handling this
    84  	ctx = context.WithValue(ctx, request.BufferOptions, buffer.Response) // always read out a possible token
    85  	ctx = context.WithValue(ctx, request.TokenRequest, t.config.Name)    // set the name for variable mapping purposes
    86  	outreq, _ := http.NewRequestWithContext(ctx, req.Method, "", nil)
    87  	result := t.reqProducer.Produce(outreq)
    88  	if result.Err != nil {
    89  		return "", 0, fmt.Errorf("token request failed") // don't propagate token request roundtrip error
    90  	}
    91  
    92  	// obtain synced and already read beresp value; map to context variables
    93  	hclCtx := eval.ContextFromRequest(req).HCLContextSync()
    94  	eval.MapTokenResponse(hclCtx, t.config.Name)
    95  
    96  	tokenRequestBody := t.config.HCLBody()
    97  	tokenVal, err := eval.ValueFromBodyAttribute(hclCtx, tokenRequestBody, "token")
    98  	if err != nil {
    99  		return "", 0, err
   100  	}
   101  	if tokenVal.IsNull() {
   102  		return "", 0, fmt.Errorf("token expression evaluates to null")
   103  	}
   104  	if tokenVal.Type() != cty.String {
   105  		return "", 0, fmt.Errorf("token expression must evaluate to a string")
   106  	}
   107  
   108  	ttlVal, err := eval.ValueFromBodyAttribute(hclCtx, tokenRequestBody, "ttl")
   109  	if err != nil {
   110  		return "", 0, err
   111  	}
   112  	if ttlVal.IsNull() {
   113  		return "", 0, fmt.Errorf("ttl expression evaluates to null")
   114  	}
   115  	if ttlVal.Type() != cty.String {
   116  		return "", 0, fmt.Errorf("ttl expression must evaluate to a string")
   117  	}
   118  
   119  	token := tokenVal.AsString()
   120  	ttl := ttlVal.AsString()
   121  	dur, parseErr := config.ParseDuration("ttl", ttl, 0)
   122  	if parseErr != nil {
   123  		return "", 0, parseErr
   124  	}
   125  
   126  	return token, int64(dur.Seconds()), nil
   127  }
   128  
   129  func (t *TokenRequest) value() (string, string) {
   130  	token := t.readToken()
   131  	return t.config.Name, token
   132  }