github.com/gk008/viper@v1.0.2/overrides_test.go (about) 1 package viper 2 3 import ( 4 "fmt" 5 "strings" 6 "testing" 7 8 "github.com/spf13/cast" 9 "github.com/stretchr/testify/assert" 10 ) 11 12 type layer int 13 14 const ( 15 defaultLayer layer = iota + 1 16 overrideLayer 17 ) 18 19 func TestNestedOverrides(t *testing.T) { 20 assert := assert.New(t) 21 var v *Viper 22 23 // Case 0: value overridden by a value 24 overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20 25 override(assert, "tom", 10, "tom", 20) // "tom" is first given value 10, then overridden by 20 26 overrideDefault(assert, "tom.age", 10, "tom.age", 20) 27 override(assert, "tom.age", 10, "tom.age", 20) 28 overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20) 29 override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20) 30 31 // Case 1: key:value overridden by a value 32 v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy" 33 assert.Nil(v.Get("tom.age")) // "tom.age" should not exist anymore 34 v = override(assert, "tom.age", 10, "tom", "boy") 35 assert.Nil(v.Get("tom.age")) 36 37 // Case 2: value overridden by a key:value 38 overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10} 39 override(assert, "tom.age", 10, "tom", "boy") 40 41 // Case 3: key:value overridden by a key:value 42 v = overrideDefault(assert, "tom.size", 4, "tom.age", 10) 43 assert.Equal(4, v.Get("tom.size")) // value should still be reachable 44 v = override(assert, "tom.size", 4, "tom.age", 10) 45 assert.Equal(4, v.Get("tom.size")) 46 deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) 47 48 // Case 4: key:value overridden by a map 49 v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} 50 assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable 51 assert.Equal(10, v.Get("tom.age")) // new value should be there 52 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there 53 v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) 54 assert.Nil(v.Get("tom.size")) 55 assert.Equal(10, v.Get("tom.age")) 56 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) 57 58 // Case 5: array overridden by a value 59 overrideDefault(assert, "tom", []int{10, 20}, "tom", 30) 60 override(assert, "tom", []int{10, 20}, "tom", 30) 61 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30) 62 override(assert, "tom.age", []int{10, 20}, "tom.age", 30) 63 64 // Case 6: array overridden by an array 65 overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40}) 66 override(assert, "tom", []int{10, 20}, "tom", []int{30, 40}) 67 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40}) 68 v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40}) 69 // explicit array merge: 70 s, ok := v.Get("tom.age").([]int) 71 if assert.True(ok, "tom[\"age\"] is not a slice") { 72 v.Set("tom.age", append(s, []int{50, 60}...)) 73 assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age")) 74 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60}) 75 } 76 } 77 78 func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper { 79 return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue) 80 } 81 func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper { 82 return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue) 83 } 84 85 // overrideFromLayer performs the sequential override and low-level checks. 86 // 87 // First assignment is made on layer l for path firstPath with value firstValue, 88 // the second one on the override layer (i.e., with the Set() function) 89 // for path secondPath with value secondValue. 90 // 91 // firstPath and secondPath can include an arbitrary number of dots to indicate 92 // a nested element. 93 // 94 // After each assignment, the value is checked, retrieved both by its full path 95 // and by its key sequence (successive maps). 96 func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper { 97 v := New() 98 firstKeys := strings.Split(firstPath, v.keyDelim) 99 if assert == nil || 100 len(firstKeys) == 0 || len(firstKeys[0]) == 0 { 101 return v 102 } 103 104 // Set and check first value 105 switch l { 106 case defaultLayer: 107 v.SetDefault(firstPath, firstValue) 108 case overrideLayer: 109 v.Set(firstPath, firstValue) 110 default: 111 return v 112 } 113 assert.Equal(firstValue, v.Get(firstPath)) 114 deepCheckValue(assert, v, l, firstKeys, firstValue) 115 116 // Override and check new value 117 secondKeys := strings.Split(secondPath, v.keyDelim) 118 if len(secondKeys) == 0 || len(secondKeys[0]) == 0 { 119 return v 120 } 121 v.Set(secondPath, secondValue) 122 assert.Equal(secondValue, v.Get(secondPath)) 123 deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue) 124 125 return v 126 } 127 128 // deepCheckValue checks that all given keys correspond to a valid path in the 129 // configuration map of the given layer, and that the final value equals the one given 130 func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) { 131 if assert == nil || v == nil || 132 len(keys) == 0 || len(keys[0]) == 0 { 133 return 134 } 135 136 // init 137 var val interface{} 138 var ms string 139 switch l { 140 case defaultLayer: 141 val = v.defaults 142 ms = "v.defaults" 143 case overrideLayer: 144 val = v.override 145 ms = "v.override" 146 } 147 148 // loop through map 149 var m map[string]interface{} 150 err := false 151 for _, k := range keys { 152 if val == nil { 153 assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms)) 154 return 155 } 156 157 // deep scan of the map to get the final value 158 switch val.(type) { 159 case map[interface{}]interface{}: 160 m = cast.ToStringMap(val) 161 case map[string]interface{}: 162 m = val.(map[string]interface{}) 163 default: 164 assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms)) 165 return 166 } 167 ms = ms + "[\"" + k + "\"]" 168 val = m[k] 169 } 170 if !err { 171 assert.Equal(value, val) 172 } 173 }