github.com/tenywen/fabric@v1.0.0-beta.0.20170620030522-a5b1ed380643/common/viperutil/config_util.go (about)

     1  /*
     2  Copyright IBM Corp. 2016 All Rights Reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8                   http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package viperutil
    18  
    19  import (
    20  	"fmt"
    21  	"io/ioutil"
    22  	"math"
    23  	"reflect"
    24  	"regexp"
    25  	"strconv"
    26  	"strings"
    27  	"time"
    28  
    29  	"encoding/json"
    30  	"encoding/pem"
    31  
    32  	"github.com/hyperledger/fabric/common/flogging"
    33  	"github.com/mitchellh/mapstructure"
    34  	"github.com/spf13/viper"
    35  )
    36  
    37  var logger = flogging.MustGetLogger("viperutil")
    38  
    39  type viperGetter func(key string) interface{}
    40  
    41  func getKeysRecursively(base string, getKey viperGetter, nodeKeys map[string]interface{}) map[string]interface{} {
    42  	result := make(map[string]interface{})
    43  	for key := range nodeKeys {
    44  		fqKey := base + key
    45  		val := getKey(fqKey)
    46  		if m, ok := val.(map[interface{}]interface{}); ok {
    47  			logger.Debugf("Found map[interface{}]interface{} value for %s", fqKey)
    48  			tmp := make(map[string]interface{})
    49  			for ik, iv := range m {
    50  				cik, ok := ik.(string)
    51  				if !ok {
    52  					panic("Non string key-entry")
    53  				}
    54  				tmp[cik] = iv
    55  			}
    56  			result[key] = getKeysRecursively(fqKey+".", getKey, tmp)
    57  		} else if m, ok := val.(map[string]interface{}); ok {
    58  			logger.Debugf("Found map[string]interface{} value for %s", fqKey)
    59  			result[key] = getKeysRecursively(fqKey+".", getKey, m)
    60  		} else if m, ok := unmarshalJSON(val); ok {
    61  			logger.Debugf("Found real value for %s setting to map[string]string %v", fqKey, m)
    62  			result[key] = m
    63  		} else {
    64  			if val == nil {
    65  				fileSubKey := fqKey + ".File"
    66  				fileVal := getKey(fileSubKey)
    67  				if fileVal != nil {
    68  					result[key] = map[string]interface{}{"File": fileVal}
    69  					continue
    70  				}
    71  			}
    72  			logger.Debugf("Found real value for %s setting to %T %v", fqKey, val, val)
    73  			result[key] = val
    74  
    75  		}
    76  	}
    77  	return result
    78  }
    79  
    80  func unmarshalJSON(val interface{}) (map[string]string, bool) {
    81  	mp := map[string]string{}
    82  	s, ok := val.(string)
    83  	if !ok {
    84  		logger.Debugf("Unmarshal JSON: value is not a string: %v", val)
    85  		return nil, false
    86  	}
    87  	err := json.Unmarshal([]byte(s), &mp)
    88  	if err != nil {
    89  		logger.Debugf("Unmarshal JSON: value cannot be unmarshalled: %s", err)
    90  		return nil, false
    91  	}
    92  	return mp, true
    93  }
    94  
    95  // customDecodeHook adds the additional functions of parsing durations from strings
    96  // as well as parsing strings of the format "[thing1, thing2, thing3]" into string slices
    97  // Note that whitespace around slice elements is removed
    98  func customDecodeHook() mapstructure.DecodeHookFunc {
    99  	durationHook := mapstructure.StringToTimeDurationHookFunc()
   100  	return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   101  		dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data)
   102  		if err == nil {
   103  			if _, ok := dur.(time.Duration); ok {
   104  				return dur, nil
   105  			}
   106  		}
   107  
   108  		if f.Kind() != reflect.String {
   109  			return data, nil
   110  		}
   111  
   112  		raw := data.(string)
   113  		l := len(raw)
   114  		if l > 1 && raw[0] == '[' && raw[l-1] == ']' {
   115  			slice := strings.Split(raw[1:l-1], ",")
   116  			for i, v := range slice {
   117  				slice[i] = strings.TrimSpace(v)
   118  			}
   119  			return slice, nil
   120  		}
   121  
   122  		return data, nil
   123  	}
   124  }
   125  
   126  func byteSizeDecodeHook() mapstructure.DecodeHookFunc {
   127  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   128  		if f != reflect.String || t != reflect.Uint32 {
   129  			return data, nil
   130  		}
   131  		raw := data.(string)
   132  		if raw == "" {
   133  			return data, nil
   134  		}
   135  		var re = regexp.MustCompile(`^(?P<size>[0-9]+)\s*(?i)(?P<unit>(k|m|g))b?$`)
   136  		if re.MatchString(raw) {
   137  			size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64)
   138  			if err != nil {
   139  				return data, nil
   140  			}
   141  			unit := re.ReplaceAllString(raw, "${unit}")
   142  			switch strings.ToLower(unit) {
   143  			case "g":
   144  				size = size << 10
   145  				fallthrough
   146  			case "m":
   147  				size = size << 10
   148  				fallthrough
   149  			case "k":
   150  				size = size << 10
   151  			}
   152  			if size > math.MaxUint32 {
   153  				return size, fmt.Errorf("value '%s' overflows uint32", raw)
   154  			}
   155  			return size, nil
   156  		}
   157  		return data, nil
   158  	}
   159  }
   160  
   161  func stringFromFileDecodeHook() mapstructure.DecodeHookFunc {
   162  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   163  		// "to" type should be string
   164  		if t != reflect.String {
   165  			return data, nil
   166  		}
   167  		// "from" type should be map
   168  		if f != reflect.Map {
   169  			return data, nil
   170  		}
   171  		v := reflect.ValueOf(data)
   172  		switch v.Kind() {
   173  		case reflect.String:
   174  			return data, nil
   175  		case reflect.Map:
   176  			d := data.(map[string]interface{})
   177  			fileName, ok := d["File"]
   178  			if !ok {
   179  				fileName, ok = d["file"]
   180  			}
   181  			switch {
   182  			case ok && fileName != nil:
   183  				bytes, err := ioutil.ReadFile(fileName.(string))
   184  				if err != nil {
   185  					return data, err
   186  				}
   187  				return string(bytes), nil
   188  			case ok:
   189  				// fileName was nil
   190  				return nil, fmt.Errorf("Value of File: was nil")
   191  			}
   192  		}
   193  		return data, nil
   194  	}
   195  }
   196  
   197  func pemBlocksFromFileDecodeHook() mapstructure.DecodeHookFunc {
   198  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   199  		// "to" type should be string
   200  		if t != reflect.Slice {
   201  			return data, nil
   202  		}
   203  		// "from" type should be map
   204  		if f != reflect.Map {
   205  			return data, nil
   206  		}
   207  		v := reflect.ValueOf(data)
   208  		switch v.Kind() {
   209  		case reflect.String:
   210  			return data, nil
   211  		case reflect.Map:
   212  			var fileName string
   213  			var ok bool
   214  			switch d := data.(type) {
   215  			case map[string]string:
   216  				fileName, ok = d["File"]
   217  				if !ok {
   218  					fileName, ok = d["file"]
   219  				}
   220  			case map[string]interface{}:
   221  				var fileI interface{}
   222  				fileI, ok = d["File"]
   223  				if !ok {
   224  					fileI, _ = d["file"]
   225  				}
   226  				fileName, ok = fileI.(string)
   227  			}
   228  
   229  			switch {
   230  			case ok && fileName != "":
   231  				var result []string
   232  				bytes, err := ioutil.ReadFile(fileName)
   233  				if err != nil {
   234  					return data, err
   235  				}
   236  				for len(bytes) > 0 {
   237  					var block *pem.Block
   238  					block, bytes = pem.Decode(bytes)
   239  					if block == nil {
   240  						break
   241  					}
   242  					if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   243  						continue
   244  					}
   245  					result = append(result, string(pem.EncodeToMemory(block)))
   246  				}
   247  				return result, nil
   248  			case ok:
   249  				// fileName was nil
   250  				return nil, fmt.Errorf("Value of File: was nil")
   251  			}
   252  		}
   253  		return data, nil
   254  	}
   255  }
   256  
   257  // EnhancedExactUnmarshal is intended to unmarshal a config file into a structure
   258  // producing error when extraneous variables are introduced and supporting
   259  // the time.Duration type
   260  func EnhancedExactUnmarshal(v *viper.Viper, output interface{}) error {
   261  	// AllKeys doesn't actually return all keys, it only returns the base ones
   262  	baseKeys := v.AllSettings()
   263  	getterWithClass := func(key string) interface{} { return v.Get(key) } // hide receiver
   264  	leafKeys := getKeysRecursively("", getterWithClass, baseKeys)
   265  
   266  	logger.Debugf("%+v", leafKeys)
   267  	config := &mapstructure.DecoderConfig{
   268  		ErrorUnused:      true,
   269  		Metadata:         nil,
   270  		Result:           output,
   271  		WeaklyTypedInput: true,
   272  		DecodeHook: mapstructure.ComposeDecodeHookFunc(
   273  			customDecodeHook(),
   274  			byteSizeDecodeHook(),
   275  			stringFromFileDecodeHook(),
   276  			pemBlocksFromFileDecodeHook(),
   277  		),
   278  	}
   279  
   280  	decoder, err := mapstructure.NewDecoder(config)
   281  	if err != nil {
   282  		return err
   283  	}
   284  	return decoder.Decode(leafKeys)
   285  }
   286  
   287  // EnhancedExactUnmarshalKey is intended to unmarshal a config file subtreee into a structure
   288  func EnhancedExactUnmarshalKey(baseKey string, output interface{}) error {
   289  	m := make(map[string]interface{})
   290  	m[baseKey] = nil
   291  	leafKeys := getKeysRecursively("", viper.Get, m)
   292  
   293  	logger.Debugf("%+v", leafKeys)
   294  	return mapstructure.Decode(leafKeys[baseKey], output)
   295  }