github.com/badrootd/celestia-core@v0.0.0-20240305091328-aa4207a4b25d/cmd/cometbft/commands/root_test.go (about)

     1  package commands
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path/filepath"
     7  	"strconv"
     8  	"testing"
     9  
    10  	"github.com/spf13/cobra"
    11  	"github.com/spf13/viper"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	cfg "github.com/badrootd/celestia-core/config"
    16  	"github.com/badrootd/celestia-core/libs/cli"
    17  	cmtos "github.com/badrootd/celestia-core/libs/os"
    18  )
    19  
    20  var defaultRoot = os.ExpandEnv("$HOME/.some/test/dir")
    21  
    22  // clearConfig clears env vars, the given root dir, and resets viper.
    23  func clearConfig(dir string) {
    24  	if err := os.Unsetenv("CMTHOME"); err != nil {
    25  		panic(err)
    26  	}
    27  	if err := os.Unsetenv("CMT_HOME"); err != nil {
    28  		panic(err)
    29  	}
    30  	if err := os.Unsetenv("TMHOME"); err != nil {
    31  		//XXX: Deprecated.
    32  		panic(err)
    33  	}
    34  	if err := os.Unsetenv("TM_HOME"); err != nil {
    35  		//XXX: Deprecated.
    36  		panic(err)
    37  	}
    38  
    39  	if err := os.RemoveAll(dir); err != nil {
    40  		panic(err)
    41  	}
    42  	viper.Reset()
    43  	config = cfg.DefaultConfig()
    44  }
    45  
    46  // prepare new rootCmd
    47  func testRootCmd() *cobra.Command {
    48  	rootCmd := &cobra.Command{
    49  		Use:               RootCmd.Use,
    50  		PersistentPreRunE: RootCmd.PersistentPreRunE,
    51  		Run:               func(cmd *cobra.Command, args []string) {},
    52  	}
    53  	registerFlagsRootCmd(rootCmd)
    54  	var l string
    55  	rootCmd.PersistentFlags().String("log", l, "Log")
    56  	return rootCmd
    57  }
    58  
    59  func testSetup(rootDir string, args []string, env map[string]string) error {
    60  	clearConfig(defaultRoot)
    61  
    62  	rootCmd := testRootCmd()
    63  	cmd := cli.PrepareBaseCmd(rootCmd, "CMT", defaultRoot)
    64  
    65  	// run with the args and env
    66  	args = append([]string{rootCmd.Use}, args...)
    67  	return cli.RunWithArgs(cmd, args, env)
    68  }
    69  
    70  func TestRootHome(t *testing.T) {
    71  	newRoot := filepath.Join(defaultRoot, "something-else")
    72  	cases := []struct {
    73  		args []string
    74  		env  map[string]string
    75  		root string
    76  	}{
    77  		{nil, nil, defaultRoot},
    78  		{[]string{"--home", newRoot}, nil, newRoot},
    79  		{nil, map[string]string{"TMHOME": newRoot}, newRoot}, //XXX: Deprecated.
    80  		{nil, map[string]string{"CMTHOME": newRoot}, newRoot},
    81  	}
    82  
    83  	for i, tc := range cases {
    84  		idxString := strconv.Itoa(i)
    85  
    86  		err := testSetup(defaultRoot, tc.args, tc.env)
    87  		require.Nil(t, err, idxString)
    88  
    89  		assert.Equal(t, tc.root, config.RootDir, idxString)
    90  		assert.Equal(t, tc.root, config.P2P.RootDir, idxString)
    91  		assert.Equal(t, tc.root, config.Consensus.RootDir, idxString)
    92  		assert.Equal(t, tc.root, config.Mempool.RootDir, idxString)
    93  	}
    94  }
    95  
    96  func TestRootFlagsEnv(t *testing.T) {
    97  	// defaults
    98  	defaults := cfg.DefaultConfig()
    99  	defaultLogLvl := defaults.LogLevel
   100  
   101  	cases := []struct {
   102  		args     []string
   103  		env      map[string]string
   104  		logLevel string
   105  	}{
   106  		{[]string{"--log", "debug"}, nil, defaultLogLvl},                  // wrong flag
   107  		{[]string{"--log_level", "debug"}, nil, "debug"},                  // right flag
   108  		{nil, map[string]string{"TM_LOW": "debug"}, defaultLogLvl},        // wrong env flag
   109  		{nil, map[string]string{"MT_LOG_LEVEL": "debug"}, defaultLogLvl},  // wrong env prefix
   110  		{nil, map[string]string{"TM_LOG_LEVEL": "debug"}, defaultLogLvl},  // right, but deprecated env
   111  		{nil, map[string]string{"CMT_LOW": "debug"}, defaultLogLvl},       // wrong env flag
   112  		{nil, map[string]string{"TMC_LOG_LEVEL": "debug"}, defaultLogLvl}, // wrong env prefix
   113  		{nil, map[string]string{"CMT_LOG_LEVEL": "debug"}, "debug"},       // right env
   114  	}
   115  
   116  	for i, tc := range cases {
   117  		idxString := strconv.Itoa(i)
   118  
   119  		err := testSetup(defaultRoot, tc.args, tc.env)
   120  		require.Nil(t, err, idxString)
   121  
   122  		assert.Equal(t, tc.logLevel, config.LogLevel, idxString)
   123  	}
   124  }
   125  
   126  func TestRootConfig(t *testing.T) {
   127  	// write non-default config
   128  	nonDefaultLogLvl := "abc:debug"
   129  	cvals := map[string]string{
   130  		"log_level": nonDefaultLogLvl,
   131  	}
   132  
   133  	cases := []struct {
   134  		args []string
   135  		env  map[string]string
   136  
   137  		logLvl string
   138  	}{
   139  		{nil, nil, nonDefaultLogLvl},                                           // should load config
   140  		{[]string{"--log_level=abc:info"}, nil, "abc:info"},                    // flag over rides
   141  		{nil, map[string]string{"TM_LOG_LEVEL": "abc:info"}, nonDefaultLogLvl}, // env over rides //XXX: Deprecated
   142  		{nil, map[string]string{"CMT_LOG_LEVEL": "abc:info"}, "abc:info"},      // env over rides
   143  	}
   144  
   145  	for i, tc := range cases {
   146  		idxString := strconv.Itoa(i)
   147  		clearConfig(defaultRoot)
   148  
   149  		// XXX: path must match cfg.defaultConfigPath
   150  		configFilePath := filepath.Join(defaultRoot, "config")
   151  		err := cmtos.EnsureDir(configFilePath, 0o700)
   152  		require.Nil(t, err)
   153  
   154  		// write the non-defaults to a different path
   155  		// TODO: support writing sub configs so we can test that too
   156  		err = WriteConfigVals(configFilePath, cvals)
   157  		require.Nil(t, err)
   158  
   159  		rootCmd := testRootCmd()
   160  		cmd := cli.PrepareBaseCmd(rootCmd, "CMT", defaultRoot)
   161  
   162  		// run with the args and env
   163  		tc.args = append([]string{rootCmd.Use}, tc.args...)
   164  		err = cli.RunWithArgs(cmd, tc.args, tc.env)
   165  		require.Nil(t, err, idxString)
   166  
   167  		assert.Equal(t, tc.logLvl, config.LogLevel, idxString)
   168  	}
   169  }
   170  
   171  // WriteConfigVals writes a toml file with the given values.
   172  // It returns an error if writing was impossible.
   173  func WriteConfigVals(dir string, vals map[string]string) error {
   174  	data := ""
   175  	for k, v := range vals {
   176  		data += fmt.Sprintf("%s = \"%s\"\n", k, v)
   177  	}
   178  	cfile := filepath.Join(dir, "config.toml")
   179  	return os.WriteFile(cfile, []byte(data), 0o600)
   180  }