github.com/gozelle/viper@v1.14.0/overrides_test.go (about) 1 package viper 2 3 import ( 4 "fmt" 5 "strings" 6 "testing" 7 8 "github.com/spf13/cast" 9 "github.com/stretchr/testify/assert" 10 ) 11 12 type layer int 13 14 const ( 15 defaultLayer layer = iota + 1 16 overrideLayer 17 ) 18 19 func TestNestedOverrides(t *testing.T) { 20 assert := assert.New(t) 21 var v *Viper 22 23 // Case 0: value overridden by a value 24 overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20 25 override(assert, "tom", 10, "tom", 20) // "tom" is first given value 10, then overridden by 20 26 overrideDefault(assert, "tom.age", 10, "tom.age", 20) 27 override(assert, "tom.age", 10, "tom.age", 20) 28 overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20) 29 override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20) 30 31 // Case 1: key:value overridden by a value 32 v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy" 33 assert.Nil(v.Get("tom.age")) // "tom.age" should not exist anymore 34 v = override(assert, "tom.age", 10, "tom", "boy") 35 assert.Nil(v.Get("tom.age")) 36 37 // Case 2: value overridden by a key:value 38 overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10} 39 override(assert, "tom.age", 10, "tom", "boy") 40 41 // Case 3: key:value overridden by a key:value 42 v = overrideDefault(assert, "tom.size", 4, "tom.age", 10) 43 assert.Equal(4, v.Get("tom.size")) // value should still be reachable 44 v = override(assert, "tom.size", 4, "tom.age", 10) 45 assert.Equal(4, v.Get("tom.size")) 46 deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) 47 48 // Case 4: key:value overridden by a map 49 v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} 50 assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable 51 assert.Equal(10, v.Get("tom.age")) // new value should be there 52 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there 53 v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) 54 assert.Nil(v.Get("tom.size")) 55 assert.Equal(10, v.Get("tom.age")) 56 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) 57 58 // Case 5: array overridden by a value 59 overrideDefault(assert, "tom", []int{10, 20}, "tom", 30) 60 override(assert, "tom", []int{10, 20}, "tom", 30) 61 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30) 62 override(assert, "tom.age", []int{10, 20}, "tom.age", 30) 63 64 // Case 6: array overridden by an array 65 overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40}) 66 override(assert, "tom", []int{10, 20}, "tom", []int{30, 40}) 67 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40}) 68 v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40}) 69 // explicit array merge: 70 s, ok := v.Get("tom.age").([]int) 71 if assert.True(ok, "tom[\"age\"] is not a slice") { 72 v.Set("tom.age", append(s, []int{50, 60}...)) 73 assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age")) 74 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60}) 75 } 76 } 77 78 func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper { 79 return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue) 80 } 81 82 func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper { 83 return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue) 84 } 85 86 // overrideFromLayer performs the sequential override and low-level checks. 87 // 88 // First assignment is made on layer l for path firstPath with value firstValue, 89 // the second one on the override layer (i.e., with the Set() function) 90 // for path secondPath with value secondValue. 91 // 92 // firstPath and secondPath can include an arbitrary number of dots to indicate 93 // a nested element. 94 // 95 // After each assignment, the value is checked, retrieved both by its full path 96 // and by its key sequence (successive maps). 97 func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper { 98 v := New() 99 firstKeys := strings.Split(firstPath, v.keyDelim) 100 if assert == nil || 101 len(firstKeys) == 0 || len(firstKeys[0]) == 0 { 102 return v 103 } 104 105 // Set and check first value 106 switch l { 107 case defaultLayer: 108 v.SetDefault(firstPath, firstValue) 109 case overrideLayer: 110 v.Set(firstPath, firstValue) 111 default: 112 return v 113 } 114 assert.Equal(firstValue, v.Get(firstPath)) 115 deepCheckValue(assert, v, l, firstKeys, firstValue) 116 117 // Override and check new value 118 secondKeys := strings.Split(secondPath, v.keyDelim) 119 if len(secondKeys) == 0 || len(secondKeys[0]) == 0 { 120 return v 121 } 122 v.Set(secondPath, secondValue) 123 assert.Equal(secondValue, v.Get(secondPath)) 124 deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue) 125 126 return v 127 } 128 129 // deepCheckValue checks that all given keys correspond to a valid path in the 130 // configuration map of the given layer, and that the final value equals the one given 131 func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) { 132 if assert == nil || v == nil || 133 len(keys) == 0 || len(keys[0]) == 0 { 134 return 135 } 136 137 // init 138 var val interface{} 139 var ms string 140 switch l { 141 case defaultLayer: 142 val = v.defaults 143 ms = "v.defaults" 144 case overrideLayer: 145 val = v.override 146 ms = "v.override" 147 } 148 149 // loop through map 150 var m map[string]interface{} 151 err := false 152 for _, k := range keys { 153 if val == nil { 154 assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms)) 155 return 156 } 157 158 // deep scan of the map to get the final value 159 switch val.(type) { 160 case map[interface{}]interface{}: 161 m = cast.ToStringMap(val) 162 case map[string]interface{}: 163 m = val.(map[string]interface{}) 164 default: 165 assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms)) 166 return 167 } 168 ms = ms + "[\"" + k + "\"]" 169 val = m[k] 170 } 171 if !err { 172 assert.Equal(value, val) 173 } 174 }