k8s.io/client-go@v0.31.1/tools/portforward/tunneling_connection_test.go (about) 1 /* 2 Copyright 2024 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package portforward 18 19 import ( 20 "io" 21 "net" 22 "net/http" 23 "net/http/httptest" 24 "net/url" 25 "strings" 26 "testing" 27 "time" 28 29 gwebsocket "github.com/gorilla/websocket" 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 33 "k8s.io/apimachinery/pkg/util/httpstream" 34 "k8s.io/apimachinery/pkg/util/httpstream/spdy" 35 constants "k8s.io/apimachinery/pkg/util/portforward" 36 "k8s.io/apimachinery/pkg/util/wait" 37 "k8s.io/client-go/rest" 38 "k8s.io/client-go/transport/websocket" 39 ) 40 41 func TestTunnelingConnection_ReadWriteClose(t *testing.T) { 42 // Stream channel that will receive streams created on upstream SPDY server. 43 streamChan := make(chan httpstream.Stream) 44 defer close(streamChan) 45 stopServerChan := make(chan struct{}) 46 defer close(stopServerChan) 47 // Create tunneling connection server endpoint with fake upstream SPDY server. 48 tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 49 var upgrader = gwebsocket.Upgrader{ 50 CheckOrigin: func(r *http.Request) bool { return true }, 51 Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, 52 } 53 conn, err := upgrader.Upgrade(w, req, nil) 54 require.NoError(t, err) 55 defer conn.Close() //nolint:errcheck 56 require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) 57 tunnelingConn := NewTunnelingConnection("server", conn) 58 spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan)) 59 require.NoError(t, err) 60 defer spdyConn.Close() //nolint:errcheck 61 <-stopServerChan 62 })) 63 defer tunnelingServer.Close() 64 // Dial the client tunneling connection to the tunneling server. 65 url, err := url.Parse(tunnelingServer.URL) 66 require.NoError(t, err) 67 dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host}) 68 require.NoError(t, err) 69 spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name) 70 require.NoError(t, err) 71 assert.Equal(t, constants.PortForwardV1Name, protocol) 72 defer spdyClient.Close() //nolint:errcheck 73 // Create a SPDY client stream, which will queue a SPDY server stream 74 // on the stream creation channel. Send data on the client stream 75 // reading off the SPDY server stream, and validating it was tunneled. 76 expected := "This is a test tunneling SPDY data through websockets." 77 var actual []byte 78 go func() { 79 clientStream, err := spdyClient.CreateStream(http.Header{}) 80 require.NoError(t, err) 81 _, err = io.Copy(clientStream, strings.NewReader(expected)) 82 require.NoError(t, err) 83 clientStream.Close() //nolint:errcheck 84 }() 85 select { 86 case serverStream := <-streamChan: 87 actual, err = io.ReadAll(serverStream) 88 require.NoError(t, err) 89 defer serverStream.Close() //nolint:errcheck 90 case <-time.After(wait.ForeverTestTimeout): 91 t.Fatalf("timeout waiting for spdy stream to arrive on channel.") 92 } 93 assert.Equal(t, expected, string(actual), "error validating tunneled string") 94 } 95 96 func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) { 97 stopServerChan := make(chan struct{}) 98 defer close(stopServerChan) 99 tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 100 var upgrader = gwebsocket.Upgrader{ 101 CheckOrigin: func(r *http.Request) bool { return true }, 102 Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, 103 } 104 conn, err := upgrader.Upgrade(w, req, nil) 105 require.NoError(t, err) 106 defer conn.Close() //nolint:errcheck 107 require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) 108 <-stopServerChan 109 })) 110 defer tunnelingServer.Close() 111 // Create the client side tunneling connection. 112 url, err := url.Parse(tunnelingServer.URL) 113 require.NoError(t, err) 114 tConn, err := dialForTunnelingConnection(url) 115 require.NoError(t, err, "error creating client tunneling connection") 116 defer tConn.Close() //nolint:errcheck 117 // Validate "LocalAddr()" and "RemoteAddr()" 118 localAddr := tConn.LocalAddr() 119 remoteAddr := tConn.RemoteAddr() 120 assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP") 121 assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP") 122 _, err = net.ResolveTCPAddr("tcp", localAddr.String()) 123 assert.NoError(t, err, "tunneling connection local addr should parse") 124 _, err = net.ResolveTCPAddr("tcp", remoteAddr.String()) 125 assert.NoError(t, err, "tunneling connection remote addr should parse") 126 } 127 128 func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) { 129 stopServerChan := make(chan struct{}) 130 defer close(stopServerChan) 131 tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 132 var upgrader = gwebsocket.Upgrader{ 133 CheckOrigin: func(r *http.Request) bool { return true }, 134 Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, 135 } 136 conn, err := upgrader.Upgrade(w, req, nil) 137 require.NoError(t, err) 138 defer conn.Close() //nolint:errcheck 139 require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) 140 <-stopServerChan 141 })) 142 defer tunnelingServer.Close() 143 // Create the client side tunneling connection. 144 url, err := url.Parse(tunnelingServer.URL) 145 require.NoError(t, err) 146 tConn, err := dialForTunnelingConnection(url) 147 require.NoError(t, err, "error creating client tunneling connection") 148 defer tConn.Close() //nolint:errcheck 149 // Validate the read and write deadlines. 150 err = tConn.SetReadDeadline(time.Time{}) 151 assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") 152 err = tConn.SetWriteDeadline(time.Time{}) 153 assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") 154 err = tConn.SetDeadline(time.Time{}) 155 assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") 156 err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0)) 157 assert.NoError(t, err, "setting deadline 10 year from now succeeds") 158 err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0)) 159 assert.NoError(t, err, "setting deadline 10 year from now succeeds") 160 err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0)) 161 assert.NoError(t, err, "setting deadline 10 year from now succeeds") 162 } 163 164 // dialForTunnelingConnection upgrades a request at the passed "url", creating 165 // a websocket connection. Returns the TunnelingConnection injected with the 166 // websocket connection or an error if one occurs. 167 func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) { 168 req, err := http.NewRequest("GET", url.String(), nil) 169 if err != nil { 170 return nil, err 171 } 172 // Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol. 173 tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1} 174 transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host}) 175 if err != nil { 176 return nil, err 177 } 178 conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...) 179 if err != nil { 180 return nil, err 181 } 182 return NewTunnelingConnection("client", conn), nil 183 } 184 185 func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { 186 return func(stream httpstream.Stream, replySent <-chan struct{}) error { 187 streams <- stream 188 return nil 189 } 190 }