github.com/argoproj/argo-cd/v3@v3.2.1/server/application/websocket_test.go (about) 1 package application 2 3 import ( 4 "context" 5 "encoding/json" 6 "net/http" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 11 corev1 "k8s.io/api/core/v1" 12 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 13 "k8s.io/client-go/kubernetes/fake" 14 15 "github.com/argoproj/argo-cd/v3/common" 16 "github.com/argoproj/argo-cd/v3/util/assets" 17 "github.com/argoproj/argo-cd/v3/util/rbac" 18 19 "github.com/golang-jwt/jwt/v5" 20 "github.com/gorilla/websocket" 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 ) 24 25 func newTestTerminalSession(w http.ResponseWriter, r *http.Request) terminalSession { 26 upgrader := websocket.Upgrader{} 27 c, err := upgrader.Upgrade(w, r, nil) 28 if err != nil { 29 return terminalSession{} 30 } 31 32 return terminalSession{wsConn: c} 33 } 34 35 func newEnforcer() *rbac.Enforcer { 36 additionalConfig := make(map[string]string, 0) 37 kubeclientset := fake.NewClientset(&corev1.ConfigMap{ 38 ObjectMeta: metav1.ObjectMeta{ 39 Namespace: testNamespace, 40 Name: "argocd-cm", 41 Labels: map[string]string{ 42 "app.kubernetes.io/part-of": "argocd", 43 }, 44 }, 45 Data: additionalConfig, 46 }, &corev1.Secret{ 47 ObjectMeta: metav1.ObjectMeta{ 48 Name: "argocd-secret", 49 Namespace: testNamespace, 50 }, 51 Data: map[string][]byte{ 52 "admin.password": []byte("test"), 53 "server.secretkey": []byte("test"), 54 }, 55 }) 56 57 enforcer := rbac.NewEnforcer(kubeclientset, testNamespace, common.ArgoCDRBACConfigMapName, nil) 58 return enforcer 59 } 60 61 func reconnect(w http.ResponseWriter, r *http.Request) { 62 ts := newTestTerminalSession(w, r) 63 _, _ = ts.reconnect() 64 } 65 66 func TestReconnect(t *testing.T) { 67 s := httptest.NewServer(http.HandlerFunc(reconnect)) 68 defer s.Close() 69 70 u := "ws" + strings.TrimPrefix(s.URL, "http") 71 72 // Connect to the server 73 ws, _, err := websocket.DefaultDialer.Dial(u, nil) 74 require.NoError(t, err) 75 76 defer ws.Close() 77 78 _, p, _ := ws.ReadMessage() 79 80 var message TerminalMessage 81 82 err = json.Unmarshal(p, &message) 83 84 require.NoError(t, err) 85 assert.Equal(t, ReconnectMessage, message.Data) 86 } 87 88 func testServerConnection(t *testing.T, testFunc func(w http.ResponseWriter, r *http.Request), expectPermissionDenied bool) { 89 t.Helper() 90 s := httptest.NewServer(http.HandlerFunc(testFunc)) 91 defer s.Close() 92 93 u := "ws" + strings.TrimPrefix(s.URL, "http") 94 95 // Connect to the server 96 ws, _, err := websocket.DefaultDialer.Dial(u, nil) 97 require.NoError(t, err) 98 99 defer ws.Close() 100 if expectPermissionDenied { 101 _, p, _ := ws.ReadMessage() 102 103 var message TerminalMessage 104 105 err = json.Unmarshal(p, &message) 106 107 require.NoError(t, err) 108 assert.Equal(t, "Permission denied", message.Data) 109 } 110 } 111 112 func TestVerifyAndReconnectDisableAuthTrue(t *testing.T) { 113 validate := func(w http.ResponseWriter, r *http.Request) { 114 ts := newTestTerminalSession(w, r) 115 // Currently testing only the usecase of disableAuth: true since the disableAuth: false case 116 // requires a valid token to be passed in the request. 117 // Note that running with disableAuth: false will surprisingly succeed as well, because 118 // the underlying token nil pointer dereference is swallowed in a location I didn't find, 119 // or even swallowed by the test framework. 120 ts.terminalOpts = &TerminalOptions{DisableAuth: true} 121 code, err := ts.performValidationsAndReconnect([]byte{}) 122 assert.Equal(t, 0, code) 123 require.NoError(t, err) 124 } 125 testServerConnection(t, validate, false) 126 } 127 128 func TestValidateWithAdminPermissions(t *testing.T) { 129 validate := func(w http.ResponseWriter, r *http.Request) { 130 enf := newEnforcer() 131 _ = enf.SetBuiltinPolicy(assets.BuiltinPolicyCSV) 132 enf.SetDefaultRole("role:admin") 133 enf.SetClaimsEnforcerFunc(func(_ jwt.Claims, _ ...any) bool { 134 return true 135 }) 136 ts := newTestTerminalSession(w, r) 137 ts.terminalOpts = &TerminalOptions{Enf: enf} 138 ts.appRBACName = "test" 139 //nolint:staticcheck 140 ts.ctx = context.WithValue(t.Context(), "claims", &jwt.MapClaims{"groups": []string{"admin"}}) 141 _, err := ts.validatePermissions([]byte{}) 142 require.NoError(t, err) 143 } 144 145 testServerConnection(t, validate, false) 146 } 147 148 func TestValidateWithoutPermissions(t *testing.T) { 149 validate := func(w http.ResponseWriter, r *http.Request) { 150 enf := newEnforcer() 151 _ = enf.SetBuiltinPolicy(assets.BuiltinPolicyCSV) 152 enf.SetDefaultRole("role:test") 153 enf.SetClaimsEnforcerFunc(func(_ jwt.Claims, _ ...any) bool { 154 return false 155 }) 156 ts := newTestTerminalSession(w, r) 157 ts.terminalOpts = &TerminalOptions{Enf: enf} 158 ts.appRBACName = "test" 159 //nolint:staticcheck 160 ts.ctx = context.WithValue(t.Context(), "claims", &jwt.MapClaims{"groups": []string{"test"}}) 161 _, err := ts.validatePermissions([]byte{}) 162 require.Error(t, err) 163 assert.EqualError(t, err, common.PermissionDeniedAPIError.Error()) 164 } 165 166 testServerConnection(t, validate, true) 167 } 168 169 func TestTerminalSession_Write(t *testing.T) { 170 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 171 upgrader := websocket.Upgrader{} 172 conn, err := upgrader.Upgrade(w, r, nil) 173 require.NoError(t, err) 174 defer conn.Close() 175 176 for { 177 // Read the message from the WebSocket connection 178 messageType, message, err := conn.ReadMessage() 179 if err != nil { 180 return 181 } 182 // Respond back the same message 183 err = conn.WriteMessage(messageType, message) 184 require.NoError(t, err) 185 } 186 })) 187 defer server.Close() 188 189 u := "ws" + strings.TrimPrefix(server.URL, "http") 190 wsConn, _, err := websocket.DefaultDialer.Dial(u, nil) 191 require.NoError(t, err) 192 defer wsConn.Close() 193 194 ts := terminalSession{ 195 wsConn: wsConn, 196 } 197 198 testData := []byte("hello world") 199 expectedMessage, err := json.Marshal(TerminalMessage{ 200 Operation: "stdout", 201 Data: string(testData), 202 }) 203 require.NoError(t, err) 204 205 n, err := ts.Write(testData) 206 require.NoError(t, err) 207 208 assert.Equal(t, len(testData), n) 209 210 _, receivedMessage, err := wsConn.ReadMessage() 211 require.NoError(t, err) 212 213 assert.Equal(t, expectedMessage, receivedMessage) 214 }