github.com/crowdsecurity/crowdsec@v1.6.1/pkg/longpollclient/client.go (about)

     1  package longpollclient
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"time"
    11  
    12  	"github.com/gofrs/uuid"
    13  	log "github.com/sirupsen/logrus"
    14  	"gopkg.in/tomb.v2"
    15  )
    16  
    17  type LongPollClient struct {
    18  	t          tomb.Tomb
    19  	c          chan Event
    20  	url        url.URL
    21  	logger     *log.Entry
    22  	since      int64
    23  	timeout    string
    24  	httpClient *http.Client
    25  }
    26  
    27  type LongPollClientConfig struct {
    28  	Url        url.URL
    29  	Logger     *log.Logger
    30  	HttpClient *http.Client
    31  }
    32  
    33  type Event struct {
    34  	Timestamp int64     `json:"timestamp"`
    35  	Category  string    `json:"category"`
    36  	Data      string    `json:"data"`
    37  	ID        uuid.UUID `json:"id"`
    38  	RequestId string
    39  }
    40  
    41  type pollResponse struct {
    42  	Events []Event `json:"events"`
    43  	// Set for timeout responses
    44  	Timestamp int64 `json:"timestamp"`
    45  	// API error responses could have an informative error here. Empty on success.
    46  	ErrorMessage string `json:"error"`
    47  }
    48  
    49  var errUnauthorized = fmt.Errorf("user is not authorized to use PAPI")
    50  
    51  const timeoutMessage = "no events before timeout"
    52  
    53  func (c *LongPollClient) doQuery() (*http.Response, error) {
    54  	logger := c.logger.WithField("method", "doQuery")
    55  	query := c.url.Query()
    56  	query.Set("since_time", fmt.Sprintf("%d", c.since))
    57  	query.Set("timeout", c.timeout)
    58  	c.url.RawQuery = query.Encode()
    59  
    60  	logger.Debugf("Query parameters: %s", c.url.RawQuery)
    61  
    62  	req, err := http.NewRequest(http.MethodGet, c.url.String(), nil)
    63  	if err != nil {
    64  		logger.Errorf("failed to create request: %s", err)
    65  		return nil, err
    66  	}
    67  	req.Header.Set("Accept", "application/json")
    68  	resp, err := c.httpClient.Do(req)
    69  	if err != nil {
    70  		logger.Errorf("failed to execute request: %s", err)
    71  		return nil, err
    72  	}
    73  	return resp, nil
    74  }
    75  
    76  func (c *LongPollClient) poll() error {
    77  
    78  	logger := c.logger.WithField("method", "poll")
    79  
    80  	resp, err := c.doQuery()
    81  
    82  	if err != nil {
    83  		return err
    84  	}
    85  
    86  	defer resp.Body.Close()
    87  
    88  	requestId := resp.Header.Get("X-Amzn-Trace-Id")
    89  	logger = logger.WithField("request-id", requestId)
    90  	if resp.StatusCode != http.StatusOK {
    91  		c.logger.Errorf("unexpected status code: %d", resp.StatusCode)
    92  		if resp.StatusCode == http.StatusPaymentRequired {
    93  			bodyContent, err := io.ReadAll(resp.Body)
    94  			if err != nil {
    95  				logger.Errorf("failed to read response body: %s", err)
    96  				return err
    97  			}
    98  			logger.Errorf(string(bodyContent))
    99  			return errUnauthorized
   100  		}
   101  		return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
   102  	}
   103  
   104  	decoder := json.NewDecoder(resp.Body)
   105  
   106  	for {
   107  		select {
   108  		case <-c.t.Dying():
   109  			logger.Debugf("dying")
   110  			close(c.c)
   111  			return nil
   112  		default:
   113  			var pollResp pollResponse
   114  			err = decoder.Decode(&pollResp)
   115  			if err != nil {
   116  				if errors.Is(err, io.EOF) {
   117  					logger.Debugf("server closed connection")
   118  					return nil
   119  				}
   120  				return fmt.Errorf("error decoding poll response: %v", err)
   121  			}
   122  
   123  			logger.Tracef("got response: %+v", pollResp)
   124  
   125  			if len(pollResp.ErrorMessage) > 0 {
   126  				if pollResp.ErrorMessage == timeoutMessage {
   127  					logger.Debugf("got timeout message")
   128  					return nil
   129  				}
   130  				return fmt.Errorf("longpoll API error message: %s", pollResp.ErrorMessage)
   131  			}
   132  
   133  			if len(pollResp.Events) > 0 {
   134  				logger.Debugf("got %d events", len(pollResp.Events))
   135  				for _, event := range pollResp.Events {
   136  					event.RequestId = requestId
   137  					c.c <- event
   138  					if event.Timestamp > c.since {
   139  						c.since = event.Timestamp
   140  					}
   141  				}
   142  			}
   143  			if pollResp.Timestamp > 0 {
   144  				c.since = pollResp.Timestamp
   145  			}
   146  			logger.Debugf("Since is now %d", c.since)
   147  		}
   148  	}
   149  }
   150  
   151  func (c *LongPollClient) pollEvents() error {
   152  	for {
   153  		select {
   154  		case <-c.t.Dying():
   155  			c.logger.Debug("dying")
   156  			return nil
   157  		default:
   158  			c.logger.Debug("Polling PAPI")
   159  			err := c.poll()
   160  			if err != nil {
   161  				c.logger.Errorf("failed to poll: %s", err)
   162  				if errors.Is(err, errUnauthorized) {
   163  					c.t.Kill(err)
   164  					close(c.c)
   165  					return err
   166  				}
   167  				continue
   168  			}
   169  		}
   170  	}
   171  }
   172  
   173  func (c *LongPollClient) Start(since time.Time) chan Event {
   174  	c.logger.Infof("starting polling client")
   175  	c.c = make(chan Event)
   176  	c.since = since.Unix() * 1000
   177  	c.timeout = "45"
   178  	c.t.Go(c.pollEvents)
   179  	return c.c
   180  }
   181  
   182  func (c *LongPollClient) Stop() error {
   183  	c.t.Kill(nil)
   184  	return nil
   185  }
   186  
   187  func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) {
   188  	c.logger.Debug("Pulling PAPI once")
   189  	c.since = since.Unix() * 1000
   190  	c.timeout = "1"
   191  	resp, err := c.doQuery()
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	defer resp.Body.Close()
   196  	decoder := json.NewDecoder(resp.Body)
   197  	evts := []Event{}
   198  	for {
   199  		var pollResp pollResponse
   200  		err = decoder.Decode(&pollResp)
   201  		if err != nil {
   202  			if errors.Is(err, io.EOF) {
   203  				c.logger.Debugf("server closed connection")
   204  				break
   205  			}
   206  			log.Errorf("error decoding poll response: %v", err)
   207  			break
   208  		}
   209  
   210  		c.logger.Tracef("got response: %+v", pollResp)
   211  
   212  		if len(pollResp.ErrorMessage) > 0 {
   213  			if pollResp.ErrorMessage == timeoutMessage {
   214  				c.logger.Debugf("got timeout message")
   215  				break
   216  			}
   217  			log.Errorf("longpoll API error message: %s", pollResp.ErrorMessage)
   218  			break
   219  		}
   220  		evts = append(evts, pollResp.Events...)
   221  	}
   222  	return evts, nil
   223  }
   224  
   225  func NewLongPollClient(config LongPollClientConfig) (*LongPollClient, error) {
   226  	var logger *log.Entry
   227  	if config.Url == (url.URL{}) {
   228  		return nil, fmt.Errorf("url is required")
   229  	}
   230  	if config.Logger == nil {
   231  		logger = log.WithField("component", "longpollclient")
   232  	} else {
   233  		logger = config.Logger.WithFields(log.Fields{
   234  			"component": "longpollclient",
   235  			"url":       config.Url.String(),
   236  		})
   237  	}
   238  
   239  	return &LongPollClient{
   240  		url:        config.Url,
   241  		logger:     logger,
   242  		httpClient: config.HttpClient,
   243  	}, nil
   244  }