github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/prompt/context_reader_test.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package prompt
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"os"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"golang.org/x/term"
    29  )
    30  
    31  func TestContextReader(t *testing.T) {
    32  	pr, pw := io.Pipe()
    33  	t.Cleanup(func() { pr.Close() })
    34  	t.Cleanup(func() { pw.Close() })
    35  
    36  	write := func(t *testing.T, s string) {
    37  		t.Helper()
    38  		_, err := pw.Write([]byte(s))
    39  		assert.NoError(t, err, "Write failed")
    40  	}
    41  
    42  	ctx := context.Background()
    43  	cr := NewContextReader(pr)
    44  
    45  	t.Run("simple read", func(t *testing.T) {
    46  		go write(t, "hello")
    47  		buf, err := cr.ReadContext(ctx)
    48  		require.NoError(t, err)
    49  		require.Equal(t, "hello", string(buf))
    50  	})
    51  
    52  	t.Run("reclaim abandoned read", func(t *testing.T) {
    53  		done := make(chan struct{})
    54  		cancelCtx, cancel := context.WithCancel(ctx)
    55  		go func() {
    56  			time.Sleep(1 * time.Millisecond) // give ReadContext time to block
    57  			cancel()
    58  			write(t, "after cancel")
    59  			close(done)
    60  		}()
    61  		buf, err := cr.ReadContext(cancelCtx)
    62  		require.ErrorIs(t, err, context.Canceled)
    63  		require.Empty(t, buf)
    64  
    65  		<-done // wait for write
    66  		buf, err = cr.ReadContext(ctx)
    67  		require.NoError(t, err)
    68  		require.Equal(t, "after cancel", string(buf))
    69  	})
    70  
    71  	t.Run("close ContextReader", func(t *testing.T) {
    72  		go func() {
    73  			time.Sleep(1 * time.Millisecond) // give ReadContext time to block
    74  			assert.NoError(t, cr.Close(), "Close errored")
    75  		}()
    76  		_, err := cr.ReadContext(ctx)
    77  		require.ErrorIs(t, err, ErrReaderClosed)
    78  
    79  		// Subsequent reads fail.
    80  		_, err = cr.ReadContext(ctx)
    81  		require.ErrorIs(t, err, ErrReaderClosed)
    82  
    83  		// Ongoing read after Close is dropped.
    84  		write(t, "unblock goroutine")
    85  		buf, err := cr.ReadContext(ctx)
    86  		assert.ErrorIs(t, err, ErrReaderClosed)
    87  		assert.Empty(t, buf, "buf not empty")
    88  
    89  		// Multiple closes are fine.
    90  		assert.NoError(t, cr.Close(), "2nd Close failed")
    91  	})
    92  
    93  	// Re-creating is safe because the tests above leave no "pending" reads.
    94  	cr = NewContextReader(pr)
    95  
    96  	t.Run("close underlying reader", func(t *testing.T) {
    97  		go func() {
    98  			write(t, "before close")
    99  			pw.CloseWithError(io.EOF)
   100  		}()
   101  
   102  		// Read the last chunk of data successfully.
   103  		buf, err := cr.ReadContext(ctx)
   104  		require.NoError(t, err)
   105  		require.Equal(t, "before close", string(buf))
   106  
   107  		// Next read fails because underlying reader is closed.
   108  		buf, err = cr.ReadContext(ctx)
   109  		require.ErrorIs(t, err, io.EOF)
   110  		require.Empty(t, buf)
   111  	})
   112  }
   113  
   114  func TestContextReader_ReadPassword(t *testing.T) {
   115  	pr, pw := io.Pipe()
   116  	write := func(t *testing.T, s string) {
   117  		t.Helper()
   118  		_, err := pw.Write([]byte(s))
   119  		assert.NoError(t, err, "Write failed")
   120  	}
   121  
   122  	devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0666)
   123  	require.NoError(t, err, "Failed to open %v", os.DevNull)
   124  	defer devNull.Close()
   125  
   126  	term := &fakeTerm{reader: pr}
   127  	cr := NewContextReader(pr)
   128  	cr.term = term
   129  	cr.fd = int(devNull.Fd()) // arbitrary, doesn't matter because term functions are mocked.
   130  
   131  	ctx := context.Background()
   132  	t.Run("read password", func(t *testing.T) {
   133  		const want = "llama45"
   134  		go write(t, want)
   135  
   136  		got, err := cr.ReadPassword(ctx)
   137  		require.NoError(t, err, "ReadPassword failed")
   138  		assert.Equal(t, want, string(got), "ReadPassword mismatch")
   139  	})
   140  
   141  	t.Run("intertwine reads", func(t *testing.T) {
   142  		const want1 = "hello, world"
   143  		go write(t, want1)
   144  		got, err := cr.ReadPassword(ctx)
   145  		require.NoError(t, err, "ReadPassword failed")
   146  		assert.Equal(t, want1, string(got), "ReadPassword mismatch")
   147  
   148  		const want2 = "goodbye, world"
   149  		go write(t, want2)
   150  		got, err = cr.ReadContext(ctx)
   151  		require.NoError(t, err, "ReadContext failed")
   152  		assert.Equal(t, want2, string(got), "ReadContext mismatch")
   153  	})
   154  
   155  	t.Run("password read turned clean", func(t *testing.T) {
   156  		require.False(t, term.restoreCalled, "restoreCalled sanity check failed")
   157  
   158  		// Give ReadPassword time to block.
   159  		cancelCtx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
   160  		defer cancel()
   161  		got, err := cr.ReadPassword(cancelCtx)
   162  		require.ErrorIs(t, err, context.DeadlineExceeded, "ReadPassword returned unexpected error")
   163  		require.Empty(t, got, "ReadPassword mismatch")
   164  
   165  		// Reclaim as clean read.
   166  		const want = "abandoned pwd read"
   167  		go func() {
   168  			// Once again, give ReadContext time to block.
   169  			// This way we force a restore.
   170  			time.Sleep(1 * time.Millisecond)
   171  			write(t, want)
   172  		}()
   173  		got, err = cr.ReadContext(ctx)
   174  		require.NoError(t, err, "ReadContext failed")
   175  		assert.Equal(t, want, string(got), "ReadContext mismatch")
   176  		assert.True(t, term.restoreCalled, "term.Restore not called")
   177  	})
   178  
   179  	t.Run("Close", func(t *testing.T) {
   180  		require.NoError(t, cr.Close(), "Close errored")
   181  
   182  		_, err := cr.ReadPassword(ctx)
   183  		require.ErrorIs(t, err, ErrReaderClosed, "ReadPassword returned unexpected error")
   184  	})
   185  }
   186  
   187  func TestNotifyExit_restoresTerminal(t *testing.T) {
   188  	oldStdin := Stdin()
   189  	t.Cleanup(func() { SetStdin(oldStdin) })
   190  
   191  	pr, _ := io.Pipe()
   192  
   193  	devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0666)
   194  	require.NoError(t, err, "Failed to open %v", os.DevNull)
   195  	defer devNull.Close()
   196  
   197  	term := &fakeTerm{reader: pr}
   198  	ctx := context.Background()
   199  
   200  	tests := []struct {
   201  		name        string
   202  		doRead      func(ctx context.Context, cr *ContextReader) error
   203  		wantRestore bool
   204  	}{
   205  		{
   206  			name: "no pending read",
   207  			doRead: func(ctx context.Context, cr *ContextReader) error {
   208  				<-ctx.Done()
   209  				return ctx.Err()
   210  			},
   211  		},
   212  		{
   213  			name: "pending clean read",
   214  			doRead: func(ctx context.Context, cr *ContextReader) error {
   215  				_, err := cr.ReadContext(ctx)
   216  				return err
   217  			},
   218  		},
   219  		{
   220  			name: "pending password read",
   221  			doRead: func(ctx context.Context, cr *ContextReader) error {
   222  				_, err := cr.ReadPassword(ctx)
   223  				return err
   224  			},
   225  			wantRestore: true,
   226  		},
   227  	}
   228  	for _, test := range tests {
   229  		t.Run(test.name, func(t *testing.T) {
   230  			term.restoreCalled = false // reset state between tests
   231  
   232  			cr := NewContextReader(pr)
   233  			cr.term = term
   234  			cr.fd = int(devNull.Fd()) // arbitrary
   235  			SetStdin(cr)
   236  
   237  			// Give the read time to block.
   238  			ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
   239  			defer cancel()
   240  			err := test.doRead(ctx, cr)
   241  			require.ErrorIs(t, err, context.DeadlineExceeded, "unexpected read error")
   242  
   243  			NotifyExit() // closes Stdin
   244  			assert.Equal(t, test.wantRestore, term.restoreCalled, "term.Restore mismatch")
   245  		})
   246  	}
   247  }
   248  
   249  type fakeTerm struct {
   250  	reader        io.Reader
   251  	restoreCalled bool
   252  }
   253  
   254  func (t *fakeTerm) GetState(fd int) (*term.State, error) {
   255  	return &term.State{}, nil
   256  }
   257  
   258  func (t *fakeTerm) IsTerminal(fd int) bool {
   259  	return true
   260  }
   261  
   262  func (t *fakeTerm) ReadPassword(fd int) ([]byte, error) {
   263  	const bufLen = 1024 // arbitrary, big enough for test data
   264  	data := make([]byte, bufLen)
   265  	n, err := t.reader.Read(data)
   266  	data = data[:n]
   267  	return data, err
   268  }
   269  
   270  func (t *fakeTerm) Restore(fd int, oldState *term.State) error {
   271  	t.restoreCalled = true
   272  	return nil
   273  }