github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiclient/client_test.go (about)

     1  package apiclient
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  	"path"
    11  	"runtime"
    12  	"strings"
    13  	"testing"
    14  
    15  	log "github.com/sirupsen/logrus"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"github.com/crowdsecurity/go-cs-lib/cstest"
    20  	"github.com/crowdsecurity/go-cs-lib/version"
    21  )
    22  
    23  /*this is a ripoff of google/go-github approach :
    24  - setup a test http server along with a client that is configured to talk to test server
    25  - each test will then bind handler for the method(s) they want to try
    26  */
    27  
    28  func setup() (*http.ServeMux, string, func()) {
    29  	return setupWithPrefix("v1")
    30  }
    31  
    32  func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) {
    33  	// mux is the HTTP request multiplexer used with the test server.
    34  	mux := http.NewServeMux()
    35  	baseURLPath := "/" + urlPrefix
    36  
    37  	apiHandler := http.NewServeMux()
    38  	apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux))
    39  
    40  	server := httptest.NewServer(apiHandler)
    41  
    42  	return mux, server.URL, server.Close
    43  }
    44  
    45  // toUNCPath converts a Windows file path to a UNC path.
    46  // This is necessary because the Go http package does not support Windows file paths.
    47  func toUNCPath(path string) (string, error) {
    48  	colonIdx := strings.Index(path, ":")
    49  	if colonIdx == -1 {
    50  		return "", fmt.Errorf("invalid path format, missing drive letter: %s", path)
    51  	}
    52  
    53  	// URL parsing does not like backslashes
    54  	remaining := strings.ReplaceAll(path[colonIdx+1:], "\\", "/")
    55  	uncPath := "//localhost/" + path[:colonIdx] + "$" + remaining
    56  
    57  	return uncPath, nil
    58  }
    59  
    60  func setupUnixSocketWithPrefix(socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) {
    61  	var err error
    62  	if runtime.GOOS == "windows" {
    63  		socket, err = toUNCPath(socket)
    64  		if err != nil {
    65  			log.Fatalf("converting to UNC path: %s", err)
    66  		}
    67  	}
    68  
    69  	mux = http.NewServeMux()
    70  	baseURLPath := "/" + urlPrefix
    71  
    72  	apiHandler := http.NewServeMux()
    73  	apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux))
    74  
    75  	server := httptest.NewUnstartedServer(apiHandler)
    76  	l, _ := net.Listen("unix", socket)
    77  	_ = server.Listener.Close()
    78  	server.Listener = l
    79  	server.Start()
    80  
    81  	return mux, socket, server.Close
    82  }
    83  
    84  func testMethod(t *testing.T, r *http.Request, want string) {
    85  	t.Helper()
    86  	assert.Equal(t, want, r.Method)
    87  }
    88  
    89  func TestNewClientOk(t *testing.T) {
    90  	mux, urlx, teardown := setup()
    91  	defer teardown()
    92  
    93  	apiURL, err := url.Parse(urlx + "/")
    94  	require.NoError(t, err)
    95  
    96  	client, err := NewClient(&Config{
    97  		MachineID:     "test_login",
    98  		Password:      "test_password",
    99  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   100  		URL:           apiURL,
   101  		VersionPrefix: "v1",
   102  	})
   103  	require.NoError(t, err)
   104  
   105  	/*mock login*/
   106  	mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
   107  		w.WriteHeader(http.StatusOK)
   108  		w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
   109  	})
   110  
   111  	mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
   112  		testMethod(t, r, "GET")
   113  		w.WriteHeader(http.StatusOK)
   114  	})
   115  
   116  	_, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
   117  	require.NoError(t, err)
   118  	assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
   119  }
   120  
   121  func TestNewClientOk_UnixSocket(t *testing.T) {
   122  	tmpDir := t.TempDir()
   123  	socket := path.Join(tmpDir, "socket")
   124  
   125  	mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
   126  	defer teardown()
   127  
   128  	apiURL, err := url.Parse(urlx)
   129  	if err != nil {
   130  		t.Fatalf("parsing api url: %s", apiURL)
   131  	}
   132  
   133  	client, err := NewClient(&Config{
   134  		MachineID:     "test_login",
   135  		Password:      "test_password",
   136  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   137  		URL:           apiURL,
   138  		VersionPrefix: "v1",
   139  	})
   140  	if err != nil {
   141  		t.Fatalf("new api client: %s", err)
   142  	}
   143  	/*mock login*/
   144  	mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
   145  		w.WriteHeader(http.StatusOK)
   146  		w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
   147  	})
   148  
   149  	mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
   150  		testMethod(t, r, "GET")
   151  		w.WriteHeader(http.StatusOK)
   152  	})
   153  
   154  	_, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
   155  	if err != nil {
   156  		t.Fatalf("test Unable to list alerts : %+v", err)
   157  	}
   158  
   159  	if resp.Response.StatusCode != http.StatusOK {
   160  		t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
   161  	}
   162  }
   163  
   164  func TestNewClientKo(t *testing.T) {
   165  	mux, urlx, teardown := setup()
   166  	defer teardown()
   167  
   168  	apiURL, err := url.Parse(urlx + "/")
   169  	require.NoError(t, err)
   170  
   171  	client, err := NewClient(&Config{
   172  		MachineID:     "test_login",
   173  		Password:      "test_password",
   174  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   175  		URL:           apiURL,
   176  		VersionPrefix: "v1",
   177  	})
   178  	require.NoError(t, err)
   179  
   180  	/*mock login*/
   181  	mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
   182  		w.WriteHeader(http.StatusUnauthorized)
   183  		w.Write([]byte(`{"code": 401, "message" : "bad login/password"}`))
   184  	})
   185  
   186  	mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
   187  		testMethod(t, r, "GET")
   188  		w.WriteHeader(http.StatusOK)
   189  	})
   190  
   191  	_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
   192  	cstest.RequireErrorContains(t, err, `API error: bad login/password`)
   193  
   194  	log.Printf("err-> %s", err)
   195  }
   196  
   197  func TestNewDefaultClient(t *testing.T) {
   198  	mux, urlx, teardown := setup()
   199  	defer teardown()
   200  
   201  	apiURL, err := url.Parse(urlx + "/")
   202  	require.NoError(t, err)
   203  
   204  	client, err := NewDefaultClient(apiURL, "/v1", "", nil)
   205  	require.NoError(t, err)
   206  
   207  	mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
   208  		w.WriteHeader(http.StatusUnauthorized)
   209  		w.Write([]byte(`{"code": 401, "message" : "brr"}`))
   210  	})
   211  
   212  	_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
   213  	cstest.RequireErrorMessage(t, err, "performing request: API error: brr")
   214  
   215  	log.Printf("err-> %s", err)
   216  }
   217  
   218  func TestNewDefaultClient_UnixSocket(t *testing.T) {
   219  	tmpDir := t.TempDir()
   220  	socket := path.Join(tmpDir, "socket")
   221  
   222  	mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
   223  	defer teardown()
   224  
   225  	apiURL, err := url.Parse(urlx)
   226  	if err != nil {
   227  		t.Fatalf("parsing api url: %s", apiURL)
   228  	}
   229  
   230  	client, err := NewDefaultClient(apiURL, "/v1", "", nil)
   231  	if err != nil {
   232  		t.Fatalf("new api client: %s", err)
   233  	}
   234  
   235  	mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
   236  		w.WriteHeader(http.StatusUnauthorized)
   237  		w.Write([]byte(`{"code": 401, "message" : "brr"}`))
   238  	})
   239  
   240  	_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
   241  	assert.Contains(t, err.Error(), `performing request: API error: brr`)
   242  	log.Printf("err-> %s", err)
   243  }
   244  
   245  func TestNewClientRegisterKO(t *testing.T) {
   246  	apiURL, err := url.Parse("http://127.0.0.1:4242/")
   247  	require.NoError(t, err)
   248  
   249  	_, err = RegisterClient(&Config{
   250  		MachineID:     "test_login",
   251  		Password:      "test_password",
   252  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   253  		URL:           apiURL,
   254  		VersionPrefix: "v1",
   255  	}, &http.Client{})
   256  
   257  	if runtime.GOOS == "windows" {
   258  		cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.")
   259  	} else {
   260  		cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused")
   261  	}
   262  }
   263  
   264  func TestNewClientRegisterOK(t *testing.T) {
   265  	log.SetLevel(log.TraceLevel)
   266  
   267  	mux, urlx, teardown := setup()
   268  	defer teardown()
   269  
   270  	/*mock login*/
   271  	mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
   272  		testMethod(t, r, "POST")
   273  		w.WriteHeader(http.StatusOK)
   274  		w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
   275  	})
   276  
   277  	apiURL, err := url.Parse(urlx + "/")
   278  	require.NoError(t, err)
   279  
   280  	client, err := RegisterClient(&Config{
   281  		MachineID:     "test_login",
   282  		Password:      "test_password",
   283  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   284  		URL:           apiURL,
   285  		VersionPrefix: "v1",
   286  	}, &http.Client{})
   287  	require.NoError(t, err)
   288  
   289  	log.Printf("->%T", client)
   290  }
   291  
   292  func TestNewClientRegisterOK_UnixSocket(t *testing.T) {
   293  	log.SetLevel(log.TraceLevel)
   294  
   295  	tmpDir := t.TempDir()
   296  	socket := path.Join(tmpDir, "socket")
   297  
   298  	mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
   299  	defer teardown()
   300  
   301  	/*mock login*/
   302  	mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
   303  		testMethod(t, r, "POST")
   304  		w.WriteHeader(http.StatusOK)
   305  		w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
   306  	})
   307  
   308  	apiURL, err := url.Parse(urlx)
   309  	if err != nil {
   310  		t.Fatalf("parsing api url: %s", apiURL)
   311  	}
   312  
   313  	client, err := RegisterClient(&Config{
   314  		MachineID:     "test_login",
   315  		Password:      "test_password",
   316  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   317  		URL:           apiURL,
   318  		VersionPrefix: "v1",
   319  	}, &http.Client{})
   320  	if err != nil {
   321  		t.Fatalf("while registering client : %s", err)
   322  	}
   323  
   324  	log.Printf("->%T", client)
   325  }
   326  
   327  func TestNewClientBadAnswer(t *testing.T) {
   328  	log.SetLevel(log.TraceLevel)
   329  
   330  	mux, urlx, teardown := setup()
   331  	defer teardown()
   332  
   333  	/*mock login*/
   334  	mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
   335  		testMethod(t, r, "POST")
   336  		w.WriteHeader(http.StatusUnauthorized)
   337  		w.Write([]byte(`bad`))
   338  	})
   339  
   340  	apiURL, err := url.Parse(urlx + "/")
   341  	require.NoError(t, err)
   342  
   343  	_, err = RegisterClient(&Config{
   344  		MachineID:     "test_login",
   345  		Password:      "test_password",
   346  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   347  		URL:           apiURL,
   348  		VersionPrefix: "v1",
   349  	}, &http.Client{})
   350  	cstest.RequireErrorContains(t, err, "invalid body: invalid character 'b' looking for beginning of value")
   351  }