github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiclient/auth_jwt.go (about) 1 package apiclient 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net/http" 9 "net/http/httputil" 10 "net/url" 11 "sync" 12 "time" 13 14 "github.com/go-openapi/strfmt" 15 log "github.com/sirupsen/logrus" 16 17 "github.com/crowdsecurity/crowdsec/pkg/models" 18 ) 19 20 type JWTTransport struct { 21 MachineID *string 22 Password *strfmt.Password 23 Token string 24 Expiration time.Time 25 Scenarios []string 26 URL *url.URL 27 VersionPrefix string 28 UserAgent string 29 // Transport is the underlying HTTP transport to use when making requests. 30 // It will default to http.DefaultTransport if nil. 31 Transport http.RoundTripper 32 UpdateScenario func() ([]string, error) 33 refreshTokenMutex sync.Mutex 34 } 35 36 func (t *JWTTransport) refreshJwtToken() error { 37 var err error 38 39 if t.UpdateScenario != nil { 40 t.Scenarios, err = t.UpdateScenario() 41 if err != nil { 42 return fmt.Errorf("can't update scenario list: %w", err) 43 } 44 45 log.Debugf("scenarios list updated for '%s'", *t.MachineID) 46 } 47 48 auth := models.WatcherAuthRequest{ 49 MachineID: t.MachineID, 50 Password: t.Password, 51 Scenarios: t.Scenarios, 52 } 53 54 /* 55 we don't use the main client, so let's build the body 56 */ 57 var buf io.ReadWriter = &bytes.Buffer{} 58 enc := json.NewEncoder(buf) 59 enc.SetEscapeHTML(false) 60 err = enc.Encode(auth) 61 62 if err != nil { 63 return fmt.Errorf("could not encode jwt auth body: %w", err) 64 } 65 66 req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf) 67 if err != nil { 68 return fmt.Errorf("could not create request: %w", err) 69 } 70 71 req.Header.Add("Content-Type", "application/json") 72 73 transport := t.Transport 74 if transport == nil { 75 transport = http.DefaultTransport 76 } 77 78 client := &http.Client{ 79 Transport: &retryRoundTripper{ 80 next: transport, 81 maxAttempts: 5, 82 withBackOff: true, 83 retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, 84 }, 85 } 86 87 if t.UserAgent != "" { 88 req.Header.Add("User-Agent", t.UserAgent) 89 } 90 91 if log.GetLevel() >= log.TraceLevel { 92 dump, _ := httputil.DumpRequest(req, true) 93 log.Tracef("auth-jwt request: %s", string(dump)) 94 } 95 96 log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String()) 97 98 resp, err := client.Do(req) 99 if err != nil { 100 return fmt.Errorf("could not get jwt token: %w", err) 101 } 102 103 log.Debugf("auth-jwt : http %d", resp.StatusCode) 104 105 if log.GetLevel() >= log.TraceLevel { 106 dump, _ := httputil.DumpResponse(resp, true) 107 log.Tracef("auth-jwt response: %s", string(dump)) 108 } 109 110 defer resp.Body.Close() 111 112 if resp.StatusCode < 200 || resp.StatusCode >= 300 { 113 log.Debugf("received response status %q when fetching %v", resp.Status, req.URL) 114 115 err = CheckResponse(resp) 116 if err != nil { 117 return err 118 } 119 } 120 121 var response models.WatcherAuthResponse 122 123 if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { 124 return fmt.Errorf("unable to decode response: %w", err) 125 } 126 127 if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil { 128 return fmt.Errorf("unable to parse jwt expiration: %w", err) 129 } 130 131 t.Token = response.Token 132 133 log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String()) 134 135 return nil 136 } 137 138 func (t *JWTTransport) needsTokenRefresh() bool { 139 return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) 140 } 141 142 // prepareRequest returns a copy of the request with the necessary authentication headers. 143 func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) { 144 // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless 145 // and will cause overload on CAPI. We use a mutex to avoid this. 146 t.refreshTokenMutex.Lock() 147 defer t.refreshTokenMutex.Unlock() 148 149 // We bypass the refresh if we are requesting the login endpoint, as it does not require a token, 150 // and it leads to do 2 requests instead of one (refresh + actual login request). 151 if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() { 152 if err := t.refreshJwtToken(); err != nil { 153 return nil, err 154 } 155 } 156 157 if t.UserAgent != "" { 158 req.Header.Add("User-Agent", t.UserAgent) 159 } 160 161 req.Header.Add("Authorization", "Bearer "+t.Token) 162 163 return req, nil 164 } 165 166 // RoundTrip implements the RoundTripper interface. 167 func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { 168 req, err := t.prepareRequest(req) 169 if err != nil { 170 return nil, err 171 } 172 173 if log.GetLevel() >= log.TraceLevel { 174 // requestToDump := cloneRequest(req) 175 dump, _ := httputil.DumpRequest(req, true) 176 log.Tracef("req-jwt: %s", string(dump)) 177 } 178 179 // Make the HTTP request. 180 resp, err := t.transport().RoundTrip(req) 181 if log.GetLevel() >= log.TraceLevel { 182 dump, _ := httputil.DumpResponse(resp, true) 183 log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) 184 } 185 186 if err != nil { 187 // we had an error (network error for example, or 401 because token is refused), reset the token? 188 t.ResetToken() 189 190 return resp, fmt.Errorf("performing jwt auth: %w", err) 191 } 192 193 if resp != nil { 194 log.Debugf("resp-jwt: %d", resp.StatusCode) 195 } 196 197 return resp, nil 198 } 199 200 func (t *JWTTransport) Client() *http.Client { 201 return &http.Client{Transport: t} 202 } 203 204 func (t *JWTTransport) ResetToken() { 205 log.Debug("resetting jwt token") 206 t.refreshTokenMutex.Lock() 207 t.Token = "" 208 t.refreshTokenMutex.Unlock() 209 } 210 211 // transport() returns a round tripper that retries once when the status is unauthorized, 212 // and 5 times when the infrastructure is overloaded. 213 func (t *JWTTransport) transport() http.RoundTripper { 214 transport := t.Transport 215 if transport == nil { 216 transport = http.DefaultTransport 217 } 218 219 return &retryRoundTripper{ 220 next: &retryRoundTripper{ 221 next: transport, 222 maxAttempts: 5, 223 withBackOff: true, 224 retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout}, 225 }, 226 maxAttempts: 2, 227 withBackOff: false, 228 retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden}, 229 onBeforeRequest: func(attempt int) { 230 // reset the token only in the second attempt as this is when we know we had a 401 or 403 231 // the second attempt is supposed to refresh the token 232 if attempt > 0 { 233 t.ResetToken() 234 } 235 }, 236 } 237 }