k8s.io/kubernetes@v1.29.3/pkg/client/tests/remotecommand_test.go (about) 1 /* 2 Copyright 2015 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tests 18 19 import ( 20 "bytes" 21 "context" 22 "errors" 23 "fmt" 24 "io" 25 "io/ioutil" 26 "net/http" 27 "net/http/httptest" 28 "net/url" 29 "strings" 30 "testing" 31 "time" 32 33 "github.com/stretchr/testify/require" 34 35 "k8s.io/apimachinery/pkg/runtime" 36 "k8s.io/apimachinery/pkg/runtime/schema" 37 "k8s.io/apimachinery/pkg/types" 38 "k8s.io/apimachinery/pkg/util/httpstream" 39 remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand" 40 restclient "k8s.io/client-go/rest" 41 remoteclient "k8s.io/client-go/tools/remotecommand" 42 "k8s.io/client-go/transport/spdy" 43 "k8s.io/kubelet/pkg/cri/streaming/remotecommand" 44 "k8s.io/kubernetes/pkg/api/legacyscheme" 45 api "k8s.io/kubernetes/pkg/apis/core" 46 ) 47 48 type fakeExecutor struct { 49 t *testing.T 50 testName string 51 errorData string 52 stdoutData string 53 stderrData string 54 expectStdin bool 55 stdinReceived bytes.Buffer 56 tty bool 57 messageCount int 58 command []string 59 exec bool 60 } 61 62 func (ex *fakeExecutor) ExecInContainer(_ context.Context, name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize, timeout time.Duration) error { 63 return ex.run(name, uid, container, cmd, in, out, err, tty) 64 } 65 66 func (ex *fakeExecutor) AttachContainer(_ context.Context, name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize) error { 67 return ex.run(name, uid, container, nil, in, out, err, tty) 68 } 69 70 func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error { 71 ex.command = cmd 72 ex.tty = tty 73 74 if e, a := "pod", name; e != a { 75 ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a) 76 } 77 if e, a := "uid", uid; e != string(a) { 78 ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a) 79 } 80 if ex.exec { 81 if e, a := "ls /", strings.Join(ex.command, " "); e != a { 82 ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a) 83 } 84 } else { 85 if len(ex.command) > 0 { 86 ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command) 87 } 88 } 89 90 if len(ex.errorData) > 0 { 91 return errors.New(ex.errorData) 92 } 93 94 if len(ex.stdoutData) > 0 { 95 for i := 0; i < ex.messageCount; i++ { 96 fmt.Fprint(out, ex.stdoutData) 97 } 98 } 99 100 if len(ex.stderrData) > 0 { 101 for i := 0; i < ex.messageCount; i++ { 102 fmt.Fprint(err, ex.stderrData) 103 } 104 } 105 106 if ex.expectStdin { 107 io.Copy(&ex.stdinReceived, in) 108 } 109 110 return nil 111 } 112 113 func fakeServer(t *testing.T, requestReceived chan struct{}, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc { 114 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 115 executor := &fakeExecutor{ 116 t: t, 117 testName: testName, 118 errorData: errorData, 119 stdoutData: stdoutData, 120 stderrData: stderrData, 121 expectStdin: len(stdinData) > 0, 122 tty: tty, 123 messageCount: messageCount, 124 exec: exec, 125 } 126 127 opts, err := remotecommand.NewOptions(req) 128 require.NoError(t, err) 129 if exec { 130 cmd := req.URL.Query()[api.ExecCommandParam] 131 remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", cmd, opts, 0, 10*time.Second, serverProtocols) 132 } else { 133 remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", opts, 0, 10*time.Second, serverProtocols) 134 } 135 136 if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a { 137 t.Errorf("%s: stdin: expected %q, got %q", testName, e, a) 138 } 139 close(requestReceived) 140 }) 141 } 142 143 func TestStream(t *testing.T) { 144 testCases := []struct { 145 TestName string 146 Stdin string 147 Stdout string 148 Stderr string 149 Error string 150 Tty bool 151 MessageCount int 152 ClientProtocols []string 153 ServerProtocols []string 154 }{ 155 { 156 TestName: "error", 157 Error: "bail", 158 Stdout: "a", 159 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 160 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 161 }, 162 { 163 TestName: "in/out/err", 164 Stdin: "a", 165 Stdout: "b", 166 Stderr: "c", 167 MessageCount: 100, 168 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 169 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 170 }, 171 { 172 TestName: "oversized stdin", 173 Stdin: strings.Repeat("a", 20*1024*1024), 174 Stdout: "b", 175 Stderr: "", 176 MessageCount: 1, 177 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 178 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 179 }, 180 { 181 TestName: "in/out/tty", 182 Stdin: "a", 183 Stdout: "b", 184 Tty: true, 185 MessageCount: 100, 186 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 187 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name}, 188 }, 189 } 190 191 for _, testCase := range testCases { 192 for _, exec := range []bool{true, false} { 193 var name string 194 if exec { 195 name = testCase.TestName + " (exec)" 196 } else { 197 name = testCase.TestName + " (attach)" 198 } 199 200 t.Run(name, func(t *testing.T) { 201 var ( 202 streamIn io.Reader 203 streamOut, streamErr io.Writer 204 ) 205 localOut := &bytes.Buffer{} 206 localErr := &bytes.Buffer{} 207 208 requestReceived := make(chan struct{}) 209 server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols)) 210 defer server.Close() 211 212 url, _ := url.ParseRequestURI(server.URL) 213 config := restclient.ClientContentConfig{ 214 GroupVersion: schema.GroupVersion{Group: "x"}, 215 Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}), 216 } 217 c, err := restclient.NewRESTClient(url, "", config, nil, nil) 218 if err != nil { 219 t.Fatalf("failed to create a client: %v", err) 220 } 221 req := c.Post().Resource("testing") 222 223 if exec { 224 req.Param("command", "ls") 225 req.Param("command", "/") 226 } 227 228 if len(testCase.Stdin) > 0 { 229 req.Param(api.ExecStdinParam, "1") 230 streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)) 231 } 232 233 if len(testCase.Stdout) > 0 { 234 req.Param(api.ExecStdoutParam, "1") 235 streamOut = localOut 236 } 237 238 if testCase.Tty { 239 req.Param(api.ExecTTYParam, "1") 240 } else if len(testCase.Stderr) > 0 { 241 req.Param(api.ExecStderrParam, "1") 242 streamErr = localErr 243 } 244 245 conf := &restclient.Config{ 246 Host: server.URL, 247 } 248 transport, upgradeTransport, err := spdy.RoundTripperFor(conf) 249 if err != nil { 250 t.Fatalf("%s: unexpected error: %v", name, err) 251 } 252 e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...) 253 if err != nil { 254 t.Fatalf("%s: unexpected error: %v", name, err) 255 } 256 err = e.StreamWithContext(context.Background(), remoteclient.StreamOptions{ 257 Stdin: streamIn, 258 Stdout: streamOut, 259 Stderr: streamErr, 260 Tty: testCase.Tty, 261 }) 262 hasErr := err != nil 263 264 if len(testCase.Error) > 0 { 265 if !hasErr { 266 t.Errorf("%s: expected an error", name) 267 } else { 268 if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) { 269 t.Errorf("%s: expected error stream read %q, got %q", name, e, a) 270 } 271 } 272 return 273 } 274 275 if hasErr { 276 t.Fatalf("%s: unexpected error: %v", name, err) 277 } 278 279 if len(testCase.Stdout) > 0 { 280 if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() { 281 t.Fatalf("%s: expected stdout data %q, got %q", name, e, a) 282 } 283 } 284 285 if testCase.Stderr != "" { 286 if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() { 287 t.Fatalf("%s: expected stderr data %q, got %q", name, e, a) 288 } 289 } 290 291 select { 292 case <-requestReceived: 293 case <-time.After(time.Minute): 294 t.Errorf("%s: expected fakeServerInstance to receive request", name) 295 } 296 }) 297 } 298 } 299 } 300 301 type fakeUpgrader struct { 302 req *http.Request 303 resp *http.Response 304 conn httpstream.Connection 305 err, connErr error 306 checkResponse bool 307 called bool 308 309 t *testing.T 310 } 311 312 func (u *fakeUpgrader) RoundTrip(req *http.Request) (*http.Response, error) { 313 u.called = true 314 u.req = req 315 return u.resp, u.err 316 } 317 318 func (u *fakeUpgrader) NewConnection(resp *http.Response) (httpstream.Connection, error) { 319 if u.checkResponse && u.resp != resp { 320 u.t.Errorf("response objects passed did not match: %#v", resp) 321 } 322 return u.conn, u.connErr 323 } 324 325 type fakeConnection struct { 326 httpstream.Connection 327 } 328 329 // Dial is the common functionality between any stream based upgrader, regardless of protocol. 330 // This method ensures that someone can use a generic stream executor without being dependent 331 // on the core Kube client config behavior. 332 func TestDial(t *testing.T) { 333 upgrader := &fakeUpgrader{ 334 t: t, 335 checkResponse: true, 336 conn: &fakeConnection{}, 337 resp: &http.Response{ 338 StatusCode: http.StatusSwitchingProtocols, 339 Body: ioutil.NopCloser(&bytes.Buffer{}), 340 }, 341 } 342 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: upgrader}, "POST", &url.URL{Host: "something.com", Scheme: "https"}) 343 conn, protocol, err := dialer.Dial("protocol1") 344 if err != nil { 345 t.Fatal(err) 346 } 347 if conn != upgrader.conn { 348 t.Errorf("unexpected connection: %#v", conn) 349 } 350 if !upgrader.called { 351 t.Errorf("request not called") 352 } 353 _ = protocol 354 }