k8s.io/kubernetes@v1.31.0-alpha.0.0.20240520171757-56147500dadc/pkg/kubelet/server/server_websocket_test.go (about) 1 /* 2 Copyright 2016 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package server 18 19 import ( 20 "encoding/binary" 21 "fmt" 22 "io" 23 "strconv" 24 "sync" 25 "testing" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "golang.org/x/net/websocket" 30 31 "k8s.io/apimachinery/pkg/types" 32 "k8s.io/kubelet/pkg/cri/streaming/portforward" 33 ) 34 35 const ( 36 dataChannel = iota 37 errorChannel 38 ) 39 40 func TestServeWSPortForward(t *testing.T) { 41 tests := map[string]struct { 42 port string 43 uid bool 44 clientData string 45 containerData string 46 shouldError bool 47 }{ 48 "no port": {port: "", shouldError: true}, 49 "none number port": {port: "abc", shouldError: true}, 50 "negative port": {port: "-1", shouldError: true}, 51 "too large port": {port: "65536", shouldError: true}, 52 "0 port": {port: "0", shouldError: true}, 53 "min port": {port: "1", shouldError: false}, 54 "normal port": {port: "8000", shouldError: false}, 55 "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, 56 "max port": {port: "65535", shouldError: false}, 57 "normal port with uid": {port: "8000", uid: true, shouldError: false}, 58 } 59 60 podNamespace := "other" 61 podName := "foo" 62 63 for desc := range tests { 64 test := tests[desc] 65 t.Run(desc, func(t *testing.T) { 66 ss, err := newTestStreamingServer(0) 67 require.NoError(t, err) 68 defer ss.testHTTPServer.Close() 69 fw := newServerTestWithDebug(true, ss) 70 defer fw.testHTTPServer.Close() 71 72 portForwardFuncDone := make(chan struct{}) 73 74 fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { 75 assert.Equal(t, podName, name, "pod name") 76 assert.Equal(t, podNamespace, namespace, "pod namespace") 77 if test.uid { 78 assert.Equal(t, testUID, string(uid), "uid") 79 } 80 } 81 82 ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { 83 defer close(portForwardFuncDone) 84 assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") 85 // The port should be valid if it reaches here. 86 testPort, err := strconv.ParseInt(test.port, 10, 32) 87 require.NoError(t, err, "parse port") 88 assert.Equal(t, int32(testPort), port, "port") 89 90 if test.clientData != "" { 91 fromClient := make([]byte, 32) 92 n, err := stream.Read(fromClient) 93 assert.NoError(t, err, "reading client data") 94 assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data") 95 } 96 97 if test.containerData != "" { 98 _, err := stream.Write([]byte(test.containerData)) 99 assert.NoError(t, err, "writing container data") 100 } 101 102 return nil 103 } 104 105 var url string 106 if test.uid { 107 url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port) 108 } else { 109 url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) 110 } 111 112 ws, err := websocket.Dial(url, "", "http://127.0.0.1/") 113 assert.Equal(t, test.shouldError, err != nil, "websocket dial") 114 if test.shouldError { 115 return 116 } 117 defer ws.Close() 118 119 p, err := strconv.ParseUint(test.port, 10, 16) 120 require.NoError(t, err, "parse port") 121 p16 := uint16(p) 122 123 channel, data, err := wsRead(ws) 124 require.NoError(t, err, "read") 125 assert.Equal(t, dataChannel, int(channel), "channel") 126 assert.Len(t, data, binary.Size(p16), "data size") 127 assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data") 128 129 channel, data, err = wsRead(ws) 130 assert.NoError(t, err, "read") 131 assert.Equal(t, errorChannel, int(channel), "channel") 132 assert.Len(t, data, binary.Size(p16), "data size") 133 assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data") 134 135 if test.clientData != "" { 136 println("writing the client data") 137 err := wsWrite(ws, dataChannel, []byte(test.clientData)) 138 assert.NoError(t, err, "writing client data") 139 } 140 141 if test.containerData != "" { 142 _, data, err = wsRead(ws) 143 assert.NoError(t, err, "reading container data") 144 assert.Equal(t, test.containerData, string(data), "container data") 145 } 146 147 <-portForwardFuncDone 148 }) 149 } 150 } 151 152 func TestServeWSMultiplePortForward(t *testing.T) { 153 portsText := []string{"7000,8000", "9000"} 154 ports := []uint16{7000, 8000, 9000} 155 podNamespace := "other" 156 podName := "foo" 157 158 ss, err := newTestStreamingServer(0) 159 require.NoError(t, err) 160 defer ss.testHTTPServer.Close() 161 fw := newServerTestWithDebug(true, ss) 162 defer fw.testHTTPServer.Close() 163 164 portForwardWG := sync.WaitGroup{} 165 portForwardWG.Add(len(ports)) 166 167 portsMutex := sync.Mutex{} 168 portsForwarded := map[int32]struct{}{} 169 170 fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { 171 assert.Equal(t, podName, name, "pod name") 172 assert.Equal(t, podNamespace, namespace, "pod namespace") 173 } 174 175 ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { 176 defer portForwardWG.Done() 177 assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") 178 179 portsMutex.Lock() 180 portsForwarded[port] = struct{}{} 181 portsMutex.Unlock() 182 183 fromClient := make([]byte, 32) 184 n, err := stream.Read(fromClient) 185 assert.NoError(t, err, "reading client data") 186 assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data") 187 188 _, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port))) 189 assert.NoError(t, err, "writing container data") 190 191 return nil 192 } 193 194 url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName) 195 for _, port := range portsText { 196 url = url + fmt.Sprintf("port=%s&", port) 197 } 198 199 ws, err := websocket.Dial(url, "", "http://127.0.0.1/") 200 require.NoError(t, err, "websocket dial") 201 202 defer ws.Close() 203 204 for i, port := range ports { 205 channel, data, err := wsRead(ws) 206 assert.NoError(t, err, "port %d read", port) 207 assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port) 208 assert.Len(t, data, binary.Size(port), "port %d data size", port) 209 assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port) 210 211 channel, data, err = wsRead(ws) 212 assert.NoError(t, err, "port %d read", port) 213 assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port) 214 assert.Len(t, data, binary.Size(port), "port %d data size", port) 215 assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port) 216 } 217 218 for i, port := range ports { 219 t.Logf("port %d writing the client data", port) 220 err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port))) 221 assert.NoError(t, err, "port %d write client data", port) 222 223 channel, data, err := wsRead(ws) 224 assert.NoError(t, err, "port %d read container data", port) 225 assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port) 226 assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port) 227 } 228 229 portForwardWG.Wait() 230 231 portsMutex.Lock() 232 defer portsMutex.Unlock() 233 assert.Len(t, portsForwarded, len(ports), "all ports forwarded") 234 } 235 236 func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { 237 frame := make([]byte, len(data)+1) 238 frame[0] = channel 239 copy(frame[1:], data) 240 err := websocket.Message.Send(conn, frame) 241 return err 242 } 243 244 func wsRead(conn *websocket.Conn) (byte, []byte, error) { 245 for { 246 var data []byte 247 err := websocket.Message.Receive(conn, &data) 248 if err != nil { 249 return 0, nil, err 250 } 251 252 if len(data) == 0 { 253 continue 254 } 255 256 channel := data[0] 257 data = data[1:] 258 259 return channel, data, err 260 } 261 }