k8s.io/kubernetes@v1.29.3/pkg/client/tests/portfoward_test.go (about) 1 /* 2 Copyright 2015 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tests 18 19 import ( 20 "bytes" 21 "context" 22 "fmt" 23 "io" 24 "net" 25 "net/http" 26 "net/http/httptest" 27 "net/url" 28 "os" 29 "strings" 30 "sync" 31 "testing" 32 "time" 33 34 "k8s.io/apimachinery/pkg/types" 35 restclient "k8s.io/client-go/rest" 36 . "k8s.io/client-go/tools/portforward" 37 "k8s.io/client-go/transport/spdy" 38 "k8s.io/kubelet/pkg/cri/streaming/portforward" 39 ) 40 41 // fakePortForwarder simulates port forwarding for testing. It implements 42 // portforward.PortForwarder. 43 type fakePortForwarder struct { 44 lock sync.Mutex 45 // stores data expected from the stream per port 46 expected map[int32]string 47 // stores data received from the stream per port 48 received map[int32]string 49 // data to be sent to the stream per port 50 send map[int32]string 51 } 52 53 var _ portforward.PortForwarder = &fakePortForwarder{} 54 55 func (pf *fakePortForwarder) PortForward(_ context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { 56 defer stream.Close() 57 58 // read from the client 59 received := make([]byte, len(pf.expected[port])) 60 n, err := stream.Read(received) 61 if err != nil { 62 return fmt.Errorf("error reading from client for port %d: %v", port, err) 63 } 64 if n != len(pf.expected[port]) { 65 return fmt.Errorf("unexpected length read from client for port %d: got %d, expected %d. data=%q", port, n, len(pf.expected[port]), string(received)) 66 } 67 68 // store the received content 69 pf.lock.Lock() 70 pf.received[port] = string(received) 71 pf.lock.Unlock() 72 73 // send the hardcoded data to the client 74 io.Copy(stream, strings.NewReader(pf.send[port])) 75 76 return nil 77 } 78 79 // fakePortForwardServer creates an HTTP server that can handle port forwarding 80 // requests. 81 func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc { 82 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 83 pf := &fakePortForwarder{ 84 expected: expectedFromClient, 85 received: make(map[int32]string), 86 send: serverSends, 87 } 88 portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols) 89 90 for port, expected := range expectedFromClient { 91 actual, ok := pf.received[port] 92 if !ok { 93 t.Errorf("%s: server didn't receive any data for port %d", testName, port) 94 continue 95 } 96 97 if expected != actual { 98 t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port) 99 } 100 } 101 102 for port, actual := range pf.received { 103 if _, ok := expectedFromClient[port]; !ok { 104 t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port) 105 } 106 } 107 }) 108 } 109 110 func TestForwardPorts(t *testing.T) { 111 tests := map[string]struct { 112 ports []string 113 clientSends map[int32]string 114 serverSends map[int32]string 115 }{ 116 "forward 1 port with no data either direction": { 117 ports: []string{":5000"}, 118 }, 119 "forward 2 ports with bidirectional data": { 120 ports: []string{":5001", ":6000"}, 121 clientSends: map[int32]string{ 122 5001: "abcd", 123 6000: "ghij", 124 }, 125 serverSends: map[int32]string{ 126 5001: "1234", 127 6000: "5678", 128 }, 129 }, 130 } 131 132 for testName, test := range tests { 133 t.Run(testName, func(t *testing.T) { 134 server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends)) 135 defer server.Close() 136 137 transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) 138 if err != nil { 139 t.Fatal(err) 140 } 141 url, _ := url.Parse(server.URL) 142 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) 143 144 stopChan := make(chan struct{}, 1) 145 readyChan := make(chan struct{}) 146 147 pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr) 148 if err != nil { 149 t.Fatalf("%s: unexpected error calling New: %v", testName, err) 150 } 151 152 doneChan := make(chan error) 153 go func() { 154 doneChan <- pf.ForwardPorts() 155 }() 156 <-pf.Ready 157 158 forwardedPorts, err := pf.GetPorts() 159 if err != nil { 160 t.Fatal(err) 161 } 162 163 remoteToLocalMap := map[int32]int32{} 164 for _, forwardedPort := range forwardedPorts { 165 remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local) 166 } 167 168 clientSend := func(port int32, data string) error { 169 clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port])) 170 if err != nil { 171 return fmt.Errorf("%s: error dialing %d: %s", testName, port, err) 172 173 } 174 defer clientConn.Close() 175 176 n, err := clientConn.Write([]byte(data)) 177 if err != nil && err != io.EOF { 178 return fmt.Errorf("%s: Error sending data '%s': %s", testName, data, err) 179 } 180 if n == 0 { 181 return fmt.Errorf("%s: unexpected write of 0 bytes", testName) 182 } 183 b := make([]byte, 4) 184 _, err = clientConn.Read(b) 185 if err != nil && err != io.EOF { 186 return fmt.Errorf("%s: Error reading data: %s", testName, err) 187 } 188 if !bytes.Equal([]byte(test.serverSends[port]), b) { 189 return fmt.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b) 190 } 191 return nil 192 } 193 for port, data := range test.clientSends { 194 if err := clientSend(port, data); err != nil { 195 t.Error(err) 196 } 197 } 198 // tell r.ForwardPorts to stop 199 close(stopChan) 200 201 // wait for r.ForwardPorts to actually return 202 err = <-doneChan 203 if err != nil { 204 t.Errorf("%s: unexpected error: %s", testName, err) 205 } 206 }) 207 } 208 209 } 210 211 func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { 212 server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil)) 213 defer server.Close() 214 215 transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) 216 if err != nil { 217 t.Fatal(err) 218 } 219 url, _ := url.Parse(server.URL) 220 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) 221 222 stopChan1 := make(chan struct{}, 1) 223 defer close(stopChan1) 224 readyChan1 := make(chan struct{}) 225 226 pf1, err := New(dialer, []string{":5555"}, stopChan1, readyChan1, os.Stdout, os.Stderr) 227 if err != nil { 228 t.Fatalf("error creating pf1: %v", err) 229 } 230 go pf1.ForwardPorts() 231 <-pf1.Ready 232 233 forwardedPorts, err := pf1.GetPorts() 234 if err != nil { 235 t.Fatal(err) 236 } 237 if len(forwardedPorts) != 1 { 238 t.Fatalf("expected 1 forwarded port, got %#v", forwardedPorts) 239 } 240 duplicateSpec := fmt.Sprintf("%d:%d", forwardedPorts[0].Local, forwardedPorts[0].Remote) 241 242 stopChan2 := make(chan struct{}, 1) 243 readyChan2 := make(chan struct{}) 244 pf2, err := New(dialer, []string{duplicateSpec}, stopChan2, readyChan2, os.Stdout, os.Stderr) 245 if err != nil { 246 t.Fatalf("error creating pf2: %v", err) 247 } 248 if err := pf2.ForwardPorts(); err == nil { 249 t.Fatal("expected non-nil error for pf2.ForwardPorts") 250 } 251 }