github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiserver/apic.go (about)

     1  package apiserver
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"slices"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/go-openapi/strfmt"
    18  	log "github.com/sirupsen/logrus"
    19  	"gopkg.in/tomb.v2"
    20  
    21  	"github.com/crowdsecurity/go-cs-lib/ptr"
    22  	"github.com/crowdsecurity/go-cs-lib/trace"
    23  	"github.com/crowdsecurity/go-cs-lib/version"
    24  
    25  	"github.com/crowdsecurity/crowdsec/pkg/apiclient"
    26  	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
    27  	"github.com/crowdsecurity/crowdsec/pkg/database"
    28  	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
    29  	"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
    30  	"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
    31  	"github.com/crowdsecurity/crowdsec/pkg/models"
    32  	"github.com/crowdsecurity/crowdsec/pkg/modelscapi"
    33  	"github.com/crowdsecurity/crowdsec/pkg/types"
    34  )
    35  
    36  const (
    37  	// delta values must be smaller than the interval
    38  	pullIntervalDefault    = time.Hour * 2
    39  	pullIntervalDelta      = 5 * time.Minute
    40  	pushIntervalDefault    = time.Second * 10
    41  	pushIntervalDelta      = time.Second * 7
    42  	metricsIntervalDefault = time.Minute * 30
    43  	metricsIntervalDelta   = time.Minute * 15
    44  )
    45  
    46  type apic struct {
    47  	// when changing the intervals in tests, always set *First too
    48  	// or they can be negative
    49  	pullInterval         time.Duration
    50  	pullIntervalFirst    time.Duration
    51  	pushInterval         time.Duration
    52  	pushIntervalFirst    time.Duration
    53  	metricsInterval      time.Duration
    54  	metricsIntervalFirst time.Duration
    55  	dbClient             *database.Client
    56  	apiClient            *apiclient.ApiClient
    57  	AlertsAddChan        chan []*models.Alert
    58  
    59  	mu            sync.Mutex
    60  	pushTomb      tomb.Tomb
    61  	pullTomb      tomb.Tomb
    62  	metricsTomb   tomb.Tomb
    63  	startup       bool
    64  	credentials   *csconfig.ApiCredentialsCfg
    65  	scenarioList  []string
    66  	consoleConfig *csconfig.ConsoleConfig
    67  	isPulling     chan bool
    68  	whitelists    *csconfig.CapiWhitelist
    69  }
    70  
    71  // randomDuration returns a duration value between d-delta and d+delta
    72  func randomDuration(d time.Duration, delta time.Duration) time.Duration {
    73  	ret := d + time.Duration(rand.Int63n(int64(2*delta))) - delta
    74  	// ticker interval must be > 0 (nanoseconds)
    75  	if ret <= 0 {
    76  		return 1
    77  	}
    78  
    79  	return ret
    80  }
    81  
    82  func (a *apic) FetchScenariosListFromDB() ([]string, error) {
    83  	scenarios := make([]string, 0)
    84  	machines, err := a.dbClient.ListMachines()
    85  
    86  	if err != nil {
    87  		return nil, fmt.Errorf("while listing machines: %w", err)
    88  	}
    89  	//merge all scenarios together
    90  	for _, v := range machines {
    91  		machineScenarios := strings.Split(v.Scenarios, ",")
    92  		log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID)
    93  
    94  		for _, sv := range machineScenarios {
    95  			if !slices.Contains(scenarios, sv) && sv != "" {
    96  				scenarios = append(scenarios, sv)
    97  			}
    98  		}
    99  	}
   100  
   101  	log.Debugf("Returning list of scenarios : %+v", scenarios)
   102  
   103  	return scenarios, nil
   104  }
   105  
   106  func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions {
   107  	apiDecisions := models.AddSignalsRequestItemDecisions{}
   108  
   109  	for _, decision := range decisions {
   110  		x := &models.AddSignalsRequestItemDecisionsItem{
   111  			Duration: ptr.Of(*decision.Duration),
   112  			ID:       new(int64),
   113  			Origin:   ptr.Of(*decision.Origin),
   114  			Scenario: ptr.Of(*decision.Scenario),
   115  			Scope:    ptr.Of(*decision.Scope),
   116  			//Simulated: *decision.Simulated,
   117  			Type:  ptr.Of(*decision.Type),
   118  			Until: decision.Until,
   119  			Value: ptr.Of(*decision.Value),
   120  			UUID:  decision.UUID,
   121  		}
   122  		*x.ID = decision.ID
   123  
   124  		if decision.Simulated != nil {
   125  			x.Simulated = *decision.Simulated
   126  		}
   127  
   128  		apiDecisions = append(apiDecisions, x)
   129  	}
   130  
   131  	return apiDecisions
   132  }
   133  
   134  func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) *models.AddSignalsRequestItem {
   135  	signal := &models.AddSignalsRequestItem{
   136  		Message:         alert.Message,
   137  		Scenario:        alert.Scenario,
   138  		ScenarioHash:    alert.ScenarioHash,
   139  		ScenarioVersion: alert.ScenarioVersion,
   140  		Source: &models.AddSignalsRequestItemSource{
   141  			AsName:    alert.Source.AsName,
   142  			AsNumber:  alert.Source.AsNumber,
   143  			Cn:        alert.Source.Cn,
   144  			IP:        alert.Source.IP,
   145  			Latitude:  alert.Source.Latitude,
   146  			Longitude: alert.Source.Longitude,
   147  			Range:     alert.Source.Range,
   148  			Scope:     alert.Source.Scope,
   149  			Value:     alert.Source.Value,
   150  		},
   151  		StartAt:       alert.StartAt,
   152  		StopAt:        alert.StopAt,
   153  		CreatedAt:     alert.CreatedAt,
   154  		MachineID:     alert.MachineID,
   155  		ScenarioTrust: scenarioTrust,
   156  		Decisions:     decisionsToApiDecisions(alert.Decisions),
   157  		UUID:          alert.UUID,
   158  	}
   159  	if shareContext {
   160  		signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0)
   161  
   162  		for _, meta := range alert.Meta {
   163  			contextItem := models.AddSignalsRequestItemContextItems0{
   164  				Key:   meta.Key,
   165  				Value: meta.Value,
   166  			}
   167  			signal.Context = append(signal.Context, &contextItem)
   168  		}
   169  	}
   170  
   171  	return signal
   172  }
   173  
   174  func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) {
   175  	var err error
   176  
   177  	ret := &apic{
   178  		AlertsAddChan:        make(chan []*models.Alert),
   179  		dbClient:             dbClient,
   180  		mu:                   sync.Mutex{},
   181  		startup:              true,
   182  		credentials:          config.Credentials,
   183  		pullTomb:             tomb.Tomb{},
   184  		pushTomb:             tomb.Tomb{},
   185  		metricsTomb:          tomb.Tomb{},
   186  		scenarioList:         make([]string, 0),
   187  		consoleConfig:        consoleConfig,
   188  		pullInterval:         pullIntervalDefault,
   189  		pullIntervalFirst:    randomDuration(pullIntervalDefault, pullIntervalDelta),
   190  		pushInterval:         pushIntervalDefault,
   191  		pushIntervalFirst:    randomDuration(pushIntervalDefault, pushIntervalDelta),
   192  		metricsInterval:      metricsIntervalDefault,
   193  		metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta),
   194  		isPulling:            make(chan bool, 1),
   195  		whitelists:           apicWhitelist,
   196  	}
   197  
   198  	password := strfmt.Password(config.Credentials.Password)
   199  	apiURL, err := url.Parse(config.Credentials.URL)
   200  
   201  	if err != nil {
   202  		return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err)
   203  	}
   204  
   205  	papiURL, err := url.Parse(config.Credentials.PapiURL)
   206  	if err != nil {
   207  		return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err)
   208  	}
   209  
   210  	ret.scenarioList, err = ret.FetchScenariosListFromDB()
   211  	if err != nil {
   212  		return nil, fmt.Errorf("while fetching scenarios from db: %w", err)
   213  	}
   214  
   215  	ret.apiClient, err = apiclient.NewClient(&apiclient.Config{
   216  		MachineID:      config.Credentials.Login,
   217  		Password:       password,
   218  		UserAgent:      fmt.Sprintf("crowdsec/%s", version.String()),
   219  		URL:            apiURL,
   220  		PapiURL:        papiURL,
   221  		VersionPrefix:  "v3",
   222  		Scenarios:      ret.scenarioList,
   223  		UpdateScenario: ret.FetchScenariosListFromDB,
   224  	})
   225  	if err != nil {
   226  		return nil, fmt.Errorf("while creating api client: %w", err)
   227  	}
   228  
   229  	// The watcher will be authenticated by the RoundTripper the first time it will call CAPI
   230  	// Explicit authentication will provoke a useless supplementary call to CAPI
   231  	scenarios, err := ret.FetchScenariosListFromDB()
   232  	if err != nil {
   233  		return ret, fmt.Errorf("get scenario in db: %w", err)
   234  	}
   235  
   236  	authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
   237  		MachineID: &config.Credentials.Login,
   238  		Password:  &password,
   239  		Scenarios: scenarios,
   240  	})
   241  	if err != nil {
   242  		return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err)
   243  	}
   244  
   245  	if err = ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
   246  		return ret, fmt.Errorf("unable to parse jwt expiration: %w", err)
   247  	}
   248  
   249  	ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token
   250  
   251  	return ret, err
   252  }
   253  
   254  // keep track of all alerts in cache and push it to CAPI every PushInterval.
   255  func (a *apic) Push() error {
   256  	defer trace.CatchPanic("lapi/pushToAPIC")
   257  
   258  	var cache models.AddSignalsRequest
   259  
   260  	ticker := time.NewTicker(a.pushIntervalFirst)
   261  
   262  	log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval)
   263  
   264  	for {
   265  		select {
   266  		case <-a.pushTomb.Dying(): // if one apic routine is dying, do we kill the others?
   267  			a.pullTomb.Kill(nil)
   268  			a.metricsTomb.Kill(nil)
   269  			log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache))
   270  
   271  			if len(cache) == 0 {
   272  				return nil
   273  			}
   274  
   275  			go a.Send(&cache)
   276  
   277  			return nil
   278  		case <-ticker.C:
   279  			ticker.Reset(a.pushInterval)
   280  
   281  			if len(cache) > 0 {
   282  				a.mu.Lock()
   283  				cacheCopy := cache
   284  				cache = make(models.AddSignalsRequest, 0)
   285  				a.mu.Unlock()
   286  				log.Infof("Signal push: %d signals to push", len(cacheCopy))
   287  
   288  				go a.Send(&cacheCopy)
   289  			}
   290  		case alerts := <-a.AlertsAddChan:
   291  			var signals []*models.AddSignalsRequestItem
   292  
   293  			for _, alert := range alerts {
   294  				if ok := shouldShareAlert(alert, a.consoleConfig); ok {
   295  					signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext))
   296  				}
   297  			}
   298  
   299  			a.mu.Lock()
   300  			cache = append(cache, signals...)
   301  			a.mu.Unlock()
   302  		}
   303  	}
   304  }
   305  
   306  func getScenarioTrustOfAlert(alert *models.Alert) string {
   307  	scenarioTrust := "certified"
   308  	if alert.ScenarioHash == nil || *alert.ScenarioHash == "" {
   309  		scenarioTrust = "custom"
   310  	} else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" {
   311  		scenarioTrust = "tainted"
   312  	}
   313  
   314  	if len(alert.Decisions) > 0 {
   315  		if *alert.Decisions[0].Origin == types.CscliOrigin {
   316  			scenarioTrust = "manual"
   317  		}
   318  	}
   319  
   320  	return scenarioTrust
   321  }
   322  
   323  func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig) bool {
   324  	if *alert.Simulated {
   325  		log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID)
   326  		return false
   327  	}
   328  
   329  	switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust {
   330  	case "manual":
   331  		if !*consoleConfig.ShareManualDecisions {
   332  			log.Debugf("manual decision generated an alert, doesn't send it to CAPI because options is disabled")
   333  			return false
   334  		}
   335  	case "tainted":
   336  		if !*consoleConfig.ShareTaintedScenarios {
   337  			log.Debugf("tainted scenario generated an alert, doesn't send it to CAPI because options is disabled")
   338  			return false
   339  		}
   340  	case "custom":
   341  		if !*consoleConfig.ShareCustomScenarios {
   342  			log.Debugf("custom scenario generated an alert, doesn't send it to CAPI because options is disabled")
   343  			return false
   344  		}
   345  	}
   346  
   347  	return true
   348  }
   349  
   350  func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
   351  	/*we do have a problem with this :
   352  	The apic.Push background routine reads from alertToPush chan.
   353  	This chan is filled by Controller.CreateAlert
   354  
   355  	If the chan apic.Send hangs, the alertToPush chan will become full,
   356  	with means that Controller.CreateAlert is going to hang, blocking API worker(s).
   357  
   358  	So instead, we prefer to cancel write.
   359  
   360  	I don't know enough about gin to tell how much of an issue it can be.
   361  	*/
   362  	var (
   363  		cache []*models.AddSignalsRequestItem = *cacheOrig
   364  		send  models.AddSignalsRequest
   365  	)
   366  
   367  	bulkSize := 50
   368  	pageStart := 0
   369  	pageEnd := bulkSize
   370  
   371  	for {
   372  		if pageEnd >= len(cache) {
   373  			send = cache[pageStart:]
   374  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   375  
   376  			defer cancel()
   377  
   378  			_, _, err := a.apiClient.Signal.Add(ctx, &send)
   379  
   380  			if err != nil {
   381  				log.Errorf("sending signal to central API: %s", err)
   382  				return
   383  			}
   384  
   385  			break
   386  		}
   387  
   388  		send = cache[pageStart:pageEnd]
   389  		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   390  
   391  		defer cancel()
   392  
   393  		_, _, err := a.apiClient.Signal.Add(ctx, &send)
   394  
   395  		if err != nil {
   396  			//we log it here as well, because the return value of func might be discarded
   397  			log.Errorf("sending signal to central API: %s", err)
   398  		}
   399  
   400  		pageStart += bulkSize
   401  		pageEnd += bulkSize
   402  	}
   403  }
   404  
   405  func (a *apic) CAPIPullIsOld() (bool, error) {
   406  	/*only pull community blocklist if it's older than 1h30 */
   407  	alerts := a.dbClient.Ent.Alert.Query()
   408  	alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
   409  	alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert
   410  	count, err := alerts.Count(a.dbClient.CTX)
   411  
   412  	if err != nil {
   413  		return false, fmt.Errorf("while looking for CAPI alert: %w", err)
   414  	}
   415  
   416  	if count > 0 {
   417  		log.Printf("last CAPI pull is newer than 1h30, skip.")
   418  		return false, nil
   419  	}
   420  
   421  	return true, nil
   422  }
   423  
   424  func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) {
   425  	nbDeleted := 0
   426  
   427  	for _, decision := range deletedDecisions {
   428  		filter := map[string][]string{
   429  			"value":  {*decision.Value},
   430  			"origin": {*decision.Origin},
   431  		}
   432  		if strings.ToLower(*decision.Scope) != "ip" {
   433  			filter["type"] = []string{*decision.Type}
   434  			filter["scopes"] = []string{*decision.Scope}
   435  		}
   436  
   437  		dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
   438  		if err != nil {
   439  			return 0, fmt.Errorf("deleting decisions error: %w", err)
   440  		}
   441  
   442  		dbCliDel, err := strconv.Atoi(dbCliRet)
   443  		if err != nil {
   444  			return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
   445  		}
   446  
   447  		updateCounterForDecision(deleteCounters, decision.Origin, decision.Scenario, dbCliDel)
   448  		nbDeleted += dbCliDel
   449  	}
   450  
   451  	return nbDeleted, nil
   452  }
   453  
   454  func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
   455  	var nbDeleted int
   456  
   457  	for _, decisions := range deletedDecisions {
   458  		scope := decisions.Scope
   459  
   460  		for _, decision := range decisions.Decisions {
   461  			filter := map[string][]string{
   462  				"value":  {decision},
   463  				"origin": {types.CAPIOrigin},
   464  			}
   465  			if strings.ToLower(*scope) != "ip" {
   466  				filter["scopes"] = []string{*scope}
   467  			}
   468  
   469  			dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
   470  			if err != nil {
   471  				return 0, fmt.Errorf("deleting decisions error: %w", err)
   472  			}
   473  
   474  			dbCliDel, err := strconv.Atoi(dbCliRet)
   475  			if err != nil {
   476  				return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
   477  			}
   478  
   479  			updateCounterForDecision(deleteCounters, ptr.Of(types.CAPIOrigin), nil, dbCliDel)
   480  			nbDeleted += dbCliDel
   481  		}
   482  	}
   483  
   484  	return nbDeleted, nil
   485  }
   486  
   487  func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
   488  	newAlerts := make([]*models.Alert, 0)
   489  
   490  	for _, decision := range decisions {
   491  		found := false
   492  
   493  		for _, sub := range newAlerts {
   494  			if sub.Source.Scope == nil {
   495  				log.Warningf("nil scope in %+v", sub)
   496  				continue
   497  			}
   498  
   499  			if *decision.Origin == types.CAPIOrigin {
   500  				if *sub.Source.Scope == types.CAPIOrigin {
   501  					found = true
   502  					break
   503  				}
   504  			} else if *decision.Origin == types.ListOrigin {
   505  				if *sub.Source.Scope == *decision.Origin {
   506  					if sub.Scenario == nil {
   507  						log.Warningf("nil scenario in %+v", sub)
   508  					}
   509  					if *sub.Scenario == *decision.Scenario {
   510  						found = true
   511  						break
   512  					}
   513  				}
   514  			} else {
   515  				log.Warningf("unknown origin %s : %+v", *decision.Origin, decision)
   516  			}
   517  		}
   518  
   519  		if !found {
   520  			log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario)
   521  			newAlerts = append(newAlerts, createAlertForDecision(decision))
   522  		}
   523  	}
   524  
   525  	return newAlerts
   526  }
   527  
   528  func createAlertForDecision(decision *models.Decision) *models.Alert {
   529  	var (
   530  		scenario string
   531  		scope    string
   532  	)
   533  
   534  	switch *decision.Origin {
   535  	case types.CAPIOrigin:
   536  		scenario = types.CAPIOrigin
   537  		scope = types.CAPIOrigin
   538  	case types.ListOrigin:
   539  		scenario = *decision.Scenario
   540  		scope = types.ListOrigin
   541  	default:
   542  		scenario = ""
   543  		scope = ""
   544  
   545  		log.Warningf("unknown origin %s", *decision.Origin)
   546  	}
   547  
   548  	return &models.Alert{
   549  		Source: &models.Source{
   550  			Scope: ptr.Of(scope),
   551  			Value: ptr.Of(""),
   552  		},
   553  		Scenario:        ptr.Of(scenario),
   554  		Message:         ptr.Of(""),
   555  		StartAt:         ptr.Of(time.Now().UTC().Format(time.RFC3339)),
   556  		StopAt:          ptr.Of(time.Now().UTC().Format(time.RFC3339)),
   557  		Capacity:        ptr.Of(int32(0)),
   558  		Simulated:       ptr.Of(false),
   559  		EventsCount:     ptr.Of(int32(0)),
   560  		Leakspeed:       ptr.Of(""),
   561  		ScenarioHash:    ptr.Of(""),
   562  		ScenarioVersion: ptr.Of(""),
   563  		MachineID:       database.CapiMachineID,
   564  	}
   565  }
   566  
   567  // This function takes in list of parent alerts and decisions and then pairs them up.
   568  func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert {
   569  	for _, decision := range decisions {
   570  		//count and create separate alerts for each list
   571  		updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1)
   572  
   573  		/*CAPI might send lower case scopes, unify it.*/
   574  		switch strings.ToLower(*decision.Scope) {
   575  		case "ip":
   576  			*decision.Scope = types.Ip
   577  		case "range":
   578  			*decision.Scope = types.Range
   579  		}
   580  
   581  		found := false
   582  		//add the individual decisions to the right list
   583  		for idx, alert := range alerts {
   584  			if *decision.Origin == types.CAPIOrigin {
   585  				if *alert.Source.Scope == types.CAPIOrigin {
   586  					alerts[idx].Decisions = append(alerts[idx].Decisions, decision)
   587  					found = true
   588  
   589  					break
   590  				}
   591  			} else if *decision.Origin == types.ListOrigin {
   592  				if *alert.Source.Scope == types.ListOrigin && *alert.Scenario == *decision.Scenario {
   593  					alerts[idx].Decisions = append(alerts[idx].Decisions, decision)
   594  					found = true
   595  					break
   596  				}
   597  			} else {
   598  				log.Warningf("unknown origin %s", *decision.Origin)
   599  			}
   600  		}
   601  
   602  		if !found {
   603  			log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario)
   604  		}
   605  	}
   606  
   607  	return alerts
   608  }
   609  
   610  // we receive a list of decisions and links for blocklist and we need to create a list of alerts :
   611  // one alert for "community blocklist"
   612  // one alert per list we're subscribed to
   613  func (a *apic) PullTop(forcePull bool) error {
   614  	var err error
   615  
   616  	//A mutex with TryLock would be a bit simpler
   617  	//But go does not guarantee that TryLock will be able to acquire the lock even if it is available
   618  	select {
   619  	case a.isPulling <- true:
   620  		defer func() {
   621  			<-a.isPulling
   622  		}()
   623  	default:
   624  		return errors.New("pull already in progress")
   625  	}
   626  
   627  	if !forcePull {
   628  		if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil {
   629  			return err
   630  		} else if !lastPullIsOld {
   631  			return nil
   632  		}
   633  	}
   634  
   635  	log.Debug("Acquiring lock for pullCAPI")
   636  	err = a.dbClient.AcquirePullCAPILock()
   637  	if a.dbClient.IsLocked(err) {
   638  		log.Info("PullCAPI is already running, skipping")
   639  		return nil
   640  	}
   641  
   642  	/*defer lock release*/
   643  	defer func() {
   644  		log.Debug("Releasing lock for pullCAPI")
   645  		if err := a.dbClient.ReleasePullCAPILock(); err != nil {
   646  			log.Errorf("while releasing lock: %v", err)
   647  		}
   648  	}()
   649  
   650  	log.Infof("Starting community-blocklist update")
   651  
   652  	data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup})
   653  	if err != nil {
   654  		return fmt.Errorf("get stream: %w", err)
   655  	}
   656  
   657  	a.startup = false
   658  	/*to count additions/deletions across lists*/
   659  
   660  	log.Debugf("Received %d new decisions", len(data.New))
   661  	log.Debugf("Received %d deleted decisions", len(data.Deleted))
   662  
   663  	if data.Links != nil {
   664  		log.Debugf("Received %d blocklists links", len(data.Links.Blocklists))
   665  	}
   666  
   667  	addCounters, deleteCounters := makeAddAndDeleteCounters()
   668  
   669  	// process deleted decisions
   670  	nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters)
   671  	if err != nil {
   672  		return err
   673  	}
   674  
   675  	log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
   676  
   677  	if len(data.New) == 0 {
   678  		log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)")
   679  		return nil
   680  	}
   681  
   682  	// create one alert for community blocklist using the first decision
   683  	decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New)
   684  	//apply APIC specific whitelists
   685  	decisions = a.ApplyApicWhitelists(decisions)
   686  
   687  	alert := createAlertForDecision(decisions[0])
   688  	alertsFromCapi := []*models.Alert{alert}
   689  	alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
   690  
   691  	err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters)
   692  	if err != nil {
   693  		return fmt.Errorf("while saving alerts: %w", err)
   694  	}
   695  
   696  	// update blocklists
   697  	if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
   698  		return fmt.Errorf("while updating blocklists: %w", err)
   699  	}
   700  
   701  	return nil
   702  }
   703  
   704  // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
   705  func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
   706  	addCounters, _ := makeAddAndDeleteCounters()
   707  	if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
   708  		Blocklists: []*modelscapi.BlocklistLink{blocklist},
   709  	}, addCounters, forcePull); err != nil {
   710  		return fmt.Errorf("while pulling blocklist: %w", err)
   711  	}
   712  
   713  	return nil
   714  }
   715  
   716  // if decisions is whitelisted: return representation of the whitelist ip or cidr
   717  // if not whitelisted: empty string
   718  func (a *apic) whitelistedBy(decision *models.Decision) string {
   719  	if decision.Value == nil {
   720  		return ""
   721  	}
   722  
   723  	ipval := net.ParseIP(*decision.Value)
   724  	for _, cidr := range a.whitelists.Cidrs {
   725  		if cidr.Contains(ipval) {
   726  			return cidr.String()
   727  		}
   728  	}
   729  
   730  	for _, ip := range a.whitelists.Ips {
   731  		if ip != nil && ip.Equal(ipval) {
   732  			return ip.String()
   733  		}
   734  	}
   735  
   736  	return ""
   737  }
   738  
   739  func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decision {
   740  	if a.whitelists == nil || len(a.whitelists.Cidrs) == 0 && len(a.whitelists.Ips) == 0 {
   741  		return decisions
   742  	}
   743  	//deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place
   744  	outIdx := 0
   745  
   746  	for _, decision := range decisions {
   747  		whitelister := a.whitelistedBy(decision)
   748  		if whitelister != "" {
   749  			log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister)
   750  			continue
   751  		}
   752  
   753  		decisions[outIdx] = decision
   754  		outIdx++
   755  	}
   756  	//shrink the list, those are deleted items
   757  	return decisions[:outIdx]
   758  }
   759  
   760  func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
   761  	for _, alert := range alertsFromCapi {
   762  		setAlertScenario(alert, addCounters, deleteCounters)
   763  		log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
   764  
   765  		if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) {
   766  			log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
   767  		}
   768  
   769  		alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert)
   770  		if err != nil {
   771  			return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
   772  		}
   773  
   774  		log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID)
   775  	}
   776  
   777  	return nil
   778  }
   779  
   780  func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) {
   781  	// we should force pull if the blocklist decisions are about to expire or there's no decision in the db
   782  	alertQuery := a.dbClient.Ent.Alert.Query()
   783  	alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
   784  	alertQuery.Order(ent.Desc(alert.FieldCreatedAt))
   785  	alertInstance, err := alertQuery.First(context.Background())
   786  
   787  	if err != nil {
   788  		if ent.IsNotFound(err) {
   789  			log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
   790  			return true, nil
   791  		}
   792  
   793  		return false, fmt.Errorf("while getting alert: %w", err)
   794  	}
   795  
   796  	decisionQuery := a.dbClient.Ent.Decision.Query()
   797  	decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))
   798  	firstDecision, err := decisionQuery.First(context.Background())
   799  
   800  	if err != nil {
   801  		if ent.IsNotFound(err) {
   802  			log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
   803  			return true, nil
   804  		}
   805  
   806  		return false, fmt.Errorf("while getting decision: %w", err)
   807  	}
   808  
   809  	if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) {
   810  		log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name)
   811  		return true, nil
   812  	}
   813  
   814  	return false, nil
   815  }
   816  
   817  func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
   818  	if blocklist.Scope == nil {
   819  		log.Warningf("blocklist has no scope")
   820  		return nil
   821  	}
   822  
   823  	if blocklist.Duration == nil {
   824  		log.Warningf("blocklist has no duration")
   825  		return nil
   826  	}
   827  
   828  	if !forcePull {
   829  		_forcePull, err := a.ShouldForcePullBlocklist(blocklist)
   830  		if err != nil {
   831  			return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
   832  		}
   833  
   834  		forcePull = _forcePull
   835  	}
   836  
   837  	blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
   838  
   839  	var (
   840  		lastPullTimestamp *string
   841  		err               error
   842  	)
   843  
   844  	if !forcePull {
   845  		lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
   846  		if err != nil {
   847  			return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
   848  		}
   849  	}
   850  
   851  	decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
   852  	if err != nil {
   853  		return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
   854  	}
   855  
   856  	if !hasChanged {
   857  		if lastPullTimestamp == nil {
   858  			log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
   859  		} else {
   860  			log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
   861  		}
   862  
   863  		return nil
   864  	}
   865  
   866  	err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
   867  	if err != nil {
   868  		return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
   869  	}
   870  
   871  	if len(decisions) == 0 {
   872  		log.Infof("blocklist %s has no decisions", *blocklist.Name)
   873  		return nil
   874  	}
   875  	//apply APIC specific whitelists
   876  	decisions = a.ApplyApicWhitelists(decisions)
   877  	alert := createAlertForDecision(decisions[0])
   878  	alertsFromCapi := []*models.Alert{alert}
   879  	alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
   880  
   881  	err = a.SaveAlerts(alertsFromCapi, addCounters, nil)
   882  	if err != nil {
   883  		return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
   884  	}
   885  
   886  	return nil
   887  }
   888  
   889  func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
   890  	if links == nil {
   891  		return nil
   892  	}
   893  
   894  	if links.Blocklists == nil {
   895  		return nil
   896  	}
   897  	// we must use a different http client than apiClient's because the transport of apiClient is jwtTransport or here we have signed apis that are incompatibles
   898  	// we can use the same baseUrl as the urls are absolute and the parse will take care of it
   899  	defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil)
   900  	if err != nil {
   901  		return fmt.Errorf("while creating default client: %w", err)
   902  	}
   903  
   904  	for _, blocklist := range links.Blocklists {
   905  		if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
   906  			return err
   907  		}
   908  	}
   909  
   910  	return nil
   911  }
   912  
   913  func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) {
   914  	if *alert.Source.Scope == types.CAPIOrigin {
   915  		*alert.Source.Scope = types.CommunityBlocklistPullSourceScope
   916  		alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"]))
   917  	} else if *alert.Source.Scope == types.ListOrigin {
   918  		*alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario)
   919  		alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario]))
   920  	}
   921  }
   922  
   923  func (a *apic) Pull() error {
   924  	defer trace.CatchPanic("lapi/pullFromAPIC")
   925  
   926  	toldOnce := false
   927  
   928  	for {
   929  		scenario, err := a.FetchScenariosListFromDB()
   930  		if err != nil {
   931  			log.Errorf("unable to fetch scenarios from db: %s", err)
   932  		}
   933  
   934  		if len(scenario) > 0 {
   935  			break
   936  		}
   937  
   938  		if !toldOnce {
   939  			log.Warning("scenario list is empty, will not pull yet")
   940  
   941  			toldOnce = true
   942  		}
   943  
   944  		time.Sleep(1 * time.Second)
   945  	}
   946  
   947  	if err := a.PullTop(false); err != nil {
   948  		log.Errorf("capi pull top: %s", err)
   949  	}
   950  
   951  	log.Infof("Start pull from CrowdSec Central API (interval: %s once, then %s)", a.pullIntervalFirst.Round(time.Second), a.pullInterval)
   952  	ticker := time.NewTicker(a.pullIntervalFirst)
   953  
   954  	for {
   955  		select {
   956  		case <-ticker.C:
   957  			ticker.Reset(a.pullInterval)
   958  
   959  			if err := a.PullTop(false); err != nil {
   960  				log.Errorf("capi pull top: %s", err)
   961  				continue
   962  			}
   963  		case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others?
   964  			a.metricsTomb.Kill(nil)
   965  			a.pushTomb.Kill(nil)
   966  
   967  			return nil
   968  		}
   969  	}
   970  }
   971  
   972  func (a *apic) Shutdown() {
   973  	a.pushTomb.Kill(nil)
   974  	a.pullTomb.Kill(nil)
   975  	a.metricsTomb.Kill(nil)
   976  }
   977  
   978  func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) {
   979  	addCounters := make(map[string]map[string]int)
   980  	addCounters[types.CAPIOrigin] = make(map[string]int)
   981  	addCounters[types.ListOrigin] = make(map[string]int)
   982  
   983  	deleteCounters := make(map[string]map[string]int)
   984  	deleteCounters[types.CAPIOrigin] = make(map[string]int)
   985  	deleteCounters[types.ListOrigin] = make(map[string]int)
   986  
   987  	return addCounters, deleteCounters
   988  }
   989  
   990  func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) {
   991  	if *origin == types.CAPIOrigin {
   992  		counter[*origin]["all"] += totalDecisions
   993  	} else if *origin == types.ListOrigin {
   994  		counter[*origin][*scenario] += totalDecisions
   995  	} else {
   996  		log.Warningf("Unknown origin %s", *origin)
   997  	}
   998  }