github.com/toolvox/utilgo@v0.0.5/pkg/cli/flagutil/file_test.go (about)

     1  package flagutil_test
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/toolvox/utilgo/pkg/cli/flagutil"
    13  	"github.com/toolvox/utilgo/pkg/errs"
    14  )
    15  
    16  func TestFileValue(t *testing.T) {
    17  	tempDir := prepTemp(t)
    18  
    19  	tests := []struct {
    20  		name      string
    21  		filename  string
    22  		content   string
    23  		expectErr error
    24  	}{
    25  		{
    26  			name:     "read file",
    27  			filename: "prefix",
    28  			content:  "<prefix>\n",
    29  		},
    30  		{
    31  			name:     "no file",
    32  			filename: "",
    33  			content:  "",
    34  		},
    35  		{
    36  			name:      "read not existing file",
    37  			filename:  "x",
    38  			expectErr: errs.New(`invalid value "x" for flag -test: unable to read file 'x': open x: The system cannot find the file specified.`),
    39  		},
    40  	}
    41  	for ti, tt := range tests {
    42  		t.Run(fmt.Sprintf("%d_%s", ti, tt.name), func(t *testing.T) {
    43  			os.Chdir(tempDir)
    44  			defer os.Chdir("../..")
    45  
    46  			testSet := flag.NewFlagSet(tt.name, flag.ContinueOnError)
    47  			testSet.SetOutput(io.Discard)
    48  
    49  			var testFlag flagutil.FileValue
    50  			testSet.Var(&testFlag, "test", "input file")
    51  			must := require.New(t)
    52  			must.NotPanics(func() {
    53  				err := testSet.Parse([]string{"-test", tt.filename})
    54  				if tt.expectErr != nil {
    55  					must.Error(err)
    56  					must.Equal(tt.expectErr.Error(), err.Error())
    57  					return
    58  				}
    59  				must.NoError(err)
    60  				must.Equal(fmt.Sprintf(`"%s"`, tt.filename), testFlag.String())
    61  				must.Equal(tt.content, string(testFlag.Get().([]byte)))
    62  				reader := testFlag.Reader()
    63  				must.NotNil(reader)
    64  				bs, err := io.ReadAll(reader)
    65  				must.NoError(err)
    66  				must.Equal(tt.content, string(bs))
    67  			})
    68  		})
    69  	}
    70  }