github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiclient/auth_jwt.go (about)

     1  package apiclient
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httputil"
    10  	"net/url"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/go-openapi/strfmt"
    15  	log "github.com/sirupsen/logrus"
    16  
    17  	"github.com/crowdsecurity/crowdsec/pkg/models"
    18  )
    19  
    20  type JWTTransport struct {
    21  	MachineID     *string
    22  	Password      *strfmt.Password
    23  	Token         string
    24  	Expiration    time.Time
    25  	Scenarios     []string
    26  	URL           *url.URL
    27  	VersionPrefix string
    28  	UserAgent     string
    29  	// Transport is the underlying HTTP transport to use when making requests.
    30  	// It will default to http.DefaultTransport if nil.
    31  	Transport         http.RoundTripper
    32  	UpdateScenario    func() ([]string, error)
    33  	refreshTokenMutex sync.Mutex
    34  }
    35  
    36  func (t *JWTTransport) refreshJwtToken() error {
    37  	var err error
    38  
    39  	if t.UpdateScenario != nil {
    40  		t.Scenarios, err = t.UpdateScenario()
    41  		if err != nil {
    42  			return fmt.Errorf("can't update scenario list: %w", err)
    43  		}
    44  
    45  		log.Debugf("scenarios list updated for '%s'", *t.MachineID)
    46  	}
    47  
    48  	auth := models.WatcherAuthRequest{
    49  		MachineID: t.MachineID,
    50  		Password:  t.Password,
    51  		Scenarios: t.Scenarios,
    52  	}
    53  
    54  	/*
    55  		we don't use the main client, so let's build the body
    56  	*/
    57  	var buf io.ReadWriter = &bytes.Buffer{}
    58  	enc := json.NewEncoder(buf)
    59  	enc.SetEscapeHTML(false)
    60  	err = enc.Encode(auth)
    61  
    62  	if err != nil {
    63  		return fmt.Errorf("could not encode jwt auth body: %w", err)
    64  	}
    65  
    66  	req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
    67  	if err != nil {
    68  		return fmt.Errorf("could not create request: %w", err)
    69  	}
    70  
    71  	req.Header.Add("Content-Type", "application/json")
    72  
    73  	transport := t.Transport
    74  	if transport == nil {
    75  		transport = http.DefaultTransport
    76  	}
    77  
    78  	client := &http.Client{
    79  		Transport: &retryRoundTripper{
    80  			next:             transport,
    81  			maxAttempts:      5,
    82  			withBackOff:      true,
    83  			retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
    84  		},
    85  	}
    86  
    87  	if t.UserAgent != "" {
    88  		req.Header.Add("User-Agent", t.UserAgent)
    89  	}
    90  
    91  	if log.GetLevel() >= log.TraceLevel {
    92  		dump, _ := httputil.DumpRequest(req, true)
    93  		log.Tracef("auth-jwt request: %s", string(dump))
    94  	}
    95  
    96  	log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String())
    97  
    98  	resp, err := client.Do(req)
    99  	if err != nil {
   100  		return fmt.Errorf("could not get jwt token: %w", err)
   101  	}
   102  
   103  	log.Debugf("auth-jwt : http %d", resp.StatusCode)
   104  
   105  	if log.GetLevel() >= log.TraceLevel {
   106  		dump, _ := httputil.DumpResponse(resp, true)
   107  		log.Tracef("auth-jwt response: %s", string(dump))
   108  	}
   109  
   110  	defer resp.Body.Close()
   111  
   112  	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   113  		log.Debugf("received response status %q when fetching %v", resp.Status, req.URL)
   114  
   115  		err = CheckResponse(resp)
   116  		if err != nil {
   117  			return err
   118  		}
   119  	}
   120  
   121  	var response models.WatcherAuthResponse
   122  
   123  	if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
   124  		return fmt.Errorf("unable to decode response: %w", err)
   125  	}
   126  
   127  	if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
   128  		return fmt.Errorf("unable to parse jwt expiration: %w", err)
   129  	}
   130  
   131  	t.Token = response.Token
   132  
   133  	log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
   134  
   135  	return nil
   136  }
   137  
   138  func (t *JWTTransport) needsTokenRefresh() bool {
   139  	return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())
   140  }
   141  
   142  // prepareRequest returns a copy of the  request with the necessary authentication headers.
   143  func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) {
   144  	// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless
   145  	// and will cause overload on CAPI. We use a mutex to avoid this.
   146  	t.refreshTokenMutex.Lock()
   147  	defer t.refreshTokenMutex.Unlock()
   148  
   149  	// We bypass the refresh if we are requesting the login endpoint, as it does not require a token,
   150  	// and it leads to do 2 requests instead of one (refresh + actual login request).
   151  	if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() {
   152  		if err := t.refreshJwtToken(); err != nil {
   153  			return nil, err
   154  		}
   155  	}
   156  
   157  	if t.UserAgent != "" {
   158  		req.Header.Add("User-Agent", t.UserAgent)
   159  	}
   160  
   161  	req.Header.Add("Authorization", "Bearer "+t.Token)
   162  
   163  	return req, nil
   164  }
   165  
   166  // RoundTrip implements the RoundTripper interface.
   167  func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   168  	req, err := t.prepareRequest(req)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	if log.GetLevel() >= log.TraceLevel {
   174  		// requestToDump := cloneRequest(req)
   175  		dump, _ := httputil.DumpRequest(req, true)
   176  		log.Tracef("req-jwt: %s", string(dump))
   177  	}
   178  
   179  	// Make the HTTP request.
   180  	resp, err := t.transport().RoundTrip(req)
   181  	if log.GetLevel() >= log.TraceLevel {
   182  		dump, _ := httputil.DumpResponse(resp, true)
   183  		log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
   184  	}
   185  
   186  	if err != nil {
   187  		// we had an error (network error for example, or 401 because token is refused), reset the token?
   188  		t.ResetToken()
   189  
   190  		return resp, fmt.Errorf("performing jwt auth: %w", err)
   191  	}
   192  
   193  	if resp != nil {
   194  		log.Debugf("resp-jwt: %d", resp.StatusCode)
   195  	}
   196  
   197  	return resp, nil
   198  }
   199  
   200  func (t *JWTTransport) Client() *http.Client {
   201  	return &http.Client{Transport: t}
   202  }
   203  
   204  func (t *JWTTransport) ResetToken() {
   205  	log.Debug("resetting jwt token")
   206  	t.refreshTokenMutex.Lock()
   207  	t.Token = ""
   208  	t.refreshTokenMutex.Unlock()
   209  }
   210  
   211  // transport() returns a round tripper that retries once when the status is unauthorized,
   212  // and 5 times when the infrastructure is overloaded.
   213  func (t *JWTTransport) transport() http.RoundTripper {
   214  	transport := t.Transport
   215  	if transport == nil {
   216  		transport = http.DefaultTransport
   217  	}
   218  
   219  	return &retryRoundTripper{
   220  		next: &retryRoundTripper{
   221  			next:             transport,
   222  			maxAttempts:      5,
   223  			withBackOff:      true,
   224  			retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
   225  		},
   226  		maxAttempts:      2,
   227  		withBackOff:      false,
   228  		retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
   229  		onBeforeRequest: func(attempt int) {
   230  			// reset the token only in the second attempt as this is when we know we had a 401 or 403
   231  			// the second attempt is supposed to refresh the token
   232  			if attempt > 0 {
   233  				t.ResetToken()
   234  			}
   235  		},
   236  	}
   237  }