github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/prompt/context_reader.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package prompt
    18  
    19  import (
    20  	"bufio"
    21  	"context"
    22  	"errors"
    23  	"io"
    24  	"log/slog"
    25  	"os"
    26  	"os/signal"
    27  	"sync"
    28  
    29  	"github.com/gravitational/trace"
    30  	"golang.org/x/term"
    31  )
    32  
    33  // ErrReaderClosed is returned from ContextReader.ReadContext after it is
    34  // closed.
    35  var ErrReaderClosed = errors.New("ContextReader has been closed")
    36  
    37  // ErrNotTerminal is returned by password reads attempted in non-terminal
    38  // readers.
    39  var ErrNotTerminal = errors.New("underlying reader is not a terminal")
    40  
    41  const bufferSize = 4096
    42  
    43  type readOutcome struct {
    44  	value []byte
    45  	err   error
    46  }
    47  
    48  type readerState int
    49  
    50  const (
    51  	readerStateIdle readerState = iota
    52  	readerStateClean
    53  	readerStatePassword
    54  	readerStateClosed
    55  )
    56  
    57  // termI aggregates methods from golang.org/x/term for easy mocking.
    58  type termI interface {
    59  	GetState(fd int) (*term.State, error)
    60  	IsTerminal(fd int) bool
    61  	ReadPassword(fd int) ([]byte, error)
    62  	Restore(fd int, oldState *term.State) error
    63  }
    64  
    65  // gxTerm delegates method calls to golang.org/x/term methods.
    66  type gxTerm struct{}
    67  
    68  func (gxTerm) GetState(fd int) (*term.State, error) {
    69  	return term.GetState(fd)
    70  }
    71  
    72  func (gxTerm) IsTerminal(fd int) bool {
    73  	return term.IsTerminal(fd)
    74  }
    75  
    76  func (gxTerm) ReadPassword(fd int) ([]byte, error) {
    77  	return term.ReadPassword(fd)
    78  }
    79  
    80  func (gxTerm) Restore(fd int, oldState *term.State) error {
    81  	return term.Restore(fd, oldState)
    82  }
    83  
    84  // ContextReader is a wrapper around an underlying io.Reader or terminal that
    85  // allows reads to be abandoned. An abandoned read may be reclaimed by future
    86  // callers.
    87  // ContextReader instances are not safe for concurrent use, callers may block
    88  // indefinitely and reads may be lost.
    89  type ContextReader struct {
    90  	term termI
    91  
    92  	// reader is used for clean reads.
    93  	reader io.Reader
    94  	// fd is used for password reads.
    95  	// Only present if the underlying reader is a terminal, otherwise set to -1.
    96  	fd int
    97  
    98  	closed chan struct{}
    99  	reads  chan readOutcome
   100  
   101  	mu                *sync.Mutex
   102  	cond              *sync.Cond
   103  	previousTermState *term.State
   104  	state             readerState
   105  }
   106  
   107  // NewContextReader creates a new ContextReader wrapping rd.
   108  // Callers should avoid reading from rd after the ContextReader is used, as
   109  // abandoned calls may be in progress. It is safe to read from rd if one can
   110  // guarantee that no calls where abandoned.
   111  // Calling ContextReader.Close attempts to release resources, but note that
   112  // ongoing reads cannot be interrupted.
   113  func NewContextReader(rd io.Reader) *ContextReader {
   114  	term := gxTerm{}
   115  
   116  	fd := -1
   117  	if f, ok := rd.(*os.File); ok {
   118  		val := int(f.Fd())
   119  		if term.IsTerminal(val) {
   120  			fd = val
   121  		}
   122  	}
   123  
   124  	mu := &sync.Mutex{}
   125  	cond := sync.NewCond(mu)
   126  	cr := &ContextReader{
   127  		term:   term,
   128  		reader: bufio.NewReader(rd),
   129  		fd:     fd,
   130  		closed: make(chan struct{}),
   131  		reads:  make(chan readOutcome), // unbuffered
   132  		mu:     mu,
   133  		cond:   cond,
   134  	}
   135  	go cr.processReads()
   136  	return cr
   137  }
   138  
   139  func (cr *ContextReader) processReads() {
   140  	defer close(cr.reads)
   141  
   142  	for {
   143  		cr.mu.Lock()
   144  		for cr.state == readerStateIdle {
   145  			cr.cond.Wait()
   146  		}
   147  		// Stop the reading loop? Once closed, forever closed.
   148  		if cr.state == readerStateClosed {
   149  			cr.mu.Unlock()
   150  			return
   151  		}
   152  		// React to the state that took us out of idleness.
   153  		// We can't hold the lock during the entire read, so we obey the last state
   154  		// observed.
   155  		state := cr.state
   156  		cr.mu.Unlock()
   157  
   158  		var value []byte
   159  		var err error
   160  		switch state {
   161  		case readerStateClean:
   162  			value = make([]byte, bufferSize)
   163  			var n int
   164  			n, err = cr.reader.Read(value)
   165  			value = value[:n]
   166  		case readerStatePassword:
   167  			value, err = cr.term.ReadPassword(cr.fd)
   168  		}
   169  		cr.mu.Lock()
   170  		cr.previousTermState = nil // A finalized read resets the terminal.
   171  		switch cr.state {
   172  		case readerStateClosed: // Don't transition from closed.
   173  		default:
   174  			cr.state = readerStateIdle
   175  		}
   176  		cr.mu.Unlock()
   177  
   178  		select {
   179  		case <-cr.closed:
   180  			slog.WarnContext(context.Background(), "ContextReader closed during ongoing read,", "dropped_bytes", len(value))
   181  			return
   182  		case cr.reads <- readOutcome{value: value, err: err}:
   183  		}
   184  	}
   185  }
   186  
   187  // IsTerminal returns whether the given reader is a terminal.
   188  func (cr *ContextReader) IsTerminal() bool {
   189  	return cr.term.IsTerminal(cr.fd)
   190  }
   191  
   192  // handleInterrupt restores terminal state on interrupts.
   193  // Called only on global ContextReaders, such as Stdin.
   194  func (cr *ContextReader) handleInterrupt() {
   195  	c := make(chan os.Signal, 1)
   196  	signal.Notify(c, os.Interrupt)
   197  	defer signal.Stop(c)
   198  
   199  	for {
   200  		select {
   201  		case sig := <-c:
   202  			slog.DebugContext(context.Background(), "Captured signal attempting to restore terminal state", "signal", sig)
   203  			cr.mu.Lock()
   204  			_ = cr.maybeRestoreTerm(iAmHoldingTheLock{})
   205  			cr.mu.Unlock()
   206  		case <-cr.closed:
   207  			return
   208  		}
   209  	}
   210  }
   211  
   212  // iAmHoldingTheLock exists only to draw attention to the need to hold the lock.
   213  type iAmHoldingTheLock struct{}
   214  
   215  // maybeRestoreTerm attempts to restore terminal state.
   216  // Lock must be held before calling.
   217  func (cr *ContextReader) maybeRestoreTerm(_ iAmHoldingTheLock) error {
   218  	if cr.state == readerStatePassword && cr.previousTermState != nil {
   219  		err := cr.term.Restore(cr.fd, cr.previousTermState)
   220  		cr.previousTermState = nil
   221  		return trace.Wrap(err)
   222  	}
   223  
   224  	return nil
   225  }
   226  
   227  // ReadContext returns the next chunk of output from the reader.
   228  // If ctx is canceled before the read completes, the current read is abandoned
   229  // and may be reclaimed by future callers.
   230  // It is not safe to read from the underlying reader after a read is abandoned,
   231  // nor is it safe to concurrently call ReadContext.
   232  func (cr *ContextReader) ReadContext(ctx context.Context) ([]byte, error) {
   233  	if err := cr.fireCleanRead(); err != nil {
   234  		return nil, trace.Wrap(err)
   235  	}
   236  
   237  	return cr.waitForRead(ctx)
   238  }
   239  
   240  func (cr *ContextReader) fireCleanRead() error {
   241  	cr.mu.Lock()
   242  	defer cr.mu.Unlock()
   243  
   244  	// Atempt to restore terminal state, so we transition to a clean read.
   245  	if err := cr.maybeRestoreTerm(iAmHoldingTheLock{}); err != nil {
   246  		return trace.Wrap(err)
   247  	}
   248  
   249  	switch cr.state {
   250  	case readerStateIdle: // OK, transition and broadcast.
   251  		cr.state = readerStateClean
   252  		cr.cond.Broadcast()
   253  	case readerStateClean: // OK, ongoing read.
   254  	case readerStatePassword: // OK, ongoing read.
   255  	case readerStateClosed:
   256  		return ErrReaderClosed
   257  	}
   258  	return nil
   259  }
   260  
   261  func (cr *ContextReader) waitForRead(ctx context.Context) ([]byte, error) {
   262  	select {
   263  	case <-ctx.Done():
   264  		return nil, trace.Wrap(ctx.Err())
   265  	case <-cr.closed:
   266  		return nil, ErrReaderClosed
   267  	case read := <-cr.reads:
   268  		return read.value, read.err
   269  	}
   270  }
   271  
   272  // ReadPassword reads a password from the underlying reader, provided that the
   273  // reader is a terminal.
   274  // It follows the semantics of ReadContext.
   275  func (cr *ContextReader) ReadPassword(ctx context.Context) ([]byte, error) {
   276  	if cr.fd == -1 {
   277  		return nil, ErrNotTerminal
   278  	}
   279  	if err := cr.firePasswordRead(ctx); err != nil {
   280  		return nil, trace.Wrap(err)
   281  	}
   282  
   283  	return cr.waitForRead(ctx)
   284  }
   285  
   286  func (cr *ContextReader) firePasswordRead(ctx context.Context) error {
   287  	cr.mu.Lock()
   288  	defer cr.mu.Unlock()
   289  
   290  	switch cr.state {
   291  	case readerStateIdle: // OK, transition and broadcast.
   292  		// Save present terminal state, so it may be restored in case the read goes
   293  		// from password to clean.
   294  		state, err := cr.term.GetState(cr.fd)
   295  		if err != nil {
   296  			return trace.Wrap(err)
   297  		}
   298  		cr.previousTermState = state
   299  		cr.state = readerStatePassword
   300  		cr.cond.Broadcast()
   301  	case readerStateClean: // OK, ongoing clean read.
   302  		// TODO(codingllama): Transition the terminal to password read?
   303  		slog.WarnContext(ctx, "prompt: Clean read reused by password read")
   304  	case readerStatePassword: // OK, ongoing password read.
   305  	case readerStateClosed:
   306  		return ErrReaderClosed
   307  	}
   308  	return nil
   309  }
   310  
   311  // Close closes the context reader, attempting to release resources and aborting
   312  // ongoing and future ReadContext calls.
   313  // Background reads that are already blocked cannot be interrupted, thus Close
   314  // doesn't guarantee a release of all resources.
   315  func (cr *ContextReader) Close() error {
   316  	cr.mu.Lock()
   317  	defer cr.mu.Unlock()
   318  
   319  	switch cr.state {
   320  	case readerStateClosed: // OK, already closed.
   321  	default:
   322  		// Attempt to restore terminal state on close.
   323  		_ = cr.maybeRestoreTerm(iAmHoldingTheLock{})
   324  
   325  		cr.state = readerStateClosed
   326  		close(cr.closed) // interrupt blocked sends.
   327  		cr.cond.Broadcast()
   328  	}
   329  
   330  	return nil
   331  }
   332  
   333  // PasswordReader is a ContextReader that reads passwords from the underlying
   334  // terminal.
   335  type PasswordReader ContextReader
   336  
   337  // Password returns a PasswordReader from a ContextReader.
   338  // The returned PasswordReader is only functional if the underlying reader is a
   339  // terminal.
   340  func (cr *ContextReader) Password() *PasswordReader {
   341  	return (*PasswordReader)(cr)
   342  }
   343  
   344  // ReadContext reads a password from the underlying reader, provided that the
   345  // reader is a terminal. It is equivalent to ContextReader.ReadPassword.
   346  // It follows the semantics of ReadContext.
   347  func (pr *PasswordReader) ReadContext(ctx context.Context) ([]byte, error) {
   348  	cr := (*ContextReader)(pr)
   349  	return cr.ReadPassword(ctx)
   350  }