github.com/TaylorOno/viper@v1.1.1/overrides_test.go (about)

     1  package viper
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"testing"
     7  
     8  	"github.com/spf13/cast"
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  type layer int
    13  
    14  const (
    15  	defaultLayer layer = iota + 1
    16  	overrideLayer
    17  )
    18  
    19  func TestNestedOverrides(t *testing.T) {
    20  	assert := assert.New(t)
    21  	var v *Viper
    22  
    23  	// Case 0: value overridden by a value
    24  	overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20
    25  	override(assert, "tom", 10, "tom", 20)        // "tom" is first given value 10, then overridden by 20
    26  	overrideDefault(assert, "tom.age", 10, "tom.age", 20)
    27  	override(assert, "tom.age", 10, "tom.age", 20)
    28  	overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
    29  	override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
    30  
    31  	// Case 1: key:value overridden by a value
    32  	v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy"
    33  	assert.Nil(v.Get("tom.age"))                             // "tom.age" should not exist anymore
    34  	v = override(assert, "tom.age", 10, "tom", "boy")
    35  	assert.Nil(v.Get("tom.age"))
    36  
    37  	// Case 2: value overridden by a key:value
    38  	overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10}
    39  	override(assert, "tom.age", 10, "tom", "boy")
    40  
    41  	// Case 3: key:value overridden by a key:value
    42  	v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
    43  	assert.Equal(4, v.Get("tom.size")) // value should still be reachable
    44  	v = override(assert, "tom.size", 4, "tom.age", 10)
    45  	assert.Equal(4, v.Get("tom.size"))
    46  	deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
    47  
    48  	// Case 4: key:value overridden by a map
    49  	v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
    50  	assert.Equal(4, v.Get("tom.size"))                                                   // "tom.size" should still be reachable
    51  	assert.Equal(10, v.Get("tom.age"))                                                   // new value should be there
    52  	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)                 // new value should be there
    53  	v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
    54  	assert.Nil(v.Get("tom.size"))
    55  	assert.Equal(10, v.Get("tom.age"))
    56  	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
    57  
    58  	// Case 5: array overridden by a value
    59  	overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
    60  	override(assert, "tom", []int{10, 20}, "tom", 30)
    61  	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
    62  	override(assert, "tom.age", []int{10, 20}, "tom.age", 30)
    63  
    64  	// Case 6: array overridden by an array
    65  	overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
    66  	override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
    67  	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
    68  	v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
    69  	// explicit array merge:
    70  	s, ok := v.Get("tom.age").([]int)
    71  	if assert.True(ok, "tom[\"age\"] is not a slice") {
    72  		v.Set("tom.age", append(s, []int{50, 60}...))
    73  		assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
    74  		deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
    75  	}
    76  }
    77  
    78  func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
    79  	return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
    80  }
    81  func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
    82  	return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
    83  }
    84  
    85  // overrideFromLayer performs the sequential override and low-level checks.
    86  //
    87  // First assignment is made on layer l for path firstPath with value firstValue,
    88  // the second one on the override layer (i.e., with the Set() function)
    89  // for path secondPath with value secondValue.
    90  //
    91  // firstPath and secondPath can include an arbitrary number of dots to indicate
    92  // a nested element.
    93  //
    94  // After each assignment, the value is checked, retrieved both by its full path
    95  // and by its key sequence (successive maps).
    96  func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
    97  	v := New()
    98  	firstKeys := strings.Split(firstPath, v.keyDelim)
    99  	if assert == nil ||
   100  		len(firstKeys) == 0 || len(firstKeys[0]) == 0 {
   101  		return v
   102  	}
   103  
   104  	// Set and check first value
   105  	switch l {
   106  	case defaultLayer:
   107  		v.SetDefault(firstPath, firstValue)
   108  	case overrideLayer:
   109  		v.Set(firstPath, firstValue)
   110  	default:
   111  		return v
   112  	}
   113  	assert.Equal(firstValue, v.Get(firstPath))
   114  	deepCheckValue(assert, v, l, firstKeys, firstValue)
   115  
   116  	// Override and check new value
   117  	secondKeys := strings.Split(secondPath, v.keyDelim)
   118  	if len(secondKeys) == 0 || len(secondKeys[0]) == 0 {
   119  		return v
   120  	}
   121  	v.Set(secondPath, secondValue)
   122  	assert.Equal(secondValue, v.Get(secondPath))
   123  	deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)
   124  
   125  	return v
   126  }
   127  
   128  // deepCheckValue checks that all given keys correspond to a valid path in the
   129  // configuration map of the given layer, and that the final value equals the one given
   130  func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) {
   131  	if assert == nil || v == nil ||
   132  		len(keys) == 0 || len(keys[0]) == 0 {
   133  		return
   134  	}
   135  
   136  	// init
   137  	var val interface{}
   138  	var ms string
   139  	switch l {
   140  	case defaultLayer:
   141  		val = v.defaults
   142  		ms = "v.defaults"
   143  	case overrideLayer:
   144  		val = v.override
   145  		ms = "v.override"
   146  	}
   147  
   148  	// loop through map
   149  	var m map[string]interface{}
   150  	err := false
   151  	for _, k := range keys {
   152  		if val == nil {
   153  			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
   154  			return
   155  		}
   156  
   157  		// deep scan of the map to get the final value
   158  		switch val.(type) {
   159  		case map[interface{}]interface{}:
   160  			m = cast.ToStringMap(val)
   161  		case map[string]interface{}:
   162  			m = val.(map[string]interface{})
   163  		default:
   164  			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
   165  			return
   166  		}
   167  		ms = ms + "[\"" + k + "\"]"
   168  		val = m[k]
   169  	}
   170  	if !err {
   171  		assert.Equal(value, val)
   172  	}
   173  }