github.com/argoproj/argo-cd/v3@v3.2.1/server/application/websocket_test.go (about)

     1  package application
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  
    11  	corev1 "k8s.io/api/core/v1"
    12  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    13  	"k8s.io/client-go/kubernetes/fake"
    14  
    15  	"github.com/argoproj/argo-cd/v3/common"
    16  	"github.com/argoproj/argo-cd/v3/util/assets"
    17  	"github.com/argoproj/argo-cd/v3/util/rbac"
    18  
    19  	"github.com/golang-jwt/jwt/v5"
    20  	"github.com/gorilla/websocket"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  func newTestTerminalSession(w http.ResponseWriter, r *http.Request) terminalSession {
    26  	upgrader := websocket.Upgrader{}
    27  	c, err := upgrader.Upgrade(w, r, nil)
    28  	if err != nil {
    29  		return terminalSession{}
    30  	}
    31  
    32  	return terminalSession{wsConn: c}
    33  }
    34  
    35  func newEnforcer() *rbac.Enforcer {
    36  	additionalConfig := make(map[string]string, 0)
    37  	kubeclientset := fake.NewClientset(&corev1.ConfigMap{
    38  		ObjectMeta: metav1.ObjectMeta{
    39  			Namespace: testNamespace,
    40  			Name:      "argocd-cm",
    41  			Labels: map[string]string{
    42  				"app.kubernetes.io/part-of": "argocd",
    43  			},
    44  		},
    45  		Data: additionalConfig,
    46  	}, &corev1.Secret{
    47  		ObjectMeta: metav1.ObjectMeta{
    48  			Name:      "argocd-secret",
    49  			Namespace: testNamespace,
    50  		},
    51  		Data: map[string][]byte{
    52  			"admin.password":   []byte("test"),
    53  			"server.secretkey": []byte("test"),
    54  		},
    55  	})
    56  
    57  	enforcer := rbac.NewEnforcer(kubeclientset, testNamespace, common.ArgoCDRBACConfigMapName, nil)
    58  	return enforcer
    59  }
    60  
    61  func reconnect(w http.ResponseWriter, r *http.Request) {
    62  	ts := newTestTerminalSession(w, r)
    63  	_, _ = ts.reconnect()
    64  }
    65  
    66  func TestReconnect(t *testing.T) {
    67  	s := httptest.NewServer(http.HandlerFunc(reconnect))
    68  	defer s.Close()
    69  
    70  	u := "ws" + strings.TrimPrefix(s.URL, "http")
    71  
    72  	// Connect to the server
    73  	ws, _, err := websocket.DefaultDialer.Dial(u, nil)
    74  	require.NoError(t, err)
    75  
    76  	defer ws.Close()
    77  
    78  	_, p, _ := ws.ReadMessage()
    79  
    80  	var message TerminalMessage
    81  
    82  	err = json.Unmarshal(p, &message)
    83  
    84  	require.NoError(t, err)
    85  	assert.Equal(t, ReconnectMessage, message.Data)
    86  }
    87  
    88  func testServerConnection(t *testing.T, testFunc func(w http.ResponseWriter, r *http.Request), expectPermissionDenied bool) {
    89  	t.Helper()
    90  	s := httptest.NewServer(http.HandlerFunc(testFunc))
    91  	defer s.Close()
    92  
    93  	u := "ws" + strings.TrimPrefix(s.URL, "http")
    94  
    95  	// Connect to the server
    96  	ws, _, err := websocket.DefaultDialer.Dial(u, nil)
    97  	require.NoError(t, err)
    98  
    99  	defer ws.Close()
   100  	if expectPermissionDenied {
   101  		_, p, _ := ws.ReadMessage()
   102  
   103  		var message TerminalMessage
   104  
   105  		err = json.Unmarshal(p, &message)
   106  
   107  		require.NoError(t, err)
   108  		assert.Equal(t, "Permission denied", message.Data)
   109  	}
   110  }
   111  
   112  func TestVerifyAndReconnectDisableAuthTrue(t *testing.T) {
   113  	validate := func(w http.ResponseWriter, r *http.Request) {
   114  		ts := newTestTerminalSession(w, r)
   115  		// Currently testing only the usecase of disableAuth: true since the disableAuth: false case
   116  		// requires a valid token to be passed in the request.
   117  		// Note that running with disableAuth: false will surprisingly succeed as well, because
   118  		// the underlying token nil pointer dereference is swallowed in a location I didn't find,
   119  		// or even swallowed by the test framework.
   120  		ts.terminalOpts = &TerminalOptions{DisableAuth: true}
   121  		code, err := ts.performValidationsAndReconnect([]byte{})
   122  		assert.Equal(t, 0, code)
   123  		require.NoError(t, err)
   124  	}
   125  	testServerConnection(t, validate, false)
   126  }
   127  
   128  func TestValidateWithAdminPermissions(t *testing.T) {
   129  	validate := func(w http.ResponseWriter, r *http.Request) {
   130  		enf := newEnforcer()
   131  		_ = enf.SetBuiltinPolicy(assets.BuiltinPolicyCSV)
   132  		enf.SetDefaultRole("role:admin")
   133  		enf.SetClaimsEnforcerFunc(func(_ jwt.Claims, _ ...any) bool {
   134  			return true
   135  		})
   136  		ts := newTestTerminalSession(w, r)
   137  		ts.terminalOpts = &TerminalOptions{Enf: enf}
   138  		ts.appRBACName = "test"
   139  		//nolint:staticcheck
   140  		ts.ctx = context.WithValue(t.Context(), "claims", &jwt.MapClaims{"groups": []string{"admin"}})
   141  		_, err := ts.validatePermissions([]byte{})
   142  		require.NoError(t, err)
   143  	}
   144  
   145  	testServerConnection(t, validate, false)
   146  }
   147  
   148  func TestValidateWithoutPermissions(t *testing.T) {
   149  	validate := func(w http.ResponseWriter, r *http.Request) {
   150  		enf := newEnforcer()
   151  		_ = enf.SetBuiltinPolicy(assets.BuiltinPolicyCSV)
   152  		enf.SetDefaultRole("role:test")
   153  		enf.SetClaimsEnforcerFunc(func(_ jwt.Claims, _ ...any) bool {
   154  			return false
   155  		})
   156  		ts := newTestTerminalSession(w, r)
   157  		ts.terminalOpts = &TerminalOptions{Enf: enf}
   158  		ts.appRBACName = "test"
   159  		//nolint:staticcheck
   160  		ts.ctx = context.WithValue(t.Context(), "claims", &jwt.MapClaims{"groups": []string{"test"}})
   161  		_, err := ts.validatePermissions([]byte{})
   162  		require.Error(t, err)
   163  		assert.EqualError(t, err, common.PermissionDeniedAPIError.Error())
   164  	}
   165  
   166  	testServerConnection(t, validate, true)
   167  }
   168  
   169  func TestTerminalSession_Write(t *testing.T) {
   170  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   171  		upgrader := websocket.Upgrader{}
   172  		conn, err := upgrader.Upgrade(w, r, nil)
   173  		require.NoError(t, err)
   174  		defer conn.Close()
   175  
   176  		for {
   177  			// Read the message from the WebSocket connection
   178  			messageType, message, err := conn.ReadMessage()
   179  			if err != nil {
   180  				return
   181  			}
   182  			// Respond back the same message
   183  			err = conn.WriteMessage(messageType, message)
   184  			require.NoError(t, err)
   185  		}
   186  	}))
   187  	defer server.Close()
   188  
   189  	u := "ws" + strings.TrimPrefix(server.URL, "http")
   190  	wsConn, _, err := websocket.DefaultDialer.Dial(u, nil)
   191  	require.NoError(t, err)
   192  	defer wsConn.Close()
   193  
   194  	ts := terminalSession{
   195  		wsConn: wsConn,
   196  	}
   197  
   198  	testData := []byte("hello world")
   199  	expectedMessage, err := json.Marshal(TerminalMessage{
   200  		Operation: "stdout",
   201  		Data:      string(testData),
   202  	})
   203  	require.NoError(t, err)
   204  
   205  	n, err := ts.Write(testData)
   206  	require.NoError(t, err)
   207  
   208  	assert.Equal(t, len(testData), n)
   209  
   210  	_, receivedMessage, err := wsConn.ReadMessage()
   211  	require.NoError(t, err)
   212  
   213  	assert.Equal(t, expectedMessage, receivedMessage)
   214  }