github.com/slackhq/nebula@v1.9.0/config/config.go (about) 1 package config 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "math" 8 "os" 9 "os/signal" 10 "path/filepath" 11 "sort" 12 "strconv" 13 "strings" 14 "sync" 15 "syscall" 16 "time" 17 18 "dario.cat/mergo" 19 "github.com/sirupsen/logrus" 20 "gopkg.in/yaml.v2" 21 ) 22 23 type C struct { 24 path string 25 files []string 26 Settings map[interface{}]interface{} 27 oldSettings map[interface{}]interface{} 28 callbacks []func(*C) 29 l *logrus.Logger 30 reloadLock sync.Mutex 31 } 32 33 func NewC(l *logrus.Logger) *C { 34 return &C{ 35 Settings: make(map[interface{}]interface{}), 36 l: l, 37 } 38 } 39 40 // Load will find all yaml files within path and load them in lexical order 41 func (c *C) Load(path string) error { 42 c.path = path 43 c.files = make([]string, 0) 44 45 err := c.resolve(path, true) 46 if err != nil { 47 return err 48 } 49 50 if len(c.files) == 0 { 51 return fmt.Errorf("no config files found at %s", path) 52 } 53 54 sort.Strings(c.files) 55 56 err = c.parse() 57 if err != nil { 58 return err 59 } 60 61 return nil 62 } 63 64 func (c *C) LoadString(raw string) error { 65 if raw == "" { 66 return errors.New("Empty configuration") 67 } 68 return c.parseRaw([]byte(raw)) 69 } 70 71 // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered 72 // here should decide if they need to make a change to the current process before making the change. HasChanged can be 73 // used to help decide if a change is necessary. 74 // These functions should return quickly or spawn their own go routine if they will take a while 75 func (c *C) RegisterReloadCallback(f func(*C)) { 76 c.callbacks = append(c.callbacks, f) 77 } 78 79 // InitialLoad returns true if this is the first load of the config, and ReloadConfig has not been called yet. 80 func (c *C) InitialLoad() bool { 81 return c.oldSettings == nil 82 } 83 84 // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of 85 // k in both the old and new settings will be serialized, the result of the string comparison is returned. 86 // If k is an empty string the entire config is tested. 87 // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating 88 // there is change when there actually wasn't any. 89 func (c *C) HasChanged(k string) bool { 90 if c.oldSettings == nil { 91 return false 92 } 93 94 var ( 95 nv interface{} 96 ov interface{} 97 ) 98 99 if k == "" { 100 nv = c.Settings 101 ov = c.oldSettings 102 k = "all settings" 103 } else { 104 nv = c.get(k, c.Settings) 105 ov = c.get(k, c.oldSettings) 106 } 107 108 newVals, err := yaml.Marshal(nv) 109 if err != nil { 110 c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") 111 } 112 113 oldVals, err := yaml.Marshal(ov) 114 if err != nil { 115 c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") 116 } 117 118 return string(newVals) != string(oldVals) 119 } 120 121 // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the 122 // original path provided to Load. The old settings are shallow copied for change detection after the reload. 123 func (c *C) CatchHUP(ctx context.Context) { 124 if c.path == "" { 125 return 126 } 127 128 ch := make(chan os.Signal, 1) 129 signal.Notify(ch, syscall.SIGHUP) 130 131 go func() { 132 for { 133 select { 134 case <-ctx.Done(): 135 signal.Stop(ch) 136 close(ch) 137 return 138 case <-ch: 139 c.l.Info("Caught HUP, reloading config") 140 c.ReloadConfig() 141 } 142 } 143 }() 144 } 145 146 func (c *C) ReloadConfig() { 147 c.reloadLock.Lock() 148 defer c.reloadLock.Unlock() 149 150 c.oldSettings = make(map[interface{}]interface{}) 151 for k, v := range c.Settings { 152 c.oldSettings[k] = v 153 } 154 155 err := c.Load(c.path) 156 if err != nil { 157 c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") 158 return 159 } 160 161 for _, v := range c.callbacks { 162 v(c) 163 } 164 } 165 166 func (c *C) ReloadConfigString(raw string) error { 167 c.reloadLock.Lock() 168 defer c.reloadLock.Unlock() 169 170 c.oldSettings = make(map[interface{}]interface{}) 171 for k, v := range c.Settings { 172 c.oldSettings[k] = v 173 } 174 175 err := c.LoadString(raw) 176 if err != nil { 177 return err 178 } 179 180 for _, v := range c.callbacks { 181 v(c) 182 } 183 184 return nil 185 } 186 187 // GetString will get the string for k or return the default d if not found or invalid 188 func (c *C) GetString(k, d string) string { 189 r := c.Get(k) 190 if r == nil { 191 return d 192 } 193 194 return fmt.Sprintf("%v", r) 195 } 196 197 // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid 198 func (c *C) GetStringSlice(k string, d []string) []string { 199 r := c.Get(k) 200 if r == nil { 201 return d 202 } 203 204 rv, ok := r.([]interface{}) 205 if !ok { 206 return d 207 } 208 209 v := make([]string, len(rv)) 210 for i := 0; i < len(v); i++ { 211 v[i] = fmt.Sprintf("%v", rv[i]) 212 } 213 214 return v 215 } 216 217 // GetMap will get the map for k or return the default d if not found or invalid 218 func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { 219 r := c.Get(k) 220 if r == nil { 221 return d 222 } 223 224 v, ok := r.(map[interface{}]interface{}) 225 if !ok { 226 return d 227 } 228 229 return v 230 } 231 232 // GetInt will get the int for k or return the default d if not found or invalid 233 func (c *C) GetInt(k string, d int) int { 234 r := c.GetString(k, strconv.Itoa(d)) 235 v, err := strconv.Atoi(r) 236 if err != nil { 237 return d 238 } 239 240 return v 241 } 242 243 // GetUint32 will get the uint32 for k or return the default d if not found or invalid 244 func (c *C) GetUint32(k string, d uint32) uint32 { 245 r := c.GetInt(k, int(d)) 246 if uint64(r) > uint64(math.MaxUint32) { 247 return d 248 } 249 return uint32(r) 250 } 251 252 // GetBool will get the bool for k or return the default d if not found or invalid 253 func (c *C) GetBool(k string, d bool) bool { 254 r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) 255 v, err := strconv.ParseBool(r) 256 if err != nil { 257 switch r { 258 case "y", "yes": 259 return true 260 case "n", "no": 261 return false 262 } 263 return d 264 } 265 266 return v 267 } 268 269 // GetDuration will get the duration for k or return the default d if not found or invalid 270 func (c *C) GetDuration(k string, d time.Duration) time.Duration { 271 r := c.GetString(k, "") 272 v, err := time.ParseDuration(r) 273 if err != nil { 274 return d 275 } 276 return v 277 } 278 279 func (c *C) Get(k string) interface{} { 280 return c.get(k, c.Settings) 281 } 282 283 func (c *C) IsSet(k string) bool { 284 return c.get(k, c.Settings) != nil 285 } 286 287 func (c *C) get(k string, v interface{}) interface{} { 288 parts := strings.Split(k, ".") 289 for _, p := range parts { 290 m, ok := v.(map[interface{}]interface{}) 291 if !ok { 292 return nil 293 } 294 295 v, ok = m[p] 296 if !ok { 297 return nil 298 } 299 } 300 301 return v 302 } 303 304 // direct signifies if this is the config path directly specified by the user, 305 // versus a file/dir found by recursing into that path 306 func (c *C) resolve(path string, direct bool) error { 307 i, err := os.Stat(path) 308 if err != nil { 309 return nil 310 } 311 312 if !i.IsDir() { 313 c.addFile(path, direct) 314 return nil 315 } 316 317 paths, err := readDirNames(path) 318 if err != nil { 319 return fmt.Errorf("problem while reading directory %s: %s", path, err) 320 } 321 322 for _, p := range paths { 323 err := c.resolve(filepath.Join(path, p), false) 324 if err != nil { 325 return err 326 } 327 } 328 329 return nil 330 } 331 332 func (c *C) addFile(path string, direct bool) error { 333 ext := filepath.Ext(path) 334 335 if !direct && ext != ".yaml" && ext != ".yml" { 336 return nil 337 } 338 339 ap, err := filepath.Abs(path) 340 if err != nil { 341 return err 342 } 343 344 c.files = append(c.files, ap) 345 return nil 346 } 347 348 func (c *C) parseRaw(b []byte) error { 349 var m map[interface{}]interface{} 350 351 err := yaml.Unmarshal(b, &m) 352 if err != nil { 353 return err 354 } 355 356 c.Settings = m 357 return nil 358 } 359 360 func (c *C) parse() error { 361 var m map[interface{}]interface{} 362 363 for _, path := range c.files { 364 b, err := os.ReadFile(path) 365 if err != nil { 366 return err 367 } 368 369 var nm map[interface{}]interface{} 370 err = yaml.Unmarshal(b, &nm) 371 if err != nil { 372 return err 373 } 374 375 // We need to use WithAppendSlice so that firewall rules in separate 376 // files are appended together 377 err = mergo.Merge(&nm, m, mergo.WithAppendSlice) 378 m = nm 379 if err != nil { 380 return err 381 } 382 } 383 384 c.Settings = m 385 return nil 386 } 387 388 func readDirNames(path string) ([]string, error) { 389 f, err := os.Open(path) 390 if err != nil { 391 return nil, err 392 } 393 394 paths, err := f.Readdirnames(-1) 395 f.Close() 396 if err != nil { 397 return nil, err 398 } 399 400 sort.Strings(paths) 401 return paths, nil 402 }