github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiserver/apiserver_test.go (about)

     1  package apiserver
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/gin-gonic/gin"
    15  	"github.com/go-openapi/strfmt"
    16  	log "github.com/sirupsen/logrus"
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/crowdsecurity/go-cs-lib/cstest"
    21  	"github.com/crowdsecurity/go-cs-lib/ptr"
    22  	"github.com/crowdsecurity/go-cs-lib/version"
    23  
    24  	middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
    25  	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
    26  	"github.com/crowdsecurity/crowdsec/pkg/database"
    27  	"github.com/crowdsecurity/crowdsec/pkg/models"
    28  	"github.com/crowdsecurity/crowdsec/pkg/types"
    29  )
    30  
    31  var testMachineID = "test"
    32  var testPassword = strfmt.Password("test")
    33  var MachineTest = models.WatcherAuthRequest{
    34  	MachineID: &testMachineID,
    35  	Password:  &testPassword,
    36  }
    37  
    38  var UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version)
    39  var emptyBody = strings.NewReader("")
    40  
    41  func LoadTestConfig(t *testing.T) csconfig.Config {
    42  	config := csconfig.Config{}
    43  	maxAge := "1h"
    44  	flushConfig := csconfig.FlushDBCfg{
    45  		MaxAge: &maxAge,
    46  	}
    47  
    48  	tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
    49  
    50  	t.Cleanup(func() { os.RemoveAll(tempDir) })
    51  
    52  	dbconfig := csconfig.DatabaseCfg{
    53  		Type:   "sqlite",
    54  		DbPath: filepath.Join(tempDir, "ent"),
    55  		Flush:  &flushConfig,
    56  	}
    57  	apiServerConfig := csconfig.LocalApiServerCfg{
    58  		ListenURI:    "http://127.0.0.1:8080",
    59  		DbConfig:     &dbconfig,
    60  		ProfilesPath: "./tests/profiles.yaml",
    61  		ConsoleConfig: &csconfig.ConsoleConfig{
    62  			ShareManualDecisions:  new(bool),
    63  			ShareTaintedScenarios: new(bool),
    64  			ShareCustomScenarios:  new(bool),
    65  		},
    66  	}
    67  
    68  	apiConfig := csconfig.APICfg{
    69  		Server: &apiServerConfig,
    70  	}
    71  
    72  	config.API = &apiConfig
    73  	err := config.API.Server.LoadProfiles()
    74  	require.NoError(t, err)
    75  
    76  	return config
    77  }
    78  
    79  func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
    80  	config := csconfig.Config{}
    81  	maxAge := "1h"
    82  	flushConfig := csconfig.FlushDBCfg{
    83  		MaxAge: &maxAge,
    84  	}
    85  
    86  	tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
    87  
    88  	t.Cleanup(func() { os.RemoveAll(tempDir) })
    89  
    90  	dbconfig := csconfig.DatabaseCfg{
    91  		Type:   "sqlite",
    92  		DbPath: filepath.Join(tempDir, "ent"),
    93  		Flush:  &flushConfig,
    94  	}
    95  	apiServerConfig := csconfig.LocalApiServerCfg{
    96  		ListenURI:              "http://127.0.0.1:8080",
    97  		DbConfig:               &dbconfig,
    98  		ProfilesPath:           "./tests/profiles.yaml",
    99  		UseForwardedForHeaders: true,
   100  		TrustedProxies:         &[]string{"0.0.0.0/0"},
   101  		ConsoleConfig: &csconfig.ConsoleConfig{
   102  			ShareManualDecisions:  new(bool),
   103  			ShareTaintedScenarios: new(bool),
   104  			ShareCustomScenarios:  new(bool),
   105  		},
   106  	}
   107  	apiConfig := csconfig.APICfg{
   108  		Server: &apiServerConfig,
   109  	}
   110  	config.API = &apiConfig
   111  	err := config.API.Server.LoadProfiles()
   112  	require.NoError(t, err)
   113  
   114  	return config
   115  }
   116  
   117  func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
   118  	config := LoadTestConfig(t)
   119  
   120  	os.Remove("./ent")
   121  
   122  	apiServer, err := NewServer(config.API.Server)
   123  	require.NoError(t, err)
   124  
   125  	log.Printf("Creating new API server")
   126  	gin.SetMode(gin.TestMode)
   127  
   128  	return apiServer, config
   129  }
   130  
   131  func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
   132  	apiServer, config := NewAPIServer(t)
   133  
   134  	err := apiServer.InitController()
   135  	require.NoError(t, err)
   136  
   137  	router, err := apiServer.Router()
   138  	require.NoError(t, err)
   139  
   140  	return router, config
   141  }
   142  
   143  func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
   144  	config := LoadTestConfigForwardedFor(t)
   145  
   146  	os.Remove("./ent")
   147  
   148  	apiServer, err := NewServer(config.API.Server)
   149  	require.NoError(t, err)
   150  
   151  	err = apiServer.InitController()
   152  	require.NoError(t, err)
   153  
   154  	log.Printf("Creating new API server")
   155  	gin.SetMode(gin.TestMode)
   156  
   157  	router, err := apiServer.Router()
   158  	require.NoError(t, err)
   159  
   160  	return router, config
   161  }
   162  
   163  func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
   164  	dbClient, err := database.NewClient(config)
   165  	require.NoError(t, err)
   166  
   167  	err = dbClient.ValidateMachine(machineID)
   168  	require.NoError(t, err)
   169  }
   170  
   171  func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
   172  	dbClient, err := database.NewClient(config)
   173  	require.NoError(t, err)
   174  
   175  	machines, err := dbClient.ListMachines()
   176  	require.NoError(t, err)
   177  
   178  	for _, machine := range machines {
   179  		if machine.MachineId == machineID {
   180  			return machine.IpAddress
   181  		}
   182  	}
   183  
   184  	return ""
   185  }
   186  
   187  func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader {
   188  	alertContentBytes, err := os.ReadFile(path)
   189  	require.NoError(t, err)
   190  
   191  	alerts := make([]*models.Alert, 0)
   192  	err = json.Unmarshal(alertContentBytes, &alerts)
   193  	require.NoError(t, err)
   194  
   195  	for _, alert := range alerts {
   196  		*alert.StartAt = time.Now().UTC().Format(time.RFC3339)
   197  		*alert.StopAt = time.Now().UTC().Format(time.RFC3339)
   198  	}
   199  
   200  	alertContent, err := json.Marshal(alerts)
   201  	require.NoError(t, err)
   202  
   203  	return strings.NewReader(string(alertContent))
   204  }
   205  
   206  func readDecisionsGetResp(t *testing.T, resp *httptest.ResponseRecorder) ([]*models.Decision, int) {
   207  	var response []*models.Decision
   208  
   209  	require.NotNil(t, resp)
   210  
   211  	err := json.Unmarshal(resp.Body.Bytes(), &response)
   212  	require.NoError(t, err)
   213  
   214  	return response, resp.Code
   215  }
   216  
   217  func readDecisionsErrorResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string]string, int) {
   218  	var response map[string]string
   219  
   220  	require.NotNil(t, resp)
   221  
   222  	err := json.Unmarshal(resp.Body.Bytes(), &response)
   223  	require.NoError(t, err)
   224  
   225  	return response, resp.Code
   226  }
   227  
   228  func readDecisionsDeleteResp(t *testing.T, resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int) {
   229  	var response models.DeleteDecisionResponse
   230  
   231  	require.NotNil(t, resp)
   232  	err := json.Unmarshal(resp.Body.Bytes(), &response)
   233  	require.NoError(t, err)
   234  
   235  	return &response, resp.Code
   236  }
   237  
   238  func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int) {
   239  	response := make(map[string][]*models.Decision)
   240  
   241  	require.NotNil(t, resp)
   242  	err := json.Unmarshal(resp.Body.Bytes(), &response)
   243  	require.NoError(t, err)
   244  
   245  	return response, resp.Code
   246  }
   247  
   248  func CreateTestMachine(t *testing.T, router *gin.Engine) string {
   249  	b, err := json.Marshal(MachineTest)
   250  	require.NoError(t, err)
   251  
   252  	body := string(b)
   253  
   254  	w := httptest.NewRecorder()
   255  	req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
   256  	req.Header.Set("User-Agent", UserAgent)
   257  	router.ServeHTTP(w, req)
   258  
   259  	return body
   260  }
   261  
   262  func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
   263  	dbClient, err := database.NewClient(config)
   264  	require.NoError(t, err)
   265  
   266  	apiKey, err := middlewares.GenerateAPIKey(keyLength)
   267  	require.NoError(t, err)
   268  
   269  	_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
   270  	require.NoError(t, err)
   271  
   272  	return apiKey
   273  }
   274  
   275  func TestWithWrongDBConfig(t *testing.T) {
   276  	config := LoadTestConfig(t)
   277  	config.API.Server.DbConfig.Type = "test"
   278  	apiServer, err := NewServer(config.API.Server)
   279  
   280  	cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'")
   281  	assert.Nil(t, apiServer)
   282  }
   283  
   284  func TestWithWrongFlushConfig(t *testing.T) {
   285  	config := LoadTestConfig(t)
   286  	maxItems := -1
   287  	config.API.Server.DbConfig.Flush.MaxItems = &maxItems
   288  	apiServer, err := NewServer(config.API.Server)
   289  
   290  	cstest.RequireErrorContains(t, err, "max_items can't be zero or negative number")
   291  	assert.Nil(t, apiServer)
   292  }
   293  
   294  func TestUnknownPath(t *testing.T) {
   295  	router, _ := NewAPITest(t)
   296  
   297  	w := httptest.NewRecorder()
   298  	req, _ := http.NewRequest(http.MethodGet, "/test", nil)
   299  	req.Header.Set("User-Agent", UserAgent)
   300  	router.ServeHTTP(w, req)
   301  
   302  	assert.Equal(t, 404, w.Code)
   303  }
   304  
   305  /*
   306  
   307  ListenURI              string              `yaml:"listen_uri,omitempty"` //127.0.0.1:8080
   308  	TLS                    *TLSCfg             `yaml:"tls"`
   309  	DbConfig               *DatabaseCfg        `yaml:"-"`
   310  	LogDir                 string              `yaml:"-"`
   311  	LogMedia               string              `yaml:"-"`
   312  	OnlineClient           *OnlineApiClientCfg `yaml:"online_client"`
   313  	ProfilesPath           string              `yaml:"profiles_path,omitempty"`
   314  	Profiles               []*ProfileCfg       `yaml:"-"`
   315  	LogLevel               *log.Level          `yaml:"log_level"`
   316  	UseForwardedForHeaders bool                `yaml:"use_forwarded_for_headers,omitempty"`
   317  
   318  */
   319  
   320  func TestLoggingDebugToFileConfig(t *testing.T) {
   321  	/*declare settings*/
   322  	maxAge := "1h"
   323  	flushConfig := csconfig.FlushDBCfg{
   324  		MaxAge: &maxAge,
   325  	}
   326  
   327  	tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
   328  
   329  	t.Cleanup(func() { os.RemoveAll(tempDir) })
   330  
   331  	dbconfig := csconfig.DatabaseCfg{
   332  		Type:   "sqlite",
   333  		DbPath: filepath.Join(tempDir, "ent"),
   334  		Flush:  &flushConfig,
   335  	}
   336  	cfg := csconfig.LocalApiServerCfg{
   337  		ListenURI: "127.0.0.1:8080",
   338  		LogMedia:  "file",
   339  		LogDir:    tempDir,
   340  		DbConfig:  &dbconfig,
   341  	}
   342  	expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
   343  	expectedLines := []string{"/test42"}
   344  	cfg.LogLevel = ptr.Of(log.DebugLevel)
   345  
   346  	// Configure logging
   347  	err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
   348  	require.NoError(t, err)
   349  
   350  	api, err := NewServer(&cfg)
   351  	require.NoError(t, err)
   352  	require.NotNil(t, api)
   353  
   354  	w := httptest.NewRecorder()
   355  	req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
   356  	req.Header.Set("User-Agent", UserAgent)
   357  	api.router.ServeHTTP(w, req)
   358  	assert.Equal(t, 404, w.Code)
   359  	//wait for the request to happen
   360  	time.Sleep(500 * time.Millisecond)
   361  
   362  	//check file content
   363  	data, err := os.ReadFile(expectedFile)
   364  	require.NoError(t, err)
   365  
   366  	for _, expectedStr := range expectedLines {
   367  		assert.Contains(t, string(data), expectedStr)
   368  	}
   369  }
   370  
   371  func TestLoggingErrorToFileConfig(t *testing.T) {
   372  	/*declare settings*/
   373  	maxAge := "1h"
   374  	flushConfig := csconfig.FlushDBCfg{
   375  		MaxAge: &maxAge,
   376  	}
   377  
   378  	tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
   379  
   380  	t.Cleanup(func() { os.RemoveAll(tempDir) })
   381  
   382  	dbconfig := csconfig.DatabaseCfg{
   383  		Type:   "sqlite",
   384  		DbPath: filepath.Join(tempDir, "ent"),
   385  		Flush:  &flushConfig,
   386  	}
   387  	cfg := csconfig.LocalApiServerCfg{
   388  		ListenURI: "127.0.0.1:8080",
   389  		LogMedia:  "file",
   390  		LogDir:    tempDir,
   391  		DbConfig:  &dbconfig,
   392  	}
   393  	expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
   394  	cfg.LogLevel = ptr.Of(log.ErrorLevel)
   395  
   396  	// Configure logging
   397  	err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
   398  	require.NoError(t, err)
   399  
   400  	api, err := NewServer(&cfg)
   401  	require.NoError(t, err)
   402  	require.NotNil(t, api)
   403  
   404  	w := httptest.NewRecorder()
   405  	req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
   406  	req.Header.Set("User-Agent", UserAgent)
   407  	api.router.ServeHTTP(w, req)
   408  	assert.Equal(t, http.StatusNotFound, w.Code)
   409  	//wait for the request to happen
   410  	time.Sleep(500 * time.Millisecond)
   411  
   412  	//check file content
   413  	x, err := os.ReadFile(expectedFile)
   414  	if err == nil {
   415  		require.Empty(t, x)
   416  	}
   417  
   418  	os.Remove("./crowdsec.log")
   419  	os.Remove(expectedFile)
   420  }