k8s.io/client-go@v0.31.1/tools/remotecommand/v2_test.go (about) 1 /* 2 Copyright 2016 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package remotecommand 18 19 import ( 20 "errors" 21 "io" 22 "net/http" 23 "strings" 24 "testing" 25 "time" 26 27 "k8s.io/api/core/v1" 28 "k8s.io/apimachinery/pkg/util/httpstream" 29 "k8s.io/apimachinery/pkg/util/wait" 30 ) 31 32 type fakeReader struct { 33 err error 34 } 35 36 func (r *fakeReader) Read([]byte) (int, error) { return 0, r.err } 37 38 type fakeWriter struct{} 39 40 func (*fakeWriter) Write([]byte) (int, error) { return 0, nil } 41 42 type fakeStreamCreator struct { 43 created map[string]bool 44 errors map[string]error 45 } 46 47 var _ streamCreator = &fakeStreamCreator{} 48 49 func (f *fakeStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) { 50 streamType := headers.Get(v1.StreamType) 51 f.created[streamType] = true 52 return nil, f.errors[streamType] 53 } 54 55 func TestV2CreateStreams(t *testing.T) { 56 tests := []struct { 57 name string 58 stdin bool 59 stdinError error 60 stdout bool 61 stdoutError error 62 stderr bool 63 stderrError error 64 errorError error 65 tty bool 66 expectError bool 67 }{ 68 { 69 name: "stdin error", 70 stdin: true, 71 stdinError: errors.New("stdin error"), 72 expectError: true, 73 }, 74 { 75 name: "stdout error", 76 stdout: true, 77 stdoutError: errors.New("stdout error"), 78 expectError: true, 79 }, 80 { 81 name: "stderr error", 82 stderr: true, 83 stderrError: errors.New("stderr error"), 84 expectError: true, 85 }, 86 { 87 name: "error stream error", 88 stdin: true, 89 stdout: true, 90 stderr: true, 91 errorError: errors.New("error stream error"), 92 expectError: true, 93 }, 94 { 95 name: "no errors", 96 stdin: true, 97 stdout: true, 98 stderr: true, 99 expectError: false, 100 }, 101 { 102 name: "no errors, stderr & tty set, don't expect stderr", 103 stdin: true, 104 stdout: true, 105 stderr: true, 106 tty: true, 107 expectError: false, 108 }, 109 } 110 for _, test := range tests { 111 conn := &fakeStreamCreator{ 112 created: make(map[string]bool), 113 errors: map[string]error{ 114 v1.StreamTypeStdin: test.stdinError, 115 v1.StreamTypeStdout: test.stdoutError, 116 v1.StreamTypeStderr: test.stderrError, 117 v1.StreamTypeError: test.errorError, 118 }, 119 } 120 121 opts := StreamOptions{Tty: test.tty} 122 if test.stdin { 123 opts.Stdin = &fakeReader{} 124 } 125 if test.stdout { 126 opts.Stdout = &fakeWriter{} 127 } 128 if test.stderr { 129 opts.Stderr = &fakeWriter{} 130 } 131 132 h := newStreamProtocolV2(opts).(*streamProtocolV2) 133 err := h.createStreams(conn) 134 135 if test.expectError { 136 if err == nil { 137 t.Errorf("%s: expected error", test.name) 138 continue 139 } 140 if e, a := test.stdinError, err; test.stdinError != nil && e != a { 141 t.Errorf("%s: expected %v, got %v", test.name, e, a) 142 } 143 if e, a := test.stdoutError, err; test.stdoutError != nil && e != a { 144 t.Errorf("%s: expected %v, got %v", test.name, e, a) 145 } 146 if e, a := test.stderrError, err; test.stderrError != nil && e != a { 147 t.Errorf("%s: expected %v, got %v", test.name, e, a) 148 } 149 if e, a := test.errorError, err; test.errorError != nil && e != a { 150 t.Errorf("%s: expected %v, got %v", test.name, e, a) 151 } 152 continue 153 } 154 155 if !test.expectError && err != nil { 156 t.Errorf("%s: unexpected error: %v", test.name, err) 157 continue 158 } 159 160 if test.stdin && !conn.created[v1.StreamTypeStdin] { 161 t.Errorf("%s: expected stdin stream", test.name) 162 } 163 if test.stdout && !conn.created[v1.StreamTypeStdout] { 164 t.Errorf("%s: expected stdout stream", test.name) 165 } 166 if test.stderr { 167 if test.tty && conn.created[v1.StreamTypeStderr] { 168 t.Errorf("%s: unexpected stderr stream because tty is set", test.name) 169 } else if !test.tty && !conn.created[v1.StreamTypeStderr] { 170 t.Errorf("%s: expected stderr stream", test.name) 171 } 172 } 173 if !conn.created[v1.StreamTypeError] { 174 t.Errorf("%s: expected error stream", test.name) 175 } 176 177 } 178 } 179 180 func TestV2ErrorStreamReading(t *testing.T) { 181 tests := []struct { 182 name string 183 stream io.Reader 184 expectedError error 185 }{ 186 { 187 name: "error reading from stream", 188 stream: &fakeReader{errors.New("foo")}, 189 expectedError: errors.New("error reading from error stream: foo"), 190 }, 191 { 192 name: "stream returns an error", 193 stream: strings.NewReader("some error"), 194 expectedError: errors.New("error executing remote command: some error"), 195 }, 196 } 197 198 for _, test := range tests { 199 h := newStreamProtocolV2(StreamOptions{}).(*streamProtocolV2) 200 h.errorStream = test.stream 201 202 ch := watchErrorStream(h.errorStream, &errorDecoderV2{}) 203 if ch == nil { 204 t.Fatalf("%s: unexpected nil channel", test.name) 205 } 206 207 var err error 208 select { 209 case err = <-ch: 210 case <-time.After(wait.ForeverTestTimeout): 211 t.Fatalf("%s: timed out", test.name) 212 } 213 214 if test.expectedError != nil { 215 if err == nil { 216 t.Errorf("%s: expected an error", test.name) 217 } else if e, a := test.expectedError, err; e.Error() != a.Error() { 218 t.Errorf("%s: expected %q, got %q", test.name, e, a) 219 } 220 continue 221 } 222 223 if test.expectedError == nil && err != nil { 224 t.Errorf("%s: unexpected error: %v", test.name, err) 225 continue 226 } 227 } 228 }