github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/cli/command/utils_test.go (about)

     1  package command_test
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"syscall"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/khulnasoft/cli/cli/command"
    17  	"github.com/khulnasoft/cli/internal/test"
    18  	"github.com/pkg/errors"
    19  	"gotest.tools/v3/assert"
    20  )
    21  
    22  func TestStringSliceReplaceAt(t *testing.T) {
    23  	out, ok := command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, -1)
    24  	assert.Assert(t, ok)
    25  	assert.DeepEqual(t, []string{"abc", "baz", "bax"}, out)
    26  
    27  	out, ok = command.StringSliceReplaceAt([]string{"foo"}, []string{"foo", "bar"}, []string{"baz"}, -1)
    28  	assert.Assert(t, !ok)
    29  	assert.DeepEqual(t, []string{"foo"}, out)
    30  
    31  	out, ok = command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0)
    32  	assert.Assert(t, !ok)
    33  	assert.DeepEqual(t, []string{"abc", "foo", "bar", "bax"}, out)
    34  
    35  	out, ok = command.StringSliceReplaceAt([]string{"foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0)
    36  	assert.Assert(t, ok)
    37  	assert.DeepEqual(t, []string{"baz", "bax"}, out)
    38  
    39  	out, ok = command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "baz"}, []string{"foo", "bar"}, nil, -1)
    40  	assert.Assert(t, ok)
    41  	assert.DeepEqual(t, []string{"abc", "baz"}, out)
    42  
    43  	out, ok = command.StringSliceReplaceAt([]string{"foo"}, nil, []string{"baz"}, -1)
    44  	assert.Assert(t, !ok)
    45  	assert.DeepEqual(t, []string{"foo"}, out)
    46  }
    47  
    48  func TestValidateOutputPath(t *testing.T) {
    49  	basedir := t.TempDir()
    50  	dir := filepath.Join(basedir, "dir")
    51  	notexist := filepath.Join(basedir, "notexist")
    52  	err := os.MkdirAll(dir, 0o755)
    53  	assert.NilError(t, err)
    54  	file := filepath.Join(dir, "file")
    55  	err = os.WriteFile(file, []byte("hi"), 0o644)
    56  	assert.NilError(t, err)
    57  	testcases := []struct {
    58  		path string
    59  		err  error
    60  	}{
    61  		{basedir, nil},
    62  		{file, nil},
    63  		{dir, nil},
    64  		{dir + string(os.PathSeparator), nil},
    65  		{notexist, nil},
    66  		{notexist + string(os.PathSeparator), nil},
    67  		{filepath.Join(notexist, "file"), errors.New("does not exist")},
    68  	}
    69  
    70  	for _, testcase := range testcases {
    71  		t.Run(testcase.path, func(t *testing.T) {
    72  			err := command.ValidateOutputPath(testcase.path)
    73  			if testcase.err == nil {
    74  				assert.NilError(t, err)
    75  			} else {
    76  				assert.ErrorContains(t, err, testcase.err.Error())
    77  			}
    78  		})
    79  	}
    80  }
    81  
    82  func TestPromptForConfirmation(t *testing.T) {
    83  	ctx, cancel := context.WithCancel(context.Background())
    84  	t.Cleanup(cancel)
    85  
    86  	type promptResult struct {
    87  		result bool
    88  		err    error
    89  	}
    90  
    91  	buf := new(bytes.Buffer)
    92  	bufioWriter := bufio.NewWriter(buf)
    93  
    94  	var (
    95  		promptWriter *io.PipeWriter
    96  		promptReader *io.PipeReader
    97  	)
    98  
    99  	defer func() {
   100  		if promptWriter != nil {
   101  			promptWriter.Close()
   102  		}
   103  		if promptReader != nil {
   104  			promptReader.Close()
   105  		}
   106  	}()
   107  
   108  	for _, tc := range []struct {
   109  		desc     string
   110  		f        func() error
   111  		expected promptResult
   112  	}{
   113  		{"SIGINT", func() error {
   114  			syscall.Kill(syscall.Getpid(), syscall.SIGINT)
   115  			return nil
   116  		}, promptResult{false, command.ErrPromptTerminated}},
   117  		{"no", func() error {
   118  			_, err := fmt.Fprint(promptWriter, "n\n")
   119  			return err
   120  		}, promptResult{false, nil}},
   121  		{"yes", func() error {
   122  			_, err := fmt.Fprint(promptWriter, "y\n")
   123  			return err
   124  		}, promptResult{true, nil}},
   125  		{"any", func() error {
   126  			_, err := fmt.Fprint(promptWriter, "a\n")
   127  			return err
   128  		}, promptResult{false, nil}},
   129  		{"with space", func() error {
   130  			_, err := fmt.Fprint(promptWriter, " y\n")
   131  			return err
   132  		}, promptResult{true, nil}},
   133  		{"reader closed", func() error {
   134  			return promptReader.Close()
   135  		}, promptResult{false, nil}},
   136  	} {
   137  		t.Run("case="+tc.desc, func(t *testing.T) {
   138  			buf.Reset()
   139  			promptReader, promptWriter = io.Pipe()
   140  
   141  			wroteHook := make(chan struct{}, 1)
   142  			promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
   143  				wroteHook <- struct{}{}
   144  			})
   145  
   146  			result := make(chan promptResult, 1)
   147  			go func() {
   148  				r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
   149  				result <- promptResult{r, err}
   150  			}()
   151  
   152  			select {
   153  			case <-time.After(100 * time.Millisecond):
   154  			case <-wroteHook:
   155  			}
   156  
   157  			assert.NilError(t, bufioWriter.Flush())
   158  			assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]")
   159  
   160  			// wait for the Prompt to write to the buffer
   161  			drainChannel(ctx, wroteHook)
   162  
   163  			assert.NilError(t, tc.f())
   164  
   165  			select {
   166  			case <-time.After(500 * time.Millisecond):
   167  				t.Fatal("timeout waiting for prompt result")
   168  			case r := <-result:
   169  				assert.Equal(t, r, tc.expected)
   170  			}
   171  		})
   172  	}
   173  }
   174  
   175  func drainChannel(ctx context.Context, ch <-chan struct{}) {
   176  	go func() {
   177  		for {
   178  			select {
   179  			case <-ctx.Done():
   180  				return
   181  			case <-ch:
   182  			}
   183  		}
   184  	}()
   185  }