github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiclient/auth_service_test.go (about)

     1  package apiclient
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"testing"
    12  
    13  	log "github.com/sirupsen/logrus"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/crowdsecurity/go-cs-lib/version"
    18  
    19  	"github.com/crowdsecurity/crowdsec/pkg/models"
    20  )
    21  
    22  type BasicMockPayload struct {
    23  	MachineID string `json:"machine_id"`
    24  	Password  string `json:"password"`
    25  }
    26  
    27  func getLoginsForMockErrorCases() map[string]int {
    28  	return map[string]int{
    29  		"login_400": http.StatusBadRequest,
    30  		"login_409": http.StatusConflict,
    31  		"login_500": http.StatusInternalServerError,
    32  	}
    33  }
    34  
    35  func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
    36  	loginsForMockErrorCases := getLoginsForMockErrorCases()
    37  
    38  	mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
    39  		testMethod(t, r, "POST")
    40  		buf := new(bytes.Buffer)
    41  		_, _ = buf.ReadFrom(r.Body)
    42  		newStr := buf.String()
    43  
    44  		var payload BasicMockPayload
    45  		err := json.Unmarshal([]byte(newStr), &payload)
    46  		if err != nil || payload.MachineID == "" || payload.Password == "" {
    47  			log.Printf("Bad payload")
    48  			w.WriteHeader(http.StatusBadRequest)
    49  		}
    50  
    51  		var responseBody string
    52  		responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID]
    53  
    54  		if !hasFoundErrorMock {
    55  			responseCode = http.StatusOK
    56  			responseBody = `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`
    57  		} else {
    58  			responseBody = fmt.Sprintf("Error %d", responseCode)
    59  		}
    60  
    61  		log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode)
    62  
    63  		w.WriteHeader(responseCode)
    64  		fmt.Fprintf(w, `%s`, responseBody)
    65  	})
    66  }
    67  
    68  /**
    69   * Test the RegisterClient function
    70   * Making sure it handles the different response code potentially coming from CAPI properly
    71   * 200 => OK
    72   * 400, 409, 500 => Error
    73   */
    74  func TestWatcherRegister(t *testing.T) {
    75  	log.SetLevel(log.DebugLevel)
    76  
    77  	mux, urlx, teardown := setup()
    78  	defer teardown()
    79  
    80  	//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
    81  	initBasicMuxMock(t, mux, "/watchers")
    82  	log.Printf("URL is %s", urlx)
    83  
    84  	apiURL, err := url.Parse(urlx + "/")
    85  	require.NoError(t, err)
    86  
    87  	// Valid Registration : should retrieve the client and no err
    88  	clientconfig := Config{
    89  		MachineID:     "test_login",
    90  		Password:      "test_password",
    91  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
    92  		URL:           apiURL,
    93  		VersionPrefix: "v1",
    94  	}
    95  
    96  	client, err := RegisterClient(&clientconfig, &http.Client{})
    97  	require.NoError(t, err)
    98  
    99  	log.Printf("->%T", client)
   100  
   101  	// Testing error handling on Registration (400, 409, 500): should retrieve an error
   102  	errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError}
   103  	for _, errorCodeToTest := range errorCodesToTest {
   104  		clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
   105  
   106  		client, err = RegisterClient(&clientconfig, &http.Client{})
   107  		require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest)
   108  		require.Error(t, err, "error expected for the response code %d", errorCodeToTest)
   109  	}
   110  }
   111  
   112  func TestWatcherAuth(t *testing.T) {
   113  	log.SetLevel(log.DebugLevel)
   114  
   115  	mux, urlx, teardown := setup()
   116  	defer teardown()
   117  	//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
   118  
   119  	initBasicMuxMock(t, mux, "/watchers/login")
   120  	log.Printf("URL is %s", urlx)
   121  
   122  	apiURL, err := url.Parse(urlx + "/")
   123  	require.NoError(t, err)
   124  
   125  	//ok auth
   126  	clientConfig := &Config{
   127  		MachineID:     "test_login",
   128  		Password:      "test_password",
   129  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   130  		URL:           apiURL,
   131  		VersionPrefix: "v1",
   132  		Scenarios:     []string{"crowdsecurity/test"},
   133  	}
   134  
   135  	client, err := NewClient(clientConfig)
   136  	require.NoError(t, err)
   137  
   138  	_, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
   139  		MachineID: &clientConfig.MachineID,
   140  		Password:  &clientConfig.Password,
   141  		Scenarios: clientConfig.Scenarios,
   142  	})
   143  	require.NoError(t, err)
   144  
   145  	// Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error
   146  	// Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array
   147  	errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict}
   148  	for _, errorCodeToTest := range errorCodesToTest {
   149  		clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
   150  
   151  		client, err := NewClient(clientConfig)
   152  		require.NoError(t, err)
   153  
   154  		_, resp, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
   155  			MachineID: &clientConfig.MachineID,
   156  			Password:  &clientConfig.Password,
   157  		})
   158  
   159  		if err == nil {
   160  			resp.Response.Body.Close()
   161  
   162  			bodyBytes, err := io.ReadAll(resp.Response.Body)
   163  			require.NoError(t, err)
   164  
   165  			log.Printf(string(bodyBytes))
   166  			t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
   167  		}
   168  
   169  		log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest)
   170  	}
   171  }
   172  
   173  func TestWatcherUnregister(t *testing.T) {
   174  	log.SetLevel(log.DebugLevel)
   175  
   176  	mux, urlx, teardown := setup()
   177  	defer teardown()
   178  	//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
   179  
   180  	mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
   181  		testMethod(t, r, "DELETE")
   182  		assert.Equal(t, int64(0), r.ContentLength)
   183  		w.WriteHeader(http.StatusOK)
   184  	})
   185  
   186  	mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
   187  		testMethod(t, r, "POST")
   188  		buf := new(bytes.Buffer)
   189  		_, _ = buf.ReadFrom(r.Body)
   190  
   191  		newStr := buf.String()
   192  		if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]}
   193  ` {
   194  			w.WriteHeader(http.StatusOK)
   195  			fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
   196  		} else {
   197  			w.WriteHeader(http.StatusForbidden)
   198  			fmt.Fprintf(w, `{"message":"access forbidden"}`)
   199  		}
   200  	})
   201  
   202  	log.Printf("URL is %s", urlx)
   203  
   204  	apiURL, err := url.Parse(urlx + "/")
   205  	require.NoError(t, err)
   206  
   207  	mycfg := &Config{
   208  		MachineID:     "test_login",
   209  		Password:      "test_password",
   210  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   211  		URL:           apiURL,
   212  		VersionPrefix: "v1",
   213  		Scenarios:     []string{"crowdsecurity/test"},
   214  	}
   215  
   216  	client, err := NewClient(mycfg)
   217  	require.NoError(t, err)
   218  
   219  	_, err = client.Auth.UnregisterWatcher(context.Background())
   220  	require.NoError(t, err)
   221  
   222  	log.Printf("->%T", client)
   223  }
   224  
   225  func TestWatcherEnroll(t *testing.T) {
   226  	log.SetLevel(log.DebugLevel)
   227  
   228  	mux, urlx, teardown := setup()
   229  	defer teardown()
   230  
   231  	mux.HandleFunc("/watchers/enroll", func(w http.ResponseWriter, r *http.Request) {
   232  		testMethod(t, r, "POST")
   233  		buf := new(bytes.Buffer)
   234  		_, _ = buf.ReadFrom(r.Body)
   235  		newStr := buf.String()
   236  		log.Debugf("body -> %s", newStr)
   237  
   238  		if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false}
   239  ` {
   240  			log.Print("good key")
   241  			w.WriteHeader(http.StatusOK)
   242  			fmt.Fprintf(w, `{"statusCode": 200, "message": "OK"}`)
   243  		} else {
   244  			log.Print("bad key")
   245  			w.WriteHeader(http.StatusForbidden)
   246  			fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`)
   247  		}
   248  	})
   249  
   250  	mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
   251  		testMethod(t, r, "POST")
   252  		w.WriteHeader(http.StatusOK)
   253  		fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
   254  	})
   255  
   256  	log.Printf("URL is %s", urlx)
   257  
   258  	apiURL, err := url.Parse(urlx + "/")
   259  	require.NoError(t, err)
   260  
   261  	mycfg := &Config{
   262  		MachineID:     "test_login",
   263  		Password:      "test_password",
   264  		UserAgent:     fmt.Sprintf("crowdsec/%s", version.String()),
   265  		URL:           apiURL,
   266  		VersionPrefix: "v1",
   267  		Scenarios:     []string{"crowdsecurity/test"},
   268  	}
   269  
   270  	client, err := NewClient(mycfg)
   271  	require.NoError(t, err)
   272  
   273  	_, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false)
   274  	require.NoError(t, err)
   275  
   276  	_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
   277  	assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error())
   278  }