github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/viperutil/config_util.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package viperutil
     8  
     9  import (
    10  	"encoding/pem"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"math"
    15  	"os"
    16  	"path/filepath"
    17  	"reflect"
    18  	"regexp"
    19  	"strconv"
    20  	"strings"
    21  
    22  	"github.com/Shopify/sarama"
    23  	version "github.com/hashicorp/go-version"
    24  	"github.com/hechain20/hechain/bccsp/factory"
    25  	"github.com/hechain20/hechain/common/flogging"
    26  	"github.com/mitchellh/mapstructure"
    27  	"github.com/pkg/errors"
    28  	"gopkg.in/yaml.v2"
    29  )
    30  
    31  var logger = flogging.MustGetLogger("viperutil")
    32  
    33  // ConfigPaths returns the paths from environment and
    34  // defaults which are CWD and /etc/hyperledger/fabric.
    35  func ConfigPaths() []string {
    36  	var paths []string
    37  	if p := os.Getenv("FABRIC_CFG_PATH"); p != "" {
    38  		paths = append(paths, p)
    39  	}
    40  	return append(paths, ".", "/etc/hyperledger/fabric")
    41  }
    42  
    43  // ConfigParser holds the configuration file locations.
    44  // It keeps the config file directory locations and env variables.
    45  // From the file the config is unmarshalled and stored.
    46  // Currently "yaml" is supported.
    47  type ConfigParser struct {
    48  	// configuration file to process
    49  	configPaths []string
    50  	configName  string
    51  	configFile  string
    52  
    53  	// parsed config
    54  	config map[string]interface{}
    55  }
    56  
    57  // New creates a ConfigParser instance
    58  func New() *ConfigParser {
    59  	return &ConfigParser{
    60  		config: map[string]interface{}{},
    61  	}
    62  }
    63  
    64  // AddConfigPaths keeps a list of path to search the relevant
    65  // config file. Multiple paths can be provided.
    66  func (c *ConfigParser) AddConfigPaths(cfgPaths ...string) {
    67  	c.configPaths = append(c.configPaths, cfgPaths...)
    68  }
    69  
    70  // SetConfigName provides the configuration file name stem. The upper-cased
    71  // version of this value also serves as the environment variable override
    72  // prefix.
    73  func (c *ConfigParser) SetConfigName(in string) {
    74  	c.configName = in
    75  }
    76  
    77  // ConfigFileUsed returns the used configFile.
    78  func (c *ConfigParser) ConfigFileUsed() string {
    79  	return c.configFile
    80  }
    81  
    82  // Search for the existence of filename for all supported extensions
    83  func (c *ConfigParser) searchInPath(in string) (filename string) {
    84  	var supportedExts []string = []string{"yaml", "yml"}
    85  	for _, ext := range supportedExts {
    86  		fullPath := filepath.Join(in, c.configName+"."+ext)
    87  		_, err := os.Stat(fullPath)
    88  		if err == nil {
    89  			return fullPath
    90  		}
    91  	}
    92  	return ""
    93  }
    94  
    95  // Search for the configName in all configPaths
    96  func (c *ConfigParser) findConfigFile() string {
    97  	paths := c.configPaths
    98  	if len(paths) == 0 {
    99  		paths = ConfigPaths()
   100  	}
   101  	for _, cp := range paths {
   102  		file := c.searchInPath(cp)
   103  		if file != "" {
   104  			return file
   105  		}
   106  	}
   107  	return ""
   108  }
   109  
   110  // Get the valid and present config file
   111  func (c *ConfigParser) getConfigFile() string {
   112  	// if explicitly set, then use it
   113  	if c.configFile != "" {
   114  		return c.configFile
   115  	}
   116  
   117  	c.configFile = c.findConfigFile()
   118  	return c.configFile
   119  }
   120  
   121  // ReadInConfig reads and unmarshals the config file.
   122  func (c *ConfigParser) ReadInConfig() error {
   123  	cf := c.getConfigFile()
   124  	logger.Debugf("Attempting to open the config file: %s", cf)
   125  	file, err := os.Open(cf)
   126  	if err != nil {
   127  		logger.Errorf("Unable to open the config file: %s", cf)
   128  		return err
   129  	}
   130  	defer file.Close()
   131  
   132  	return c.ReadConfig(file)
   133  }
   134  
   135  // ReadConfig parses the buffer and initializes the config.
   136  func (c *ConfigParser) ReadConfig(in io.Reader) error {
   137  	return yaml.NewDecoder(in).Decode(c.config)
   138  }
   139  
   140  // Get value for the key by searching environment variables.
   141  func (c *ConfigParser) getFromEnv(key string) string {
   142  	envKey := key
   143  	if c.configName != "" {
   144  		envKey = c.configName + "_" + envKey
   145  	}
   146  	envKey = strings.ToUpper(envKey)
   147  	envKey = strings.ReplaceAll(envKey, ".", "_")
   148  	return os.Getenv(envKey)
   149  }
   150  
   151  // Prototype declaration for getFromEnv function.
   152  type envGetter func(key string) string
   153  
   154  func getKeysRecursively(base string, getenv envGetter, nodeKeys map[string]interface{}, oType reflect.Type) map[string]interface{} {
   155  	subTypes := map[string]reflect.Type{}
   156  
   157  	if oType != nil && oType.Kind() == reflect.Struct {
   158  	outer:
   159  		for i := 0; i < oType.NumField(); i++ {
   160  			fieldName := oType.Field(i).Name
   161  			fieldType := oType.Field(i).Type
   162  
   163  			for key := range nodeKeys {
   164  				if strings.EqualFold(fieldName, key) {
   165  					subTypes[key] = fieldType
   166  					continue outer
   167  				}
   168  			}
   169  
   170  			subTypes[fieldName] = fieldType
   171  			nodeKeys[fieldName] = nil
   172  		}
   173  	}
   174  
   175  	result := make(map[string]interface{})
   176  	for key, val := range nodeKeys {
   177  		fqKey := base + key
   178  
   179  		// overwrite val, if an environment is available
   180  		if override := getenv(fqKey); override != "" {
   181  			val = override
   182  		}
   183  
   184  		switch val := val.(type) {
   185  		case map[string]interface{}:
   186  			logger.Debugf("Found map[string]interface{} value for %s", fqKey)
   187  			result[key] = getKeysRecursively(fqKey+".", getenv, val, subTypes[key])
   188  
   189  		case map[interface{}]interface{}:
   190  			logger.Debugf("Found map[interface{}]interface{} value for %s", fqKey)
   191  			result[key] = getKeysRecursively(fqKey+".", getenv, toMapStringInterface(val), subTypes[key])
   192  
   193  		case nil:
   194  			if override := getenv(fqKey + ".File"); override != "" {
   195  				result[key] = map[string]interface{}{"File": override}
   196  			}
   197  
   198  		default:
   199  			result[key] = val
   200  		}
   201  	}
   202  	return result
   203  }
   204  
   205  func toMapStringInterface(m map[interface{}]interface{}) map[string]interface{} {
   206  	result := map[string]interface{}{}
   207  	for k, v := range m {
   208  		k, ok := k.(string)
   209  		if !ok {
   210  			panic(fmt.Sprintf("Non string %v, %v: key-entry: %v", k, v, k))
   211  		}
   212  		result[k] = v
   213  	}
   214  	return result
   215  }
   216  
   217  // customDecodeHook parses strings of the format "[thing1, thing2, thing3]"
   218  // into string slices. Note that whitespace around slice elements is removed.
   219  func customDecodeHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   220  	if f.Kind() != reflect.String {
   221  		return data, nil
   222  	}
   223  
   224  	raw := data.(string)
   225  	l := len(raw)
   226  	if l > 1 && raw[0] == '[' && raw[l-1] == ']' {
   227  		slice := strings.Split(raw[1:l-1], ",")
   228  		for i, v := range slice {
   229  			slice[i] = strings.TrimSpace(v)
   230  		}
   231  		return slice, nil
   232  	}
   233  
   234  	return data, nil
   235  }
   236  
   237  func byteSizeDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   238  	if f != reflect.String || t != reflect.Uint32 {
   239  		return data, nil
   240  	}
   241  	raw := data.(string)
   242  	if raw == "" {
   243  		return data, nil
   244  	}
   245  	re := regexp.MustCompile(`^(?P<size>[0-9]+)\s*(?i)(?P<unit>(k|m|g))b?$`)
   246  	if re.MatchString(raw) {
   247  		size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64)
   248  		if err != nil {
   249  			return data, nil
   250  		}
   251  		unit := re.ReplaceAllString(raw, "${unit}")
   252  		switch strings.ToLower(unit) {
   253  		case "g":
   254  			size = size << 10
   255  			fallthrough
   256  		case "m":
   257  			size = size << 10
   258  			fallthrough
   259  		case "k":
   260  			size = size << 10
   261  		}
   262  		if size > math.MaxUint32 {
   263  			return size, fmt.Errorf("value '%s' overflows uint32", raw)
   264  		}
   265  		return size, nil
   266  	}
   267  	return data, nil
   268  }
   269  
   270  func stringFromFileDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   271  	// "to" type should be string
   272  	if t != reflect.String {
   273  		return data, nil
   274  	}
   275  	// "from" type should be map
   276  	if f != reflect.Map {
   277  		return data, nil
   278  	}
   279  	v := reflect.ValueOf(data)
   280  	switch v.Kind() {
   281  	case reflect.String:
   282  		return data, nil
   283  	case reflect.Map:
   284  		d := data.(map[string]interface{})
   285  		fileName, ok := d["File"]
   286  		if !ok {
   287  			fileName, ok = d["file"]
   288  		}
   289  		switch {
   290  		case ok && fileName != nil:
   291  			bytes, err := ioutil.ReadFile(fileName.(string))
   292  			if err != nil {
   293  				return data, err
   294  			}
   295  			return string(bytes), nil
   296  		case ok:
   297  			// fileName was nil
   298  			return nil, fmt.Errorf("Value of File: was nil")
   299  		}
   300  	}
   301  	return data, nil
   302  }
   303  
   304  func pemBlocksFromFileDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
   305  	// "to" type should be string
   306  	if t != reflect.Slice {
   307  		return data, nil
   308  	}
   309  	// "from" type should be map
   310  	if f != reflect.Map {
   311  		return data, nil
   312  	}
   313  	v := reflect.ValueOf(data)
   314  	switch v.Kind() {
   315  	case reflect.String:
   316  		return data, nil
   317  	case reflect.Map:
   318  		var fileName string
   319  		var ok bool
   320  		switch d := data.(type) {
   321  		case map[string]string:
   322  			fileName, ok = d["File"]
   323  			if !ok {
   324  				fileName, ok = d["file"]
   325  			}
   326  		case map[string]interface{}:
   327  			var fileI interface{}
   328  			fileI, ok = d["File"]
   329  			if !ok {
   330  				fileI = d["file"]
   331  			}
   332  			fileName, ok = fileI.(string)
   333  		}
   334  
   335  		switch {
   336  		case ok && fileName != "":
   337  			var result []string
   338  			bytes, err := ioutil.ReadFile(fileName)
   339  			if err != nil {
   340  				return data, err
   341  			}
   342  			for len(bytes) > 0 {
   343  				var block *pem.Block
   344  				block, bytes = pem.Decode(bytes)
   345  				if block == nil {
   346  					break
   347  				}
   348  				if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   349  					continue
   350  				}
   351  				result = append(result, string(pem.EncodeToMemory(block)))
   352  			}
   353  			return result, nil
   354  		case ok:
   355  			// fileName was nil
   356  			return nil, fmt.Errorf("Value of File: was nil")
   357  		}
   358  	}
   359  	return data, nil
   360  }
   361  
   362  var kafkaVersionConstraints map[sarama.KafkaVersion]version.Constraints
   363  
   364  func init() {
   365  	kafkaVersionConstraints = make(map[sarama.KafkaVersion]version.Constraints)
   366  	kafkaVersionConstraints[sarama.V0_8_2_0], _ = version.NewConstraint(">=0.8.2,<0.8.2.1")
   367  	kafkaVersionConstraints[sarama.V0_8_2_1], _ = version.NewConstraint(">=0.8.2.1,<0.8.2.2")
   368  	kafkaVersionConstraints[sarama.V0_8_2_2], _ = version.NewConstraint(">=0.8.2.2,<0.9.0.0")
   369  	kafkaVersionConstraints[sarama.V0_9_0_0], _ = version.NewConstraint(">=0.9.0.0,<0.9.0.1")
   370  	kafkaVersionConstraints[sarama.V0_9_0_1], _ = version.NewConstraint(">=0.9.0.1,<0.10.0.0")
   371  	kafkaVersionConstraints[sarama.V0_10_0_0], _ = version.NewConstraint(">=0.10.0.0,<0.10.0.1")
   372  	kafkaVersionConstraints[sarama.V0_10_0_1], _ = version.NewConstraint(">=0.10.0.1,<0.10.1.0")
   373  	kafkaVersionConstraints[sarama.V0_10_1_0], _ = version.NewConstraint(">=0.10.1.0,<0.10.2.0")
   374  	kafkaVersionConstraints[sarama.V0_10_2_0], _ = version.NewConstraint(">=0.10.2.0,<0.11.0.0")
   375  	kafkaVersionConstraints[sarama.V0_11_0_0], _ = version.NewConstraint(">=0.11.0.0,<1.0.0")
   376  	kafkaVersionConstraints[sarama.V1_0_0_0], _ = version.NewConstraint(">=1.0.0")
   377  }
   378  
   379  func kafkaVersionDecodeHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   380  	if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) {
   381  		return data, nil
   382  	}
   383  
   384  	v, err := version.NewVersion(data.(string))
   385  	if err != nil {
   386  		return nil, fmt.Errorf("Unable to parse Kafka version: %s", err)
   387  	}
   388  
   389  	for kafkaVersion, constraints := range kafkaVersionConstraints {
   390  		if constraints.Check(v) {
   391  			return kafkaVersion, nil
   392  		}
   393  	}
   394  
   395  	return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data)
   396  }
   397  
   398  func bccspHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
   399  	if t != reflect.TypeOf(&factory.FactoryOpts{}) {
   400  		return data, nil
   401  	}
   402  
   403  	config := factory.GetDefaultOpts()
   404  
   405  	err := mapstructure.WeakDecode(data, config)
   406  	if err != nil {
   407  		return nil, errors.Wrap(err, "could not decode bccsp type")
   408  	}
   409  
   410  	return config, nil
   411  }
   412  
   413  // EnhancedExactUnmarshal is intended to unmarshal a config file into a structure
   414  // producing error when extraneous variables are introduced and supporting
   415  // the time.Duration type
   416  func (c *ConfigParser) EnhancedExactUnmarshal(output interface{}) error {
   417  	oType := reflect.TypeOf(output)
   418  	if oType.Kind() != reflect.Ptr {
   419  		return errors.Errorf("supplied output argument must be a pointer to a struct but is not pointer")
   420  	}
   421  	eType := oType.Elem()
   422  	if eType.Kind() != reflect.Struct {
   423  		return errors.Errorf("supplied output argument must be a pointer to a struct, but it is pointer to something else")
   424  	}
   425  
   426  	baseKeys := c.config
   427  	leafKeys := getKeysRecursively("", c.getFromEnv, baseKeys, eType)
   428  
   429  	logger.Debugf("%+v", leafKeys)
   430  	config := &mapstructure.DecoderConfig{
   431  		ErrorUnused:      true,
   432  		Metadata:         nil,
   433  		Result:           output,
   434  		WeaklyTypedInput: true,
   435  		DecodeHook: mapstructure.ComposeDecodeHookFunc(
   436  			bccspHook,
   437  			mapstructure.StringToTimeDurationHookFunc(),
   438  			customDecodeHook,
   439  			byteSizeDecodeHook,
   440  			stringFromFileDecodeHook,
   441  			pemBlocksFromFileDecodeHook,
   442  			kafkaVersionDecodeHook,
   443  		),
   444  	}
   445  
   446  	decoder, err := mapstructure.NewDecoder(config)
   447  	if err != nil {
   448  		return err
   449  	}
   450  	return decoder.Decode(leafKeys)
   451  }
   452  
   453  // YamlStringToStructHook is a hook for viper(viper.Unmarshal(*,*, here)), it is able to parse a string of minified yaml into a slice of structs
   454  func YamlStringToStructHook(m interface{}) func(rf reflect.Kind, rt reflect.Kind, data interface{}) (interface{}, error) {
   455  	return func(rf reflect.Kind, rt reflect.Kind, data interface{}) (interface{}, error) {
   456  		if rf != reflect.String || rt != reflect.Slice {
   457  			return data, nil
   458  		}
   459  
   460  		raw := data.(string)
   461  		if raw == "" {
   462  			return m, nil
   463  		}
   464  
   465  		return m, yaml.UnmarshalStrict([]byte(raw), &m)
   466  	}
   467  }