github.com/myafeier/fabric@v1.0.1-0.20170722181825-3a4b1f2bce86/common/viperutil/config_util.go (about)

     1  /*
     2  Copyright IBM Corp. 2016 All Rights Reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8                   http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package viperutil
    18  
    19  import (
    20  	"fmt"
    21  	"io/ioutil"
    22  	"math"
    23  	"reflect"
    24  	"regexp"
    25  	"strconv"
    26  	"strings"
    27  	"time"
    28  
    29  	"encoding/json"
    30  	"encoding/pem"
    31  
    32  	"github.com/Shopify/sarama"
    33  	"github.com/hyperledger/fabric/common/flogging"
    34  	"github.com/mitchellh/mapstructure"
    35  	"github.com/spf13/viper"
    36  )
    37  
    38  var logger = flogging.MustGetLogger("viperutil")
    39  
    40  type viperGetter func(key string) interface{}
    41  
    42  func getKeysRecursively(base string, getKey viperGetter, nodeKeys map[string]interface{}) map[string]interface{} {
    43  	result := make(map[string]interface{})
    44  	for key := range nodeKeys {
    45  		fqKey := base + key
    46  		val := getKey(fqKey)
    47  		if m, ok := val.(map[interface{}]interface{}); ok {
    48  			logger.Debugf("Found map[interface{}]interface{} value for %s", fqKey)
    49  			tmp := make(map[string]interface{})
    50  			for ik, iv := range m {
    51  				cik, ok := ik.(string)
    52  				if !ok {
    53  					panic("Non string key-entry")
    54  				}
    55  				tmp[cik] = iv
    56  			}
    57  			result[key] = getKeysRecursively(fqKey+".", getKey, tmp)
    58  		} else if m, ok := val.(map[string]interface{}); ok {
    59  			logger.Debugf("Found map[string]interface{} value for %s", fqKey)
    60  			result[key] = getKeysRecursively(fqKey+".", getKey, m)
    61  		} else if m, ok := unmarshalJSON(val); ok {
    62  			logger.Debugf("Found real value for %s setting to map[string]string %v", fqKey, m)
    63  			result[key] = m
    64  		} else {
    65  			if val == nil {
    66  				fileSubKey := fqKey + ".File"
    67  				fileVal := getKey(fileSubKey)
    68  				if fileVal != nil {
    69  					result[key] = map[string]interface{}{"File": fileVal}
    70  					continue
    71  				}
    72  			}
    73  			logger.Debugf("Found real value for %s setting to %T %v", fqKey, val, val)
    74  			result[key] = val
    75  
    76  		}
    77  	}
    78  	return result
    79  }
    80  
    81  func unmarshalJSON(val interface{}) (map[string]string, bool) {
    82  	mp := map[string]string{}
    83  	s, ok := val.(string)
    84  	if !ok {
    85  		logger.Debugf("Unmarshal JSON: value is not a string: %v", val)
    86  		return nil, false
    87  	}
    88  	err := json.Unmarshal([]byte(s), &mp)
    89  	if err != nil {
    90  		logger.Debugf("Unmarshal JSON: value cannot be unmarshalled: %s", err)
    91  		return nil, false
    92  	}
    93  	return mp, true
    94  }
    95  
    96  // customDecodeHook adds the additional functions of parsing durations from strings
    97  // as well as parsing strings of the format "[thing1, thing2, thing3]" into string slices
    98  // Note that whitespace around slice elements is removed
    99  func customDecodeHook() mapstructure.DecodeHookFunc {
   100  	durationHook := mapstructure.StringToTimeDurationHookFunc()
   101  	return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   102  		dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data)
   103  		if err == nil {
   104  			if _, ok := dur.(time.Duration); ok {
   105  				return dur, nil
   106  			}
   107  		}
   108  
   109  		if f.Kind() != reflect.String {
   110  			return data, nil
   111  		}
   112  
   113  		raw := data.(string)
   114  		l := len(raw)
   115  		if l > 1 && raw[0] == '[' && raw[l-1] == ']' {
   116  			slice := strings.Split(raw[1:l-1], ",")
   117  			for i, v := range slice {
   118  				slice[i] = strings.TrimSpace(v)
   119  			}
   120  			return slice, nil
   121  		}
   122  
   123  		return data, nil
   124  	}
   125  }
   126  
   127  func byteSizeDecodeHook() mapstructure.DecodeHookFunc {
   128  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   129  		if f != reflect.String || t != reflect.Uint32 {
   130  			return data, nil
   131  		}
   132  		raw := data.(string)
   133  		if raw == "" {
   134  			return data, nil
   135  		}
   136  		var re = regexp.MustCompile(`^(?P<size>[0-9]+)\s*(?i)(?P<unit>(k|m|g))b?$`)
   137  		if re.MatchString(raw) {
   138  			size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64)
   139  			if err != nil {
   140  				return data, nil
   141  			}
   142  			unit := re.ReplaceAllString(raw, "${unit}")
   143  			switch strings.ToLower(unit) {
   144  			case "g":
   145  				size = size << 10
   146  				fallthrough
   147  			case "m":
   148  				size = size << 10
   149  				fallthrough
   150  			case "k":
   151  				size = size << 10
   152  			}
   153  			if size > math.MaxUint32 {
   154  				return size, fmt.Errorf("value '%s' overflows uint32", raw)
   155  			}
   156  			return size, nil
   157  		}
   158  		return data, nil
   159  	}
   160  }
   161  
   162  func stringFromFileDecodeHook() mapstructure.DecodeHookFunc {
   163  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   164  		// "to" type should be string
   165  		if t != reflect.String {
   166  			return data, nil
   167  		}
   168  		// "from" type should be map
   169  		if f != reflect.Map {
   170  			return data, nil
   171  		}
   172  		v := reflect.ValueOf(data)
   173  		switch v.Kind() {
   174  		case reflect.String:
   175  			return data, nil
   176  		case reflect.Map:
   177  			d := data.(map[string]interface{})
   178  			fileName, ok := d["File"]
   179  			if !ok {
   180  				fileName, ok = d["file"]
   181  			}
   182  			switch {
   183  			case ok && fileName != nil:
   184  				bytes, err := ioutil.ReadFile(fileName.(string))
   185  				if err != nil {
   186  					return data, err
   187  				}
   188  				return string(bytes), nil
   189  			case ok:
   190  				// fileName was nil
   191  				return nil, fmt.Errorf("Value of File: was nil")
   192  			}
   193  		}
   194  		return data, nil
   195  	}
   196  }
   197  
   198  func pemBlocksFromFileDecodeHook() mapstructure.DecodeHookFunc {
   199  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   200  		// "to" type should be string
   201  		if t != reflect.Slice {
   202  			return data, nil
   203  		}
   204  		// "from" type should be map
   205  		if f != reflect.Map {
   206  			return data, nil
   207  		}
   208  		v := reflect.ValueOf(data)
   209  		switch v.Kind() {
   210  		case reflect.String:
   211  			return data, nil
   212  		case reflect.Map:
   213  			var fileName string
   214  			var ok bool
   215  			switch d := data.(type) {
   216  			case map[string]string:
   217  				fileName, ok = d["File"]
   218  				if !ok {
   219  					fileName, ok = d["file"]
   220  				}
   221  			case map[string]interface{}:
   222  				var fileI interface{}
   223  				fileI, ok = d["File"]
   224  				if !ok {
   225  					fileI, _ = d["file"]
   226  				}
   227  				fileName, ok = fileI.(string)
   228  			}
   229  
   230  			switch {
   231  			case ok && fileName != "":
   232  				var result []string
   233  				bytes, err := ioutil.ReadFile(fileName)
   234  				if err != nil {
   235  					return data, err
   236  				}
   237  				for len(bytes) > 0 {
   238  					var block *pem.Block
   239  					block, bytes = pem.Decode(bytes)
   240  					if block == nil {
   241  						break
   242  					}
   243  					if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   244  						continue
   245  					}
   246  					result = append(result, string(pem.EncodeToMemory(block)))
   247  				}
   248  				return result, nil
   249  			case ok:
   250  				// fileName was nil
   251  				return nil, fmt.Errorf("Value of File: was nil")
   252  			}
   253  		}
   254  		return data, nil
   255  	}
   256  }
   257  
   258  func kafkaVersionDecodeHook() mapstructure.DecodeHookFunc {
   259  	return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   260  		if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) {
   261  			return data, nil
   262  		}
   263  		switch data {
   264  		case "0.8.2.0":
   265  			return sarama.V0_8_2_0, nil
   266  		case "0.8.2.1":
   267  			return sarama.V0_8_2_1, nil
   268  		case "0.8.2.2":
   269  			return sarama.V0_8_2_2, nil
   270  		case "0.9.0.0":
   271  			return sarama.V0_9_0_0, nil
   272  		case "0.9.0.1":
   273  			return sarama.V0_9_0_1, nil
   274  		case "0.10.0.0":
   275  			return sarama.V0_10_0_0, nil
   276  		case "0.10.0.1":
   277  			return sarama.V0_10_0_1, nil
   278  		case "0.10.1.0":
   279  			return sarama.V0_10_1_0, nil
   280  		default:
   281  			return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data)
   282  		}
   283  	}
   284  }
   285  
   286  // EnhancedExactUnmarshal is intended to unmarshal a config file into a structure
   287  // producing error when extraneous variables are introduced and supporting
   288  // the time.Duration type
   289  func EnhancedExactUnmarshal(v *viper.Viper, output interface{}) error {
   290  	// AllKeys doesn't actually return all keys, it only returns the base ones
   291  	baseKeys := v.AllSettings()
   292  	getterWithClass := func(key string) interface{} { return v.Get(key) } // hide receiver
   293  	leafKeys := getKeysRecursively("", getterWithClass, baseKeys)
   294  
   295  	logger.Debugf("%+v", leafKeys)
   296  	config := &mapstructure.DecoderConfig{
   297  		ErrorUnused:      true,
   298  		Metadata:         nil,
   299  		Result:           output,
   300  		WeaklyTypedInput: true,
   301  		DecodeHook: mapstructure.ComposeDecodeHookFunc(
   302  			customDecodeHook(),
   303  			byteSizeDecodeHook(),
   304  			stringFromFileDecodeHook(),
   305  			pemBlocksFromFileDecodeHook(),
   306  			kafkaVersionDecodeHook(),
   307  		),
   308  	}
   309  
   310  	decoder, err := mapstructure.NewDecoder(config)
   311  	if err != nil {
   312  		return err
   313  	}
   314  	return decoder.Decode(leafKeys)
   315  }
   316  
   317  // EnhancedExactUnmarshalKey is intended to unmarshal a config file subtreee into a structure
   318  func EnhancedExactUnmarshalKey(baseKey string, output interface{}) error {
   319  	m := make(map[string]interface{})
   320  	m[baseKey] = nil
   321  	leafKeys := getKeysRecursively("", viper.Get, m)
   322  
   323  	logger.Debugf("%+v", leafKeys)
   324  	return mapstructure.Decode(leafKeys[baseKey], output)
   325  }