github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/prompt/context_reader.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package prompt 18 19 import ( 20 "bufio" 21 "context" 22 "errors" 23 "io" 24 "log/slog" 25 "os" 26 "os/signal" 27 "sync" 28 29 "github.com/gravitational/trace" 30 "golang.org/x/term" 31 ) 32 33 // ErrReaderClosed is returned from ContextReader.ReadContext after it is 34 // closed. 35 var ErrReaderClosed = errors.New("ContextReader has been closed") 36 37 // ErrNotTerminal is returned by password reads attempted in non-terminal 38 // readers. 39 var ErrNotTerminal = errors.New("underlying reader is not a terminal") 40 41 const bufferSize = 4096 42 43 type readOutcome struct { 44 value []byte 45 err error 46 } 47 48 type readerState int 49 50 const ( 51 readerStateIdle readerState = iota 52 readerStateClean 53 readerStatePassword 54 readerStateClosed 55 ) 56 57 // termI aggregates methods from golang.org/x/term for easy mocking. 58 type termI interface { 59 GetState(fd int) (*term.State, error) 60 IsTerminal(fd int) bool 61 ReadPassword(fd int) ([]byte, error) 62 Restore(fd int, oldState *term.State) error 63 } 64 65 // gxTerm delegates method calls to golang.org/x/term methods. 66 type gxTerm struct{} 67 68 func (gxTerm) GetState(fd int) (*term.State, error) { 69 return term.GetState(fd) 70 } 71 72 func (gxTerm) IsTerminal(fd int) bool { 73 return term.IsTerminal(fd) 74 } 75 76 func (gxTerm) ReadPassword(fd int) ([]byte, error) { 77 return term.ReadPassword(fd) 78 } 79 80 func (gxTerm) Restore(fd int, oldState *term.State) error { 81 return term.Restore(fd, oldState) 82 } 83 84 // ContextReader is a wrapper around an underlying io.Reader or terminal that 85 // allows reads to be abandoned. An abandoned read may be reclaimed by future 86 // callers. 87 // ContextReader instances are not safe for concurrent use, callers may block 88 // indefinitely and reads may be lost. 89 type ContextReader struct { 90 term termI 91 92 // reader is used for clean reads. 93 reader io.Reader 94 // fd is used for password reads. 95 // Only present if the underlying reader is a terminal, otherwise set to -1. 96 fd int 97 98 closed chan struct{} 99 reads chan readOutcome 100 101 mu *sync.Mutex 102 cond *sync.Cond 103 previousTermState *term.State 104 state readerState 105 } 106 107 // NewContextReader creates a new ContextReader wrapping rd. 108 // Callers should avoid reading from rd after the ContextReader is used, as 109 // abandoned calls may be in progress. It is safe to read from rd if one can 110 // guarantee that no calls where abandoned. 111 // Calling ContextReader.Close attempts to release resources, but note that 112 // ongoing reads cannot be interrupted. 113 func NewContextReader(rd io.Reader) *ContextReader { 114 term := gxTerm{} 115 116 fd := -1 117 if f, ok := rd.(*os.File); ok { 118 val := int(f.Fd()) 119 if term.IsTerminal(val) { 120 fd = val 121 } 122 } 123 124 mu := &sync.Mutex{} 125 cond := sync.NewCond(mu) 126 cr := &ContextReader{ 127 term: term, 128 reader: bufio.NewReader(rd), 129 fd: fd, 130 closed: make(chan struct{}), 131 reads: make(chan readOutcome), // unbuffered 132 mu: mu, 133 cond: cond, 134 } 135 go cr.processReads() 136 return cr 137 } 138 139 func (cr *ContextReader) processReads() { 140 defer close(cr.reads) 141 142 for { 143 cr.mu.Lock() 144 for cr.state == readerStateIdle { 145 cr.cond.Wait() 146 } 147 // Stop the reading loop? Once closed, forever closed. 148 if cr.state == readerStateClosed { 149 cr.mu.Unlock() 150 return 151 } 152 // React to the state that took us out of idleness. 153 // We can't hold the lock during the entire read, so we obey the last state 154 // observed. 155 state := cr.state 156 cr.mu.Unlock() 157 158 var value []byte 159 var err error 160 switch state { 161 case readerStateClean: 162 value = make([]byte, bufferSize) 163 var n int 164 n, err = cr.reader.Read(value) 165 value = value[:n] 166 case readerStatePassword: 167 value, err = cr.term.ReadPassword(cr.fd) 168 } 169 cr.mu.Lock() 170 cr.previousTermState = nil // A finalized read resets the terminal. 171 switch cr.state { 172 case readerStateClosed: // Don't transition from closed. 173 default: 174 cr.state = readerStateIdle 175 } 176 cr.mu.Unlock() 177 178 select { 179 case <-cr.closed: 180 slog.WarnContext(context.Background(), "ContextReader closed during ongoing read,", "dropped_bytes", len(value)) 181 return 182 case cr.reads <- readOutcome{value: value, err: err}: 183 } 184 } 185 } 186 187 // IsTerminal returns whether the given reader is a terminal. 188 func (cr *ContextReader) IsTerminal() bool { 189 return cr.term.IsTerminal(cr.fd) 190 } 191 192 // handleInterrupt restores terminal state on interrupts. 193 // Called only on global ContextReaders, such as Stdin. 194 func (cr *ContextReader) handleInterrupt() { 195 c := make(chan os.Signal, 1) 196 signal.Notify(c, os.Interrupt) 197 defer signal.Stop(c) 198 199 for { 200 select { 201 case sig := <-c: 202 slog.DebugContext(context.Background(), "Captured signal attempting to restore terminal state", "signal", sig) 203 cr.mu.Lock() 204 _ = cr.maybeRestoreTerm(iAmHoldingTheLock{}) 205 cr.mu.Unlock() 206 case <-cr.closed: 207 return 208 } 209 } 210 } 211 212 // iAmHoldingTheLock exists only to draw attention to the need to hold the lock. 213 type iAmHoldingTheLock struct{} 214 215 // maybeRestoreTerm attempts to restore terminal state. 216 // Lock must be held before calling. 217 func (cr *ContextReader) maybeRestoreTerm(_ iAmHoldingTheLock) error { 218 if cr.state == readerStatePassword && cr.previousTermState != nil { 219 err := cr.term.Restore(cr.fd, cr.previousTermState) 220 cr.previousTermState = nil 221 return trace.Wrap(err) 222 } 223 224 return nil 225 } 226 227 // ReadContext returns the next chunk of output from the reader. 228 // If ctx is canceled before the read completes, the current read is abandoned 229 // and may be reclaimed by future callers. 230 // It is not safe to read from the underlying reader after a read is abandoned, 231 // nor is it safe to concurrently call ReadContext. 232 func (cr *ContextReader) ReadContext(ctx context.Context) ([]byte, error) { 233 if err := cr.fireCleanRead(); err != nil { 234 return nil, trace.Wrap(err) 235 } 236 237 return cr.waitForRead(ctx) 238 } 239 240 func (cr *ContextReader) fireCleanRead() error { 241 cr.mu.Lock() 242 defer cr.mu.Unlock() 243 244 // Atempt to restore terminal state, so we transition to a clean read. 245 if err := cr.maybeRestoreTerm(iAmHoldingTheLock{}); err != nil { 246 return trace.Wrap(err) 247 } 248 249 switch cr.state { 250 case readerStateIdle: // OK, transition and broadcast. 251 cr.state = readerStateClean 252 cr.cond.Broadcast() 253 case readerStateClean: // OK, ongoing read. 254 case readerStatePassword: // OK, ongoing read. 255 case readerStateClosed: 256 return ErrReaderClosed 257 } 258 return nil 259 } 260 261 func (cr *ContextReader) waitForRead(ctx context.Context) ([]byte, error) { 262 select { 263 case <-ctx.Done(): 264 return nil, trace.Wrap(ctx.Err()) 265 case <-cr.closed: 266 return nil, ErrReaderClosed 267 case read := <-cr.reads: 268 return read.value, read.err 269 } 270 } 271 272 // ReadPassword reads a password from the underlying reader, provided that the 273 // reader is a terminal. 274 // It follows the semantics of ReadContext. 275 func (cr *ContextReader) ReadPassword(ctx context.Context) ([]byte, error) { 276 if cr.fd == -1 { 277 return nil, ErrNotTerminal 278 } 279 if err := cr.firePasswordRead(ctx); err != nil { 280 return nil, trace.Wrap(err) 281 } 282 283 return cr.waitForRead(ctx) 284 } 285 286 func (cr *ContextReader) firePasswordRead(ctx context.Context) error { 287 cr.mu.Lock() 288 defer cr.mu.Unlock() 289 290 switch cr.state { 291 case readerStateIdle: // OK, transition and broadcast. 292 // Save present terminal state, so it may be restored in case the read goes 293 // from password to clean. 294 state, err := cr.term.GetState(cr.fd) 295 if err != nil { 296 return trace.Wrap(err) 297 } 298 cr.previousTermState = state 299 cr.state = readerStatePassword 300 cr.cond.Broadcast() 301 case readerStateClean: // OK, ongoing clean read. 302 // TODO(codingllama): Transition the terminal to password read? 303 slog.WarnContext(ctx, "prompt: Clean read reused by password read") 304 case readerStatePassword: // OK, ongoing password read. 305 case readerStateClosed: 306 return ErrReaderClosed 307 } 308 return nil 309 } 310 311 // Close closes the context reader, attempting to release resources and aborting 312 // ongoing and future ReadContext calls. 313 // Background reads that are already blocked cannot be interrupted, thus Close 314 // doesn't guarantee a release of all resources. 315 func (cr *ContextReader) Close() error { 316 cr.mu.Lock() 317 defer cr.mu.Unlock() 318 319 switch cr.state { 320 case readerStateClosed: // OK, already closed. 321 default: 322 // Attempt to restore terminal state on close. 323 _ = cr.maybeRestoreTerm(iAmHoldingTheLock{}) 324 325 cr.state = readerStateClosed 326 close(cr.closed) // interrupt blocked sends. 327 cr.cond.Broadcast() 328 } 329 330 return nil 331 } 332 333 // PasswordReader is a ContextReader that reads passwords from the underlying 334 // terminal. 335 type PasswordReader ContextReader 336 337 // Password returns a PasswordReader from a ContextReader. 338 // The returned PasswordReader is only functional if the underlying reader is a 339 // terminal. 340 func (cr *ContextReader) Password() *PasswordReader { 341 return (*PasswordReader)(cr) 342 } 343 344 // ReadContext reads a password from the underlying reader, provided that the 345 // reader is a terminal. It is equivalent to ContextReader.ReadPassword. 346 // It follows the semantics of ReadContext. 347 func (pr *PasswordReader) ReadContext(ctx context.Context) ([]byte, error) { 348 cr := (*ContextReader)(pr) 349 return cr.ReadPassword(ctx) 350 }