github.com/slackhq/nebula@v1.9.0/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"os"
     9  	"os/signal"
    10  	"path/filepath"
    11  	"sort"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"syscall"
    16  	"time"
    17  
    18  	"dario.cat/mergo"
    19  	"github.com/sirupsen/logrus"
    20  	"gopkg.in/yaml.v2"
    21  )
    22  
    23  type C struct {
    24  	path        string
    25  	files       []string
    26  	Settings    map[interface{}]interface{}
    27  	oldSettings map[interface{}]interface{}
    28  	callbacks   []func(*C)
    29  	l           *logrus.Logger
    30  	reloadLock  sync.Mutex
    31  }
    32  
    33  func NewC(l *logrus.Logger) *C {
    34  	return &C{
    35  		Settings: make(map[interface{}]interface{}),
    36  		l:        l,
    37  	}
    38  }
    39  
    40  // Load will find all yaml files within path and load them in lexical order
    41  func (c *C) Load(path string) error {
    42  	c.path = path
    43  	c.files = make([]string, 0)
    44  
    45  	err := c.resolve(path, true)
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	if len(c.files) == 0 {
    51  		return fmt.Errorf("no config files found at %s", path)
    52  	}
    53  
    54  	sort.Strings(c.files)
    55  
    56  	err = c.parse()
    57  	if err != nil {
    58  		return err
    59  	}
    60  
    61  	return nil
    62  }
    63  
    64  func (c *C) LoadString(raw string) error {
    65  	if raw == "" {
    66  		return errors.New("Empty configuration")
    67  	}
    68  	return c.parseRaw([]byte(raw))
    69  }
    70  
    71  // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
    72  // here should decide if they need to make a change to the current process before making the change. HasChanged can be
    73  // used to help decide if a change is necessary.
    74  // These functions should return quickly or spawn their own go routine if they will take a while
    75  func (c *C) RegisterReloadCallback(f func(*C)) {
    76  	c.callbacks = append(c.callbacks, f)
    77  }
    78  
    79  // InitialLoad returns true if this is the first load of the config, and ReloadConfig has not been called yet.
    80  func (c *C) InitialLoad() bool {
    81  	return c.oldSettings == nil
    82  }
    83  
    84  // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
    85  // k in both the old and new settings will be serialized, the result of the string comparison is returned.
    86  // If k is an empty string the entire config is tested.
    87  // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
    88  // there is change when there actually wasn't any.
    89  func (c *C) HasChanged(k string) bool {
    90  	if c.oldSettings == nil {
    91  		return false
    92  	}
    93  
    94  	var (
    95  		nv interface{}
    96  		ov interface{}
    97  	)
    98  
    99  	if k == "" {
   100  		nv = c.Settings
   101  		ov = c.oldSettings
   102  		k = "all settings"
   103  	} else {
   104  		nv = c.get(k, c.Settings)
   105  		ov = c.get(k, c.oldSettings)
   106  	}
   107  
   108  	newVals, err := yaml.Marshal(nv)
   109  	if err != nil {
   110  		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
   111  	}
   112  
   113  	oldVals, err := yaml.Marshal(ov)
   114  	if err != nil {
   115  		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
   116  	}
   117  
   118  	return string(newVals) != string(oldVals)
   119  }
   120  
   121  // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
   122  // original path provided to Load. The old settings are shallow copied for change detection after the reload.
   123  func (c *C) CatchHUP(ctx context.Context) {
   124  	if c.path == "" {
   125  		return
   126  	}
   127  
   128  	ch := make(chan os.Signal, 1)
   129  	signal.Notify(ch, syscall.SIGHUP)
   130  
   131  	go func() {
   132  		for {
   133  			select {
   134  			case <-ctx.Done():
   135  				signal.Stop(ch)
   136  				close(ch)
   137  				return
   138  			case <-ch:
   139  				c.l.Info("Caught HUP, reloading config")
   140  				c.ReloadConfig()
   141  			}
   142  		}
   143  	}()
   144  }
   145  
   146  func (c *C) ReloadConfig() {
   147  	c.reloadLock.Lock()
   148  	defer c.reloadLock.Unlock()
   149  
   150  	c.oldSettings = make(map[interface{}]interface{})
   151  	for k, v := range c.Settings {
   152  		c.oldSettings[k] = v
   153  	}
   154  
   155  	err := c.Load(c.path)
   156  	if err != nil {
   157  		c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
   158  		return
   159  	}
   160  
   161  	for _, v := range c.callbacks {
   162  		v(c)
   163  	}
   164  }
   165  
   166  func (c *C) ReloadConfigString(raw string) error {
   167  	c.reloadLock.Lock()
   168  	defer c.reloadLock.Unlock()
   169  
   170  	c.oldSettings = make(map[interface{}]interface{})
   171  	for k, v := range c.Settings {
   172  		c.oldSettings[k] = v
   173  	}
   174  
   175  	err := c.LoadString(raw)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	for _, v := range c.callbacks {
   181  		v(c)
   182  	}
   183  
   184  	return nil
   185  }
   186  
   187  // GetString will get the string for k or return the default d if not found or invalid
   188  func (c *C) GetString(k, d string) string {
   189  	r := c.Get(k)
   190  	if r == nil {
   191  		return d
   192  	}
   193  
   194  	return fmt.Sprintf("%v", r)
   195  }
   196  
   197  // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
   198  func (c *C) GetStringSlice(k string, d []string) []string {
   199  	r := c.Get(k)
   200  	if r == nil {
   201  		return d
   202  	}
   203  
   204  	rv, ok := r.([]interface{})
   205  	if !ok {
   206  		return d
   207  	}
   208  
   209  	v := make([]string, len(rv))
   210  	for i := 0; i < len(v); i++ {
   211  		v[i] = fmt.Sprintf("%v", rv[i])
   212  	}
   213  
   214  	return v
   215  }
   216  
   217  // GetMap will get the map for k or return the default d if not found or invalid
   218  func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
   219  	r := c.Get(k)
   220  	if r == nil {
   221  		return d
   222  	}
   223  
   224  	v, ok := r.(map[interface{}]interface{})
   225  	if !ok {
   226  		return d
   227  	}
   228  
   229  	return v
   230  }
   231  
   232  // GetInt will get the int for k or return the default d if not found or invalid
   233  func (c *C) GetInt(k string, d int) int {
   234  	r := c.GetString(k, strconv.Itoa(d))
   235  	v, err := strconv.Atoi(r)
   236  	if err != nil {
   237  		return d
   238  	}
   239  
   240  	return v
   241  }
   242  
   243  // GetUint32 will get the uint32 for k or return the default d if not found or invalid
   244  func (c *C) GetUint32(k string, d uint32) uint32 {
   245  	r := c.GetInt(k, int(d))
   246  	if uint64(r) > uint64(math.MaxUint32) {
   247  		return d
   248  	}
   249  	return uint32(r)
   250  }
   251  
   252  // GetBool will get the bool for k or return the default d if not found or invalid
   253  func (c *C) GetBool(k string, d bool) bool {
   254  	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
   255  	v, err := strconv.ParseBool(r)
   256  	if err != nil {
   257  		switch r {
   258  		case "y", "yes":
   259  			return true
   260  		case "n", "no":
   261  			return false
   262  		}
   263  		return d
   264  	}
   265  
   266  	return v
   267  }
   268  
   269  // GetDuration will get the duration for k or return the default d if not found or invalid
   270  func (c *C) GetDuration(k string, d time.Duration) time.Duration {
   271  	r := c.GetString(k, "")
   272  	v, err := time.ParseDuration(r)
   273  	if err != nil {
   274  		return d
   275  	}
   276  	return v
   277  }
   278  
   279  func (c *C) Get(k string) interface{} {
   280  	return c.get(k, c.Settings)
   281  }
   282  
   283  func (c *C) IsSet(k string) bool {
   284  	return c.get(k, c.Settings) != nil
   285  }
   286  
   287  func (c *C) get(k string, v interface{}) interface{} {
   288  	parts := strings.Split(k, ".")
   289  	for _, p := range parts {
   290  		m, ok := v.(map[interface{}]interface{})
   291  		if !ok {
   292  			return nil
   293  		}
   294  
   295  		v, ok = m[p]
   296  		if !ok {
   297  			return nil
   298  		}
   299  	}
   300  
   301  	return v
   302  }
   303  
   304  // direct signifies if this is the config path directly specified by the user,
   305  // versus a file/dir found by recursing into that path
   306  func (c *C) resolve(path string, direct bool) error {
   307  	i, err := os.Stat(path)
   308  	if err != nil {
   309  		return nil
   310  	}
   311  
   312  	if !i.IsDir() {
   313  		c.addFile(path, direct)
   314  		return nil
   315  	}
   316  
   317  	paths, err := readDirNames(path)
   318  	if err != nil {
   319  		return fmt.Errorf("problem while reading directory %s: %s", path, err)
   320  	}
   321  
   322  	for _, p := range paths {
   323  		err := c.resolve(filepath.Join(path, p), false)
   324  		if err != nil {
   325  			return err
   326  		}
   327  	}
   328  
   329  	return nil
   330  }
   331  
   332  func (c *C) addFile(path string, direct bool) error {
   333  	ext := filepath.Ext(path)
   334  
   335  	if !direct && ext != ".yaml" && ext != ".yml" {
   336  		return nil
   337  	}
   338  
   339  	ap, err := filepath.Abs(path)
   340  	if err != nil {
   341  		return err
   342  	}
   343  
   344  	c.files = append(c.files, ap)
   345  	return nil
   346  }
   347  
   348  func (c *C) parseRaw(b []byte) error {
   349  	var m map[interface{}]interface{}
   350  
   351  	err := yaml.Unmarshal(b, &m)
   352  	if err != nil {
   353  		return err
   354  	}
   355  
   356  	c.Settings = m
   357  	return nil
   358  }
   359  
   360  func (c *C) parse() error {
   361  	var m map[interface{}]interface{}
   362  
   363  	for _, path := range c.files {
   364  		b, err := os.ReadFile(path)
   365  		if err != nil {
   366  			return err
   367  		}
   368  
   369  		var nm map[interface{}]interface{}
   370  		err = yaml.Unmarshal(b, &nm)
   371  		if err != nil {
   372  			return err
   373  		}
   374  
   375  		// We need to use WithAppendSlice so that firewall rules in separate
   376  		// files are appended together
   377  		err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
   378  		m = nm
   379  		if err != nil {
   380  			return err
   381  		}
   382  	}
   383  
   384  	c.Settings = m
   385  	return nil
   386  }
   387  
   388  func readDirNames(path string) ([]string, error) {
   389  	f, err := os.Open(path)
   390  	if err != nil {
   391  		return nil, err
   392  	}
   393  
   394  	paths, err := f.Readdirnames(-1)
   395  	f.Close()
   396  	if err != nil {
   397  		return nil, err
   398  	}
   399  
   400  	sort.Strings(paths)
   401  	return paths, nil
   402  }