github.com/influxdata/influxdb/v2@v2.7.6/kit/cli/viper_test.go (about)

     1  package cli
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"math"
     7  	"os"
     8  	"path"
     9  	"path/filepath"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/BurntSushi/toml"
    14  	"github.com/spf13/viper"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  	"go.uber.org/zap/zapcore"
    18  	"gopkg.in/yaml.v3"
    19  )
    20  
    21  type customFlag bool
    22  
    23  func (c customFlag) String() string {
    24  	if c == true {
    25  		return "on"
    26  	}
    27  	return "off"
    28  }
    29  
    30  func (c *customFlag) Set(s string) error {
    31  	if s == "on" {
    32  		*c = true
    33  	} else {
    34  		*c = false
    35  	}
    36  
    37  	return nil
    38  }
    39  
    40  func (c *customFlag) Type() string {
    41  	return "fancy-bool"
    42  }
    43  
    44  func ExampleNewCommand() {
    45  	var monitorHost string
    46  	var number int
    47  	var smallerNumber int32
    48  	var longerNumber int64
    49  	var sleep bool
    50  	var duration time.Duration
    51  	var stringSlice []string
    52  	var fancyBool customFlag
    53  	var logLevel zapcore.Level
    54  	cmd, err := NewCommand(viper.New(), &Program{
    55  		Run: func() error {
    56  			fmt.Println(monitorHost)
    57  			for i := 0; i < number; i++ {
    58  				fmt.Printf("%d\n", i)
    59  			}
    60  			fmt.Println(longerNumber - int64(smallerNumber))
    61  			fmt.Println(sleep)
    62  			fmt.Println(duration)
    63  			fmt.Println(stringSlice)
    64  			fmt.Println(fancyBool)
    65  			fmt.Println(logLevel.String())
    66  			return nil
    67  		},
    68  		Name: "myprogram",
    69  		Opts: []Opt{
    70  			{
    71  				DestP:   &monitorHost,
    72  				Flag:    "monitor-host",
    73  				Default: "http://localhost:8086",
    74  				Desc:    "host to send influxdb metrics",
    75  			},
    76  			{
    77  				DestP:   &number,
    78  				Flag:    "number",
    79  				Default: 2,
    80  				Desc:    "number of times to loop",
    81  			},
    82  			{
    83  				DestP:   &smallerNumber,
    84  				Flag:    "smaller-number",
    85  				Default: math.MaxInt32,
    86  				Desc:    "limited size number",
    87  			},
    88  			{
    89  				DestP:   &longerNumber,
    90  				Flag:    "longer-number",
    91  				Default: math.MaxInt64,
    92  				Desc:    "explicitly expanded-size number",
    93  			},
    94  			{
    95  				DestP:   &sleep,
    96  				Flag:    "sleep",
    97  				Default: true,
    98  				Desc:    "whether to sleep",
    99  			},
   100  			{
   101  				DestP:   &duration,
   102  				Flag:    "duration",
   103  				Default: time.Minute,
   104  				Desc:    "how long to sleep",
   105  			},
   106  			{
   107  				DestP:   &stringSlice,
   108  				Flag:    "string-slice",
   109  				Default: []string{"foo", "bar"},
   110  				Desc:    "things come in lists",
   111  			},
   112  			{
   113  				DestP:   &fancyBool,
   114  				Flag:    "fancy-bool",
   115  				Default: "on",
   116  				Desc:    "things that implement pflag.Value",
   117  			},
   118  			{
   119  				DestP:   &logLevel,
   120  				Flag:    "log-level",
   121  				Default: zapcore.WarnLevel,
   122  			},
   123  		},
   124  	})
   125  	if err != nil {
   126  		_, _ = fmt.Fprintln(os.Stderr, err)
   127  		return
   128  	}
   129  
   130  	cmd.SetArgs([]string{})
   131  	if err := cmd.Execute(); err != nil {
   132  		_, _ = fmt.Fprintln(os.Stderr, err)
   133  	}
   134  	// Output:
   135  	// http://localhost:8086
   136  	// 0
   137  	// 1
   138  	// 9223372034707292160
   139  	// true
   140  	// 1m0s
   141  	// [foo bar]
   142  	// on
   143  	// warn
   144  }
   145  
   146  func Test_NewProgram(t *testing.T) {
   147  	config := map[string]string{
   148  		// config values should be same as flags
   149  		"foo":         "bar",
   150  		"shoe-fly":    "yadon",
   151  		"number":      "2147483647",
   152  		"long-number": "9223372036854775807",
   153  		"log-level":   "debug",
   154  	}
   155  
   156  	tests := []struct {
   157  		name      string
   158  		envVarVal string
   159  		args      []string
   160  		expected  string
   161  	}{
   162  		{
   163  			name:     "no vals reads from config",
   164  			expected: "bar",
   165  		},
   166  		{
   167  			name:      "reads from env var",
   168  			envVarVal: "foobar",
   169  			expected:  "foobar",
   170  		},
   171  		{
   172  			name:     "reads from flag",
   173  			args:     []string{"--foo=baz"},
   174  			expected: "baz",
   175  		},
   176  		{
   177  			name:      "flag has highest precedence",
   178  			envVarVal: "foobar",
   179  			args:      []string{"--foo=baz"},
   180  			expected:  "baz",
   181  		},
   182  	}
   183  
   184  	for _, tt := range tests {
   185  		for _, writer := range configWriters {
   186  			fn := func(t *testing.T) {
   187  				testDir := t.TempDir()
   188  
   189  				confFile, err := writer.writeFn(testDir, config)
   190  				require.NoError(t, err)
   191  
   192  				defer setEnvVar("TEST_CONFIG_PATH", confFile)()
   193  
   194  				if tt.envVarVal != "" {
   195  					defer setEnvVar("TEST_FOO", tt.envVarVal)()
   196  				}
   197  
   198  				var testVar string
   199  				var testFly string
   200  				var testNumber int32
   201  				var testLongNumber int64
   202  				var logLevel zapcore.Level
   203  				program := &Program{
   204  					Name: "test",
   205  					Opts: []Opt{
   206  						{
   207  							DestP:    &testVar,
   208  							Flag:     "foo",
   209  							Required: true,
   210  						},
   211  						{
   212  							DestP: &testFly,
   213  							Flag:  "shoe-fly",
   214  						},
   215  						{
   216  							DestP: &testNumber,
   217  							Flag:  "number",
   218  						},
   219  						{
   220  							DestP: &testLongNumber,
   221  							Flag:  "long-number",
   222  						},
   223  						{
   224  							DestP: &logLevel,
   225  							Flag:  "log-level",
   226  						},
   227  					},
   228  					Run: func() error { return nil },
   229  				}
   230  
   231  				cmd, err := NewCommand(viper.New(), program)
   232  				require.NoError(t, err)
   233  				cmd.SetArgs(append([]string{}, tt.args...))
   234  				require.NoError(t, cmd.Execute())
   235  
   236  				require.Equal(t, tt.expected, testVar)
   237  				assert.Equal(t, "yadon", testFly)
   238  				assert.Equal(t, int32(math.MaxInt32), testNumber)
   239  				assert.Equal(t, int64(math.MaxInt64), testLongNumber)
   240  				assert.Equal(t, zapcore.DebugLevel, logLevel)
   241  			}
   242  
   243  			t.Run(fmt.Sprintf("%s_%s", tt.name, writer.ext), fn)
   244  		}
   245  	}
   246  }
   247  
   248  func setEnvVar(key, val string) func() {
   249  	old := os.Getenv(key)
   250  	os.Setenv(key, val)
   251  	return func() {
   252  		os.Setenv(key, old)
   253  	}
   254  }
   255  
   256  type configWriter func(dir string, config interface{}) (string, error)
   257  type labeledWriter struct {
   258  	ext     string
   259  	writeFn configWriter
   260  }
   261  
   262  var configWriters = []labeledWriter{
   263  	{ext: "json", writeFn: writeJsonConfig},
   264  	{ext: "toml", writeFn: writeTomlConfig},
   265  	{ext: "yml", writeFn: yamlConfigWriter(true)},
   266  	{ext: "yaml", writeFn: yamlConfigWriter(false)},
   267  }
   268  
   269  func writeJsonConfig(dir string, config interface{}) (string, error) {
   270  	b, err := json.Marshal(config)
   271  	if err != nil {
   272  		return "", err
   273  	}
   274  	confFile := path.Join(dir, "config.json")
   275  	if err := os.WriteFile(confFile, b, os.ModePerm); err != nil {
   276  		return "", err
   277  	}
   278  	return confFile, nil
   279  }
   280  
   281  func writeTomlConfig(dir string, config interface{}) (string, error) {
   282  	confFile := path.Join(dir, "config.toml")
   283  	w, err := os.OpenFile(confFile, os.O_CREATE|os.O_EXCL|os.O_WRONLY, os.ModePerm)
   284  	if err != nil {
   285  		return "", err
   286  	}
   287  	defer w.Close()
   288  
   289  	if err := toml.NewEncoder(w).Encode(config); err != nil {
   290  		return "", err
   291  	}
   292  
   293  	return confFile, nil
   294  }
   295  
   296  func yamlConfigWriter(shortExt bool) configWriter {
   297  	fileName := "config.yaml"
   298  	if shortExt {
   299  		fileName = "config.yml"
   300  	}
   301  
   302  	return func(dir string, config interface{}) (string, error) {
   303  		confFile := path.Join(dir, fileName)
   304  		w, err := os.OpenFile(confFile, os.O_CREATE|os.O_EXCL|os.O_WRONLY, os.ModePerm)
   305  		if err != nil {
   306  			return "", err
   307  		}
   308  		defer w.Close()
   309  
   310  		if err := yaml.NewEncoder(w).Encode(config); err != nil {
   311  			return "", err
   312  		}
   313  
   314  		return confFile, nil
   315  	}
   316  }
   317  
   318  func Test_RequiredFlag(t *testing.T) {
   319  	var testVar string
   320  	program := &Program{
   321  		Name: "test",
   322  		Opts: []Opt{
   323  			{
   324  				DestP:    &testVar,
   325  				Flag:     "foo",
   326  				Required: true,
   327  			},
   328  		},
   329  	}
   330  
   331  	cmd, err := NewCommand(viper.New(), program)
   332  	require.NoError(t, err)
   333  	cmd.SetArgs([]string{})
   334  	err = cmd.Execute()
   335  	require.Error(t, err)
   336  	require.Equal(t, `required flag(s) "foo" not set`, err.Error())
   337  }
   338  
   339  func Test_ConfigPrecedence(t *testing.T) {
   340  	jsonConfig := map[string]interface{}{"log-level": zapcore.DebugLevel}
   341  	tomlConfig := map[string]interface{}{"log-level": zapcore.InfoLevel}
   342  	yamlConfig := map[string]interface{}{"log-level": zapcore.WarnLevel}
   343  	ymlConfig := map[string]interface{}{"log-level": zapcore.ErrorLevel}
   344  
   345  	tests := []struct {
   346  		name          string
   347  		writeJson     bool
   348  		writeToml     bool
   349  		writeYaml     bool
   350  		writeYml      bool
   351  		expectedLevel zapcore.Level
   352  	}{
   353  		{
   354  			name:          "JSON is used if present",
   355  			writeJson:     true,
   356  			writeToml:     true,
   357  			writeYaml:     true,
   358  			writeYml:      true,
   359  			expectedLevel: zapcore.DebugLevel,
   360  		},
   361  		{
   362  			name:          "TOML is used if no JSON present",
   363  			writeJson:     false,
   364  			writeToml:     true,
   365  			writeYaml:     true,
   366  			writeYml:      true,
   367  			expectedLevel: zapcore.InfoLevel,
   368  		},
   369  		{
   370  			name:          "YAML is used if no JSON or TOML present",
   371  			writeJson:     false,
   372  			writeToml:     false,
   373  			writeYaml:     true,
   374  			writeYml:      true,
   375  			expectedLevel: zapcore.WarnLevel,
   376  		},
   377  		{
   378  			name:          "YML is used if no other option present",
   379  			writeJson:     false,
   380  			writeToml:     false,
   381  			writeYaml:     false,
   382  			writeYml:      true,
   383  			expectedLevel: zapcore.ErrorLevel,
   384  		},
   385  	}
   386  
   387  	for _, tt := range tests {
   388  		fn := func(t *testing.T) {
   389  			testDir := t.TempDir()
   390  			defer setEnvVar("TEST_CONFIG_PATH", testDir)()
   391  
   392  			if tt.writeJson {
   393  				_, err := writeJsonConfig(testDir, jsonConfig)
   394  				require.NoError(t, err)
   395  			}
   396  			if tt.writeToml {
   397  				_, err := writeTomlConfig(testDir, tomlConfig)
   398  				require.NoError(t, err)
   399  			}
   400  			if tt.writeYaml {
   401  				_, err := yamlConfigWriter(false)(testDir, yamlConfig)
   402  				require.NoError(t, err)
   403  			}
   404  			if tt.writeYml {
   405  				_, err := yamlConfigWriter(true)(testDir, ymlConfig)
   406  				require.NoError(t, err)
   407  			}
   408  
   409  			var logLevel zapcore.Level
   410  			program := &Program{
   411  				Name: "test",
   412  				Opts: []Opt{
   413  					{
   414  						DestP: &logLevel,
   415  						Flag:  "log-level",
   416  					},
   417  				},
   418  				Run: func() error { return nil },
   419  			}
   420  
   421  			cmd, err := NewCommand(viper.New(), program)
   422  			require.NoError(t, err)
   423  			cmd.SetArgs([]string{})
   424  			require.NoError(t, cmd.Execute())
   425  
   426  			require.Equal(t, tt.expectedLevel, logLevel)
   427  		}
   428  
   429  		t.Run(tt.name, fn)
   430  	}
   431  }
   432  
   433  func Test_ConfigPathDotDirectory(t *testing.T) {
   434  	testDir := t.TempDir()
   435  
   436  	tests := []struct {
   437  		name string
   438  		dir  string
   439  	}{
   440  		{
   441  			name: "dot at start",
   442  			dir:  ".directory",
   443  		},
   444  		{
   445  			name: "dot in middle",
   446  			dir:  "config.d",
   447  		},
   448  		{
   449  			name: "dot at end",
   450  			dir:  "forgotmyextension.",
   451  		},
   452  	}
   453  
   454  	config := map[string]string{
   455  		"foo": "bar",
   456  	}
   457  
   458  	for _, tc := range tests {
   459  		t.Run(tc.name, func(t *testing.T) {
   460  			configDir := filepath.Join(testDir, tc.dir)
   461  			require.NoError(t, os.Mkdir(configDir, 0700))
   462  
   463  			_, err := writeTomlConfig(configDir, config)
   464  			require.NoError(t, err)
   465  			defer setEnvVar("TEST_CONFIG_PATH", configDir)()
   466  
   467  			var foo string
   468  			program := &Program{
   469  				Name: "test",
   470  				Opts: []Opt{
   471  					{
   472  						DestP: &foo,
   473  						Flag:  "foo",
   474  					},
   475  				},
   476  				Run: func() error { return nil },
   477  			}
   478  
   479  			cmd, err := NewCommand(viper.New(), program)
   480  			require.NoError(t, err)
   481  			cmd.SetArgs([]string{})
   482  			require.NoError(t, cmd.Execute())
   483  
   484  			require.Equal(t, "bar", foo)
   485  		})
   486  	}
   487  }
   488  
   489  func Test_LoadConfigCwd(t *testing.T) {
   490  	testDir := t.TempDir()
   491  
   492  	pwd, err := os.Getwd()
   493  	require.NoError(t, err)
   494  	defer os.Chdir(pwd)
   495  
   496  	require.NoError(t, os.Chdir(testDir))
   497  
   498  	config := map[string]string{
   499  		"foo": "bar",
   500  	}
   501  	_, err = writeJsonConfig(testDir, config)
   502  	require.NoError(t, err)
   503  
   504  	var foo string
   505  	program := &Program{
   506  		Name: "test",
   507  		Opts: []Opt{
   508  			{
   509  				DestP: &foo,
   510  				Flag:  "foo",
   511  			},
   512  		},
   513  		Run: func() error { return nil },
   514  	}
   515  
   516  	cmd, err := NewCommand(viper.New(), program)
   517  	require.NoError(t, err)
   518  	cmd.SetArgs([]string{})
   519  	require.NoError(t, cmd.Execute())
   520  
   521  	require.Equal(t, "bar", foo)
   522  }