github.com/yacovm/fabric@v2.0.0-alpha.0.20191128145320-c5d4087dc723+incompatible/common/viperutil/config_util.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package viperutil
     8  
     9  import (
    10  	"encoding/json"
    11  	"encoding/pem"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"math"
    15  	"reflect"
    16  	"regexp"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/Shopify/sarama"
    22  	version "github.com/hashicorp/go-version"
    23  	"github.com/hyperledger/fabric/bccsp/factory"
    24  	"github.com/hyperledger/fabric/common/flogging"
    25  	"github.com/mitchellh/mapstructure"
    26  	"github.com/pkg/errors"
    27  	"github.com/spf13/viper"
    28  )
    29  
    30  var logger = flogging.MustGetLogger("viperutil")
    31  
    32  type viperGetter func(key string) interface{}
    33  
    34  func getKeysRecursively(base string, getKey viperGetter, nodeKeys map[string]interface{}, oType reflect.Type) map[string]interface{} {
    35  	subTypes := map[string]reflect.Type{}
    36  
    37  	if oType != nil && oType.Kind() == reflect.Struct {
    38  	outer:
    39  		for i := 0; i < oType.NumField(); i++ {
    40  			fieldName := oType.Field(i).Name
    41  			fieldType := oType.Field(i).Type
    42  
    43  			for key := range nodeKeys {
    44  				if strings.EqualFold(fieldName, key) {
    45  					subTypes[key] = fieldType
    46  					continue outer
    47  				}
    48  			}
    49  
    50  			subTypes[fieldName] = fieldType
    51  			nodeKeys[fieldName] = nil
    52  		}
    53  	}
    54  
    55  	result := make(map[string]interface{})
    56  	for key := range nodeKeys {
    57  		fqKey := base + key
    58  
    59  		val := getKey(fqKey)
    60  		if m, ok := val.(map[interface{}]interface{}); ok {
    61  			logger.Debugf("Found map[interface{}]interface{} value for %s", fqKey)
    62  			tmp := make(map[string]interface{})
    63  			for ik, iv := range m {
    64  				cik, ok := ik.(string)
    65  				if !ok {
    66  					panic("Non string key-entry")
    67  				}
    68  				tmp[cik] = iv
    69  			}
    70  			result[key] = getKeysRecursively(fqKey+".", getKey, tmp, subTypes[key])
    71  		} else if m, ok := val.(map[string]interface{}); ok {
    72  			logger.Debugf("Found map[string]interface{} value for %s", fqKey)
    73  			result[key] = getKeysRecursively(fqKey+".", getKey, m, subTypes[key])
    74  		} else if m, ok := unmarshalJSON(val); ok {
    75  			logger.Debugf("Found real value for %s setting to map[string]string %v", fqKey, m)
    76  			result[key] = m
    77  		} else {
    78  			if val == nil {
    79  				fileSubKey := fqKey + ".File"
    80  				fileVal := getKey(fileSubKey)
    81  				if fileVal != nil {
    82  					result[key] = map[string]interface{}{"File": fileVal}
    83  					continue
    84  				}
    85  			}
    86  			logger.Debugf("Found real value for %s setting to %T %v", fqKey, val, val)
    87  			result[key] = val
    88  
    89  		}
    90  	}
    91  	return result
    92  }
    93  
    94  func unmarshalJSON(val interface{}) (map[string]string, bool) {
    95  	mp := map[string]string{}
    96  	s, ok := val.(string)
    97  	if !ok {
    98  		logger.Debugf("Unmarshal JSON: value is not a string: %v", val)
    99  		return nil, false
   100  	}
   101  	err := json.Unmarshal([]byte(s), &mp)
   102  	if err != nil {
   103  		logger.Debugf("Unmarshal JSON: value cannot be unmarshalled: %s", err)
   104  		return nil, false
   105  	}
   106  	return mp, true
   107  }
   108  
   109  // customDecodeHook adds the additional functions of parsing durations from strings
   110  // as well as parsing strings of the format "[thing1, thing2, thing3]" into string slices
   111  // Note that whitespace around slice elements is removed
   112  func customDecodeHook() mapstructure.DecodeHookFunc {
   113  	durationHook := mapstructure.StringToTimeDurationHookFunc()
   114  	return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   115  		dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data)
   116  		if err == nil {
   117  			if _, ok := dur.(time.Duration); ok {
   118  				return dur, nil
   119  			}
   120  		}
   121  
   122  		if f.Kind() != reflect.String {
   123  			return data, nil
   124  		}
   125  
   126  		raw := data.(string)
   127  		l := len(raw)
   128  		if l > 1 && raw[0] == '[' && raw[l-1] == ']' {
   129  			slice := strings.Split(raw[1:l-1], ",")
   130  			for i, v := range slice {
   131  				slice[i] = strings.TrimSpace(v)
   132  			}
   133  			return slice, nil
   134  		}
   135  
   136  		return data, nil
   137  	}
   138  }
   139  
   140  func byteSizeDecodeHook() mapstructure.DecodeHookFunc {
   141  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   142  		if f != reflect.String || t != reflect.Uint32 {
   143  			return data, nil
   144  		}
   145  		raw := data.(string)
   146  		if raw == "" {
   147  			return data, nil
   148  		}
   149  		var re = regexp.MustCompile(`^(?P<size>[0-9]+)\s*(?i)(?P<unit>(k|m|g))b?$`)
   150  		if re.MatchString(raw) {
   151  			size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64)
   152  			if err != nil {
   153  				return data, nil
   154  			}
   155  			unit := re.ReplaceAllString(raw, "${unit}")
   156  			switch strings.ToLower(unit) {
   157  			case "g":
   158  				size = size << 10
   159  				fallthrough
   160  			case "m":
   161  				size = size << 10
   162  				fallthrough
   163  			case "k":
   164  				size = size << 10
   165  			}
   166  			if size > math.MaxUint32 {
   167  				return size, fmt.Errorf("value '%s' overflows uint32", raw)
   168  			}
   169  			return size, nil
   170  		}
   171  		return data, nil
   172  	}
   173  }
   174  
   175  func stringFromFileDecodeHook() mapstructure.DecodeHookFunc {
   176  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   177  		// "to" type should be string
   178  		if t != reflect.String {
   179  			return data, nil
   180  		}
   181  		// "from" type should be map
   182  		if f != reflect.Map {
   183  			return data, nil
   184  		}
   185  		v := reflect.ValueOf(data)
   186  		switch v.Kind() {
   187  		case reflect.String:
   188  			return data, nil
   189  		case reflect.Map:
   190  			d := data.(map[string]interface{})
   191  			fileName, ok := d["File"]
   192  			if !ok {
   193  				fileName, ok = d["file"]
   194  			}
   195  			switch {
   196  			case ok && fileName != nil:
   197  				bytes, err := ioutil.ReadFile(fileName.(string))
   198  				if err != nil {
   199  					return data, err
   200  				}
   201  				return string(bytes), nil
   202  			case ok:
   203  				// fileName was nil
   204  				return nil, fmt.Errorf("Value of File: was nil")
   205  			}
   206  		}
   207  		return data, nil
   208  	}
   209  }
   210  
   211  func pemBlocksFromFileDecodeHook() mapstructure.DecodeHookFunc {
   212  	return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   213  		// "to" type should be string
   214  		if t != reflect.Slice {
   215  			return data, nil
   216  		}
   217  		// "from" type should be map
   218  		if f != reflect.Map {
   219  			return data, nil
   220  		}
   221  		v := reflect.ValueOf(data)
   222  		switch v.Kind() {
   223  		case reflect.String:
   224  			return data, nil
   225  		case reflect.Map:
   226  			var fileName string
   227  			var ok bool
   228  			switch d := data.(type) {
   229  			case map[string]string:
   230  				fileName, ok = d["File"]
   231  				if !ok {
   232  					fileName, ok = d["file"]
   233  				}
   234  			case map[string]interface{}:
   235  				var fileI interface{}
   236  				fileI, ok = d["File"]
   237  				if !ok {
   238  					fileI = d["file"]
   239  				}
   240  				fileName, ok = fileI.(string)
   241  			}
   242  
   243  			switch {
   244  			case ok && fileName != "":
   245  				var result []string
   246  				bytes, err := ioutil.ReadFile(fileName)
   247  				if err != nil {
   248  					return data, err
   249  				}
   250  				for len(bytes) > 0 {
   251  					var block *pem.Block
   252  					block, bytes = pem.Decode(bytes)
   253  					if block == nil {
   254  						break
   255  					}
   256  					if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   257  						continue
   258  					}
   259  					result = append(result, string(pem.EncodeToMemory(block)))
   260  				}
   261  				return result, nil
   262  			case ok:
   263  				// fileName was nil
   264  				return nil, fmt.Errorf("Value of File: was nil")
   265  			}
   266  		}
   267  		return data, nil
   268  	}
   269  }
   270  
   271  var kafkaVersionConstraints map[sarama.KafkaVersion]version.Constraints
   272  
   273  func init() {
   274  	kafkaVersionConstraints = make(map[sarama.KafkaVersion]version.Constraints)
   275  	kafkaVersionConstraints[sarama.V0_8_2_0], _ = version.NewConstraint(">=0.8.2,<0.8.2.1")
   276  	kafkaVersionConstraints[sarama.V0_8_2_1], _ = version.NewConstraint(">=0.8.2.1,<0.8.2.2")
   277  	kafkaVersionConstraints[sarama.V0_8_2_2], _ = version.NewConstraint(">=0.8.2.2,<0.9.0.0")
   278  	kafkaVersionConstraints[sarama.V0_9_0_0], _ = version.NewConstraint(">=0.9.0.0,<0.9.0.1")
   279  	kafkaVersionConstraints[sarama.V0_9_0_1], _ = version.NewConstraint(">=0.9.0.1,<0.10.0.0")
   280  	kafkaVersionConstraints[sarama.V0_10_0_0], _ = version.NewConstraint(">=0.10.0.0,<0.10.0.1")
   281  	kafkaVersionConstraints[sarama.V0_10_0_1], _ = version.NewConstraint(">=0.10.0.1,<0.10.1.0")
   282  	kafkaVersionConstraints[sarama.V0_10_1_0], _ = version.NewConstraint(">=0.10.1.0,<0.10.2.0")
   283  	kafkaVersionConstraints[sarama.V0_10_2_0], _ = version.NewConstraint(">=0.10.2.0,<0.11.0.0")
   284  	kafkaVersionConstraints[sarama.V0_11_0_0], _ = version.NewConstraint(">=0.11.0.0,<1.0.0")
   285  	kafkaVersionConstraints[sarama.V1_0_0_0], _ = version.NewConstraint(">=1.0.0")
   286  }
   287  
   288  func kafkaVersionDecodeHook() mapstructure.DecodeHookFunc {
   289  	return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   290  		if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) {
   291  			return data, nil
   292  		}
   293  
   294  		v, err := version.NewVersion(data.(string))
   295  		if err != nil {
   296  			return nil, fmt.Errorf("Unable to parse Kafka version: %s", err)
   297  		}
   298  
   299  		for kafkaVersion, constraints := range kafkaVersionConstraints {
   300  			if constraints.Check(v) {
   301  				return kafkaVersion, nil
   302  			}
   303  		}
   304  
   305  		return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data)
   306  	}
   307  }
   308  
   309  func bccspHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   310  	if t != reflect.TypeOf(&factory.FactoryOpts{}) {
   311  		return data, nil
   312  	}
   313  
   314  	config := factory.GetDefaultOpts()
   315  
   316  	err := mapstructure.Decode(data, config)
   317  	if err != nil {
   318  		return nil, errors.Wrap(err, "could not decode bcssp type")
   319  	}
   320  
   321  	return config, nil
   322  }
   323  
   324  // EnhancedExactUnmarshal is intended to unmarshal a config file into a structure
   325  // producing error when extraneous variables are introduced and supporting
   326  // the time.Duration type
   327  func EnhancedExactUnmarshal(v *viper.Viper, output interface{}) error {
   328  	oType := reflect.TypeOf(output)
   329  	if oType.Kind() != reflect.Ptr {
   330  		return errors.Errorf("supplied output argument must be a pointer to a struct but is not pointer")
   331  	}
   332  	eType := oType.Elem()
   333  	if eType.Kind() != reflect.Struct {
   334  		return errors.Errorf("supplied output argument must be a pointer to a struct, but it is pointer to something else")
   335  	}
   336  
   337  	baseKeys := v.AllSettings()
   338  
   339  	getterWithClass := func(key string) interface{} { return v.Get(key) } // hide receiver
   340  	leafKeys := getKeysRecursively("", getterWithClass, baseKeys, eType)
   341  
   342  	logger.Debugf("%+v", leafKeys)
   343  	config := &mapstructure.DecoderConfig{
   344  		ErrorUnused:      true,
   345  		Metadata:         nil,
   346  		Result:           output,
   347  		WeaklyTypedInput: true,
   348  		DecodeHook: mapstructure.ComposeDecodeHookFunc(
   349  			bccspHook,
   350  			customDecodeHook(),
   351  			byteSizeDecodeHook(),
   352  			stringFromFileDecodeHook(),
   353  			pemBlocksFromFileDecodeHook(),
   354  			kafkaVersionDecodeHook(),
   355  		),
   356  	}
   357  
   358  	decoder, err := mapstructure.NewDecoder(config)
   359  	if err != nil {
   360  		return err
   361  	}
   362  	return decoder.Decode(leafKeys)
   363  }