github.com/ava-labs/viper@v1.7.2-0.20210125155433-65cc0421b384/overrides_test.go (about)

     1  package viper
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"testing"
     7  
     8  	"github.com/spf13/cast"
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  type layer int
    13  
    14  const (
    15  	defaultLayer layer = iota + 1
    16  	overrideLayer
    17  )
    18  
    19  func TestNestedOverrides(t *testing.T) {
    20  	assert := assert.New(t)
    21  	var v *Viper
    22  
    23  	// Case 0: value overridden by a value
    24  	overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20
    25  	override(assert, "tom", 10, "tom", 20)        // "tom" is first given value 10, then overridden by 20
    26  	overrideDefault(assert, "tom.age", 10, "tom.age", 20)
    27  	override(assert, "tom.age", 10, "tom.age", 20)
    28  	overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
    29  	override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
    30  
    31  	// Case 1: key:value overridden by a value
    32  	v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy"
    33  	assert.Nil(v.Get("tom.age"))                             // "tom.age" should not exist anymore
    34  	v = override(assert, "tom.age", 10, "tom", "boy")
    35  	assert.Nil(v.Get("tom.age"))
    36  
    37  	// Case 2: value overridden by a key:value
    38  	overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10}
    39  	override(assert, "tom.age", 10, "tom", "boy")
    40  
    41  	// Case 3: key:value overridden by a key:value
    42  	v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
    43  	assert.Equal(4, v.Get("tom.size")) // value should still be reachable
    44  	v = override(assert, "tom.size", 4, "tom.age", 10)
    45  	assert.Equal(4, v.Get("tom.size"))
    46  	deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
    47  
    48  	// Case 4: key:value overridden by a map
    49  	v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
    50  	assert.Equal(4, v.Get("tom.size"))                                                   // "tom.size" should still be reachable
    51  	assert.Equal(10, v.Get("tom.age"))                                                   // new value should be there
    52  	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)                 // new value should be there
    53  	v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
    54  	assert.Nil(v.Get("tom.size"))
    55  	assert.Equal(10, v.Get("tom.age"))
    56  	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
    57  
    58  	// Case 5: array overridden by a value
    59  	overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
    60  	override(assert, "tom", []int{10, 20}, "tom", 30)
    61  	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
    62  	override(assert, "tom.age", []int{10, 20}, "tom.age", 30)
    63  
    64  	// Case 6: array overridden by an array
    65  	overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
    66  	override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
    67  	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
    68  	v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
    69  	// explicit array merge:
    70  	s, ok := v.Get("tom.age").([]int)
    71  	if assert.True(ok, "tom[\"age\"] is not a slice") {
    72  		v.Set("tom.age", append(s, []int{50, 60}...))
    73  		assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
    74  		deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
    75  	}
    76  }
    77  
    78  func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
    79  	return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
    80  }
    81  
    82  func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
    83  	return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
    84  }
    85  
    86  // overrideFromLayer performs the sequential override and low-level checks.
    87  //
    88  // First assignment is made on layer l for path firstPath with value firstValue,
    89  // the second one on the override layer (i.e., with the Set() function)
    90  // for path secondPath with value secondValue.
    91  //
    92  // firstPath and secondPath can include an arbitrary number of dots to indicate
    93  // a nested element.
    94  //
    95  // After each assignment, the value is checked, retrieved both by its full path
    96  // and by its key sequence (successive maps).
    97  func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
    98  	v := New()
    99  	firstKeys := strings.Split(firstPath, v.keyDelim)
   100  	if assert == nil ||
   101  		len(firstKeys) == 0 || len(firstKeys[0]) == 0 {
   102  		return v
   103  	}
   104  
   105  	// Set and check first value
   106  	switch l {
   107  	case defaultLayer:
   108  		v.SetDefault(firstPath, firstValue)
   109  	case overrideLayer:
   110  		v.Set(firstPath, firstValue)
   111  	default:
   112  		return v
   113  	}
   114  	assert.Equal(firstValue, v.Get(firstPath))
   115  	deepCheckValue(assert, v, l, firstKeys, firstValue)
   116  
   117  	// Override and check new value
   118  	secondKeys := strings.Split(secondPath, v.keyDelim)
   119  	if len(secondKeys) == 0 || len(secondKeys[0]) == 0 {
   120  		return v
   121  	}
   122  	v.Set(secondPath, secondValue)
   123  	assert.Equal(secondValue, v.Get(secondPath))
   124  	deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)
   125  
   126  	return v
   127  }
   128  
   129  // deepCheckValue checks that all given keys correspond to a valid path in the
   130  // configuration map of the given layer, and that the final value equals the one given
   131  func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) {
   132  	if assert == nil || v == nil ||
   133  		len(keys) == 0 || len(keys[0]) == 0 {
   134  		return
   135  	}
   136  
   137  	// init
   138  	var val interface{}
   139  	var ms string
   140  	switch l {
   141  	case defaultLayer:
   142  		val = v.defaults
   143  		ms = "v.defaults"
   144  	case overrideLayer:
   145  		val = v.override
   146  		ms = "v.override"
   147  	}
   148  
   149  	// loop through map
   150  	var m map[string]interface{}
   151  	err := false
   152  	for _, k := range keys {
   153  		if val == nil {
   154  			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
   155  			return
   156  		}
   157  
   158  		// deep scan of the map to get the final value
   159  		switch val.(type) {
   160  		case map[interface{}]interface{}:
   161  			m = cast.ToStringMap(val)
   162  		case map[string]interface{}:
   163  			m = val.(map[string]interface{})
   164  		default:
   165  			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
   166  			return
   167  		}
   168  		ms = ms + "[\"" + k + "\"]"
   169  		val = m[k]
   170  	}
   171  	if !err {
   172  		assert.Equal(value, val)
   173  	}
   174  }