github.com/slackhq/nebula@v1.9.0/config/config_test.go (about)

     1  package config
     2  
     3  import (
     4  	"os"
     5  	"path/filepath"
     6  	"testing"
     7  	"time"
     8  
     9  	"dario.cat/mergo"
    10  	"github.com/slackhq/nebula/test"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  	"gopkg.in/yaml.v2"
    14  )
    15  
    16  func TestConfig_Load(t *testing.T) {
    17  	l := test.NewLogger()
    18  	dir, err := os.MkdirTemp("", "config-test")
    19  	// invalid yaml
    20  	c := NewC(l)
    21  	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
    22  	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
    23  
    24  	// simple multi config merge
    25  	c = NewC(l)
    26  	os.RemoveAll(dir)
    27  	os.Mkdir(dir, 0755)
    28  
    29  	assert.Nil(t, err)
    30  
    31  	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
    32  	os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
    33  	assert.Nil(t, c.Load(dir))
    34  	expected := map[interface{}]interface{}{
    35  		"outer": map[interface{}]interface{}{
    36  			"inner": "override",
    37  		},
    38  		"new": "hi",
    39  	}
    40  	assert.Equal(t, expected, c.Settings)
    41  
    42  	//TODO: test symlinked file
    43  	//TODO: test symlinked directory
    44  }
    45  
    46  func TestConfig_Get(t *testing.T) {
    47  	l := test.NewLogger()
    48  	// test simple type
    49  	c := NewC(l)
    50  	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
    51  	assert.Equal(t, "hi", c.Get("firewall.outbound"))
    52  
    53  	// test complex type
    54  	inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
    55  	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
    56  	assert.EqualValues(t, inner, c.Get("firewall.outbound"))
    57  
    58  	// test missing
    59  	assert.Nil(t, c.Get("firewall.nope"))
    60  }
    61  
    62  func TestConfig_GetStringSlice(t *testing.T) {
    63  	l := test.NewLogger()
    64  	c := NewC(l)
    65  	c.Settings["slice"] = []interface{}{"one", "two"}
    66  	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
    67  }
    68  
    69  func TestConfig_GetBool(t *testing.T) {
    70  	l := test.NewLogger()
    71  	c := NewC(l)
    72  	c.Settings["bool"] = true
    73  	assert.Equal(t, true, c.GetBool("bool", false))
    74  
    75  	c.Settings["bool"] = "true"
    76  	assert.Equal(t, true, c.GetBool("bool", false))
    77  
    78  	c.Settings["bool"] = false
    79  	assert.Equal(t, false, c.GetBool("bool", true))
    80  
    81  	c.Settings["bool"] = "false"
    82  	assert.Equal(t, false, c.GetBool("bool", true))
    83  
    84  	c.Settings["bool"] = "Y"
    85  	assert.Equal(t, true, c.GetBool("bool", false))
    86  
    87  	c.Settings["bool"] = "yEs"
    88  	assert.Equal(t, true, c.GetBool("bool", false))
    89  
    90  	c.Settings["bool"] = "N"
    91  	assert.Equal(t, false, c.GetBool("bool", true))
    92  
    93  	c.Settings["bool"] = "nO"
    94  	assert.Equal(t, false, c.GetBool("bool", true))
    95  }
    96  
    97  func TestConfig_HasChanged(t *testing.T) {
    98  	l := test.NewLogger()
    99  	// No reload has occurred, return false
   100  	c := NewC(l)
   101  	c.Settings["test"] = "hi"
   102  	assert.False(t, c.HasChanged(""))
   103  
   104  	// Test key change
   105  	c = NewC(l)
   106  	c.Settings["test"] = "hi"
   107  	c.oldSettings = map[interface{}]interface{}{"test": "no"}
   108  	assert.True(t, c.HasChanged("test"))
   109  	assert.True(t, c.HasChanged(""))
   110  
   111  	// No key change
   112  	c = NewC(l)
   113  	c.Settings["test"] = "hi"
   114  	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
   115  	assert.False(t, c.HasChanged("test"))
   116  	assert.False(t, c.HasChanged(""))
   117  }
   118  
   119  func TestConfig_ReloadConfig(t *testing.T) {
   120  	l := test.NewLogger()
   121  	done := make(chan bool, 1)
   122  	dir, err := os.MkdirTemp("", "config-test")
   123  	assert.Nil(t, err)
   124  	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
   125  
   126  	c := NewC(l)
   127  	assert.Nil(t, c.Load(dir))
   128  
   129  	assert.False(t, c.HasChanged("outer.inner"))
   130  	assert.False(t, c.HasChanged("outer"))
   131  	assert.False(t, c.HasChanged(""))
   132  
   133  	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
   134  
   135  	c.RegisterReloadCallback(func(c *C) {
   136  		done <- true
   137  	})
   138  
   139  	c.ReloadConfig()
   140  	assert.True(t, c.HasChanged("outer.inner"))
   141  	assert.True(t, c.HasChanged("outer"))
   142  	assert.True(t, c.HasChanged(""))
   143  
   144  	// Make sure we call the callbacks
   145  	select {
   146  	case <-done:
   147  	case <-time.After(1 * time.Second):
   148  		panic("timeout")
   149  	}
   150  
   151  }
   152  
   153  // Ensure mergo merges are done the way we expect.
   154  // This is needed to test for potential regressions, like:
   155  // - https://github.com/imdario/mergo/issues/187
   156  func TestConfig_MergoMerge(t *testing.T) {
   157  	configs := [][]byte{
   158  		[]byte(`
   159  listen:
   160    port: 1234
   161  `),
   162  		[]byte(`
   163  firewall:
   164    inbound:
   165      - port: 443
   166        proto: tcp
   167        groups:
   168          - server
   169      - port: 443
   170        proto: tcp
   171        groups:
   172          - webapp
   173  `),
   174  		[]byte(`
   175  listen:
   176    host: 0.0.0.0
   177    port: 4242
   178  firewall:
   179    outbound:
   180      - port: any
   181        proto: any
   182        host: any
   183    inbound:
   184      - port: any
   185        proto: icmp
   186        host: any
   187  `),
   188  	}
   189  
   190  	var m map[any]any
   191  
   192  	// merge the same way config.parse() merges
   193  	for _, b := range configs {
   194  		var nm map[any]any
   195  		err := yaml.Unmarshal(b, &nm)
   196  		require.NoError(t, err)
   197  
   198  		// We need to use WithAppendSlice so that firewall rules in separate
   199  		// files are appended together
   200  		err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
   201  		m = nm
   202  		require.NoError(t, err)
   203  	}
   204  
   205  	t.Logf("Merged Config: %#v", m)
   206  	mYaml, err := yaml.Marshal(m)
   207  	require.NoError(t, err)
   208  	t.Logf("Merged Config as YAML:\n%s", mYaml)
   209  
   210  	// If a bug is present, some items might be replaced instead of merged like we expect
   211  	expected := map[any]any{
   212  		"firewall": map[any]any{
   213  			"inbound": []any{
   214  				map[any]any{"host": "any", "port": "any", "proto": "icmp"},
   215  				map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
   216  				map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
   217  			"outbound": []any{
   218  				map[any]any{"host": "any", "port": "any", "proto": "any"}}},
   219  		"listen": map[any]any{
   220  			"host": "0.0.0.0",
   221  			"port": 4242,
   222  		},
   223  	}
   224  	assert.Equal(t, expected, m)
   225  }