github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/prompt/context_reader_test.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package prompt 18 19 import ( 20 "context" 21 "io" 22 "os" 23 "testing" 24 "time" 25 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "golang.org/x/term" 29 ) 30 31 func TestContextReader(t *testing.T) { 32 pr, pw := io.Pipe() 33 t.Cleanup(func() { pr.Close() }) 34 t.Cleanup(func() { pw.Close() }) 35 36 write := func(t *testing.T, s string) { 37 t.Helper() 38 _, err := pw.Write([]byte(s)) 39 assert.NoError(t, err, "Write failed") 40 } 41 42 ctx := context.Background() 43 cr := NewContextReader(pr) 44 45 t.Run("simple read", func(t *testing.T) { 46 go write(t, "hello") 47 buf, err := cr.ReadContext(ctx) 48 require.NoError(t, err) 49 require.Equal(t, "hello", string(buf)) 50 }) 51 52 t.Run("reclaim abandoned read", func(t *testing.T) { 53 done := make(chan struct{}) 54 cancelCtx, cancel := context.WithCancel(ctx) 55 go func() { 56 time.Sleep(1 * time.Millisecond) // give ReadContext time to block 57 cancel() 58 write(t, "after cancel") 59 close(done) 60 }() 61 buf, err := cr.ReadContext(cancelCtx) 62 require.ErrorIs(t, err, context.Canceled) 63 require.Empty(t, buf) 64 65 <-done // wait for write 66 buf, err = cr.ReadContext(ctx) 67 require.NoError(t, err) 68 require.Equal(t, "after cancel", string(buf)) 69 }) 70 71 t.Run("close ContextReader", func(t *testing.T) { 72 go func() { 73 time.Sleep(1 * time.Millisecond) // give ReadContext time to block 74 assert.NoError(t, cr.Close(), "Close errored") 75 }() 76 _, err := cr.ReadContext(ctx) 77 require.ErrorIs(t, err, ErrReaderClosed) 78 79 // Subsequent reads fail. 80 _, err = cr.ReadContext(ctx) 81 require.ErrorIs(t, err, ErrReaderClosed) 82 83 // Ongoing read after Close is dropped. 84 write(t, "unblock goroutine") 85 buf, err := cr.ReadContext(ctx) 86 assert.ErrorIs(t, err, ErrReaderClosed) 87 assert.Empty(t, buf, "buf not empty") 88 89 // Multiple closes are fine. 90 assert.NoError(t, cr.Close(), "2nd Close failed") 91 }) 92 93 // Re-creating is safe because the tests above leave no "pending" reads. 94 cr = NewContextReader(pr) 95 96 t.Run("close underlying reader", func(t *testing.T) { 97 go func() { 98 write(t, "before close") 99 pw.CloseWithError(io.EOF) 100 }() 101 102 // Read the last chunk of data successfully. 103 buf, err := cr.ReadContext(ctx) 104 require.NoError(t, err) 105 require.Equal(t, "before close", string(buf)) 106 107 // Next read fails because underlying reader is closed. 108 buf, err = cr.ReadContext(ctx) 109 require.ErrorIs(t, err, io.EOF) 110 require.Empty(t, buf) 111 }) 112 } 113 114 func TestContextReader_ReadPassword(t *testing.T) { 115 pr, pw := io.Pipe() 116 write := func(t *testing.T, s string) { 117 t.Helper() 118 _, err := pw.Write([]byte(s)) 119 assert.NoError(t, err, "Write failed") 120 } 121 122 devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0666) 123 require.NoError(t, err, "Failed to open %v", os.DevNull) 124 defer devNull.Close() 125 126 term := &fakeTerm{reader: pr} 127 cr := NewContextReader(pr) 128 cr.term = term 129 cr.fd = int(devNull.Fd()) // arbitrary, doesn't matter because term functions are mocked. 130 131 ctx := context.Background() 132 t.Run("read password", func(t *testing.T) { 133 const want = "llama45" 134 go write(t, want) 135 136 got, err := cr.ReadPassword(ctx) 137 require.NoError(t, err, "ReadPassword failed") 138 assert.Equal(t, want, string(got), "ReadPassword mismatch") 139 }) 140 141 t.Run("intertwine reads", func(t *testing.T) { 142 const want1 = "hello, world" 143 go write(t, want1) 144 got, err := cr.ReadPassword(ctx) 145 require.NoError(t, err, "ReadPassword failed") 146 assert.Equal(t, want1, string(got), "ReadPassword mismatch") 147 148 const want2 = "goodbye, world" 149 go write(t, want2) 150 got, err = cr.ReadContext(ctx) 151 require.NoError(t, err, "ReadContext failed") 152 assert.Equal(t, want2, string(got), "ReadContext mismatch") 153 }) 154 155 t.Run("password read turned clean", func(t *testing.T) { 156 require.False(t, term.restoreCalled, "restoreCalled sanity check failed") 157 158 // Give ReadPassword time to block. 159 cancelCtx, cancel := context.WithTimeout(ctx, 1*time.Millisecond) 160 defer cancel() 161 got, err := cr.ReadPassword(cancelCtx) 162 require.ErrorIs(t, err, context.DeadlineExceeded, "ReadPassword returned unexpected error") 163 require.Empty(t, got, "ReadPassword mismatch") 164 165 // Reclaim as clean read. 166 const want = "abandoned pwd read" 167 go func() { 168 // Once again, give ReadContext time to block. 169 // This way we force a restore. 170 time.Sleep(1 * time.Millisecond) 171 write(t, want) 172 }() 173 got, err = cr.ReadContext(ctx) 174 require.NoError(t, err, "ReadContext failed") 175 assert.Equal(t, want, string(got), "ReadContext mismatch") 176 assert.True(t, term.restoreCalled, "term.Restore not called") 177 }) 178 179 t.Run("Close", func(t *testing.T) { 180 require.NoError(t, cr.Close(), "Close errored") 181 182 _, err := cr.ReadPassword(ctx) 183 require.ErrorIs(t, err, ErrReaderClosed, "ReadPassword returned unexpected error") 184 }) 185 } 186 187 func TestNotifyExit_restoresTerminal(t *testing.T) { 188 oldStdin := Stdin() 189 t.Cleanup(func() { SetStdin(oldStdin) }) 190 191 pr, _ := io.Pipe() 192 193 devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0666) 194 require.NoError(t, err, "Failed to open %v", os.DevNull) 195 defer devNull.Close() 196 197 term := &fakeTerm{reader: pr} 198 ctx := context.Background() 199 200 tests := []struct { 201 name string 202 doRead func(ctx context.Context, cr *ContextReader) error 203 wantRestore bool 204 }{ 205 { 206 name: "no pending read", 207 doRead: func(ctx context.Context, cr *ContextReader) error { 208 <-ctx.Done() 209 return ctx.Err() 210 }, 211 }, 212 { 213 name: "pending clean read", 214 doRead: func(ctx context.Context, cr *ContextReader) error { 215 _, err := cr.ReadContext(ctx) 216 return err 217 }, 218 }, 219 { 220 name: "pending password read", 221 doRead: func(ctx context.Context, cr *ContextReader) error { 222 _, err := cr.ReadPassword(ctx) 223 return err 224 }, 225 wantRestore: true, 226 }, 227 } 228 for _, test := range tests { 229 t.Run(test.name, func(t *testing.T) { 230 term.restoreCalled = false // reset state between tests 231 232 cr := NewContextReader(pr) 233 cr.term = term 234 cr.fd = int(devNull.Fd()) // arbitrary 235 SetStdin(cr) 236 237 // Give the read time to block. 238 ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond) 239 defer cancel() 240 err := test.doRead(ctx, cr) 241 require.ErrorIs(t, err, context.DeadlineExceeded, "unexpected read error") 242 243 NotifyExit() // closes Stdin 244 assert.Equal(t, test.wantRestore, term.restoreCalled, "term.Restore mismatch") 245 }) 246 } 247 } 248 249 type fakeTerm struct { 250 reader io.Reader 251 restoreCalled bool 252 } 253 254 func (t *fakeTerm) GetState(fd int) (*term.State, error) { 255 return &term.State{}, nil 256 } 257 258 func (t *fakeTerm) IsTerminal(fd int) bool { 259 return true 260 } 261 262 func (t *fakeTerm) ReadPassword(fd int) ([]byte, error) { 263 const bufLen = 1024 // arbitrary, big enough for test data 264 data := make([]byte, bufLen) 265 n, err := t.reader.Read(data) 266 data = data[:n] 267 return data, err 268 } 269 270 func (t *fakeTerm) Restore(fd int, oldState *term.State) error { 271 t.restoreCalled = true 272 return nil 273 }