github.com/jxskiss/gopkg@v0.17.3/confr/loader.go (about)

     1  // Package confr provides a simple but yet powerful configuration loader.
     2  //
     3  // Features:
     4  //
     5  // 1. Load from command line flags defined by field tag `flag`;
     6  //
     7  // 2. Load by custom loader function for fields which have a `custom` tag,
     8  // this is useful where you may have configuration values stored in a
     9  // centric repository or a remote config center;
    10  //
    11  // 3. Load from environment variables by explicitly defined `env` tag or
    12  // auto-generated names implicitly;
    13  //
    14  // 4. Load from multiple configuration fields with priority and overriding;
    15  //
    16  // 5. Set default values by field tag `default` if a configuration field
    17  // is not given by any of the higher priority source;
    18  //
    19  // 6. Minimal dependency;
    20  //
    21  // You may check Config and Loader for more details.
    22  package confr
    23  
    24  import (
    25  	"bytes"
    26  	"encoding/json"
    27  	"errors"
    28  	"flag"
    29  	"fmt"
    30  	"io/ioutil"
    31  	"log"
    32  	"os"
    33  	"path"
    34  	"reflect"
    35  	"strings"
    36  	"unicode"
    37  
    38  	"github.com/BurntSushi/toml"
    39  	"github.com/spf13/cast"
    40  	"gopkg.in/yaml.v2"
    41  )
    42  
    43  const DefaultEnvPrefix = "Confr"
    44  
    45  const (
    46  	ConfrTag        = "confr"
    47  	CustomTag       = "custom"
    48  	DefaultValueTag = "default"
    49  	EnvTag          = "env"
    50  	FlagTag         = "flag"
    51  )
    52  
    53  // Config provides options to configure the behavior of Loader.
    54  type Config struct {
    55  
    56  	// Verbose tells the loader to output verbose logging messages.
    57  	Verbose bool
    58  
    59  	// DisallowUnknownFields causes the loader to return an error when
    60  	// the configuration files contain object keys which do not match
    61  	// the given destination struct.
    62  	DisallowUnknownFields bool
    63  
    64  	// EnableImplicitEnv enables the loader checking auto-generated names
    65  	// to find environment variables.
    66  	// The default is false, which means the loader will only check `env`
    67  	// tag, won't check auto-generated names.
    68  	EnableImplicitEnv bool
    69  
    70  	// EnvPrefix is used to prefix the auto-generated names to find
    71  	// environment variables. The default value is "Confr".
    72  	EnvPrefix string
    73  
    74  	// CustomLoader optionally loads fields which have a `custom` tag,
    75  	// the field's type and the tag value will be passed to the custom loader.
    76  	CustomLoader func(typ reflect.Type, tag string) (interface{}, error)
    77  
    78  	// FlagSet optionally specifies a flag set to lookup flag value
    79  	// for fields which have a `flag` tag. The tag value should be the
    80  	// flag name to lookup for.
    81  	FlagSet *flag.FlagSet
    82  }
    83  
    84  // Loader is used to load configuration from files (JSON/TOML/YAML),
    85  // environment variables, command line flags, or by custom loader function.
    86  //
    87  // The priority in descending order is:
    88  //
    89  // 1. command line flag defined by field tag `flag`;
    90  //
    91  // 2. custom loader function defined by field tag `custom`;
    92  //
    93  // 3. environment variables;
    94  //
    95  // 4. config files, if multiple files are given to Load, files appeared
    96  // first takes higher priority, if a config field appears in more
    97  // than one files, only the first has effect.
    98  //
    99  // 5. default values defined by field tag `default`;
   100  type Loader struct {
   101  	*Config
   102  }
   103  
   104  // New creates a new Loader.
   105  func New(config *Config) *Loader {
   106  	if config == nil {
   107  		config = &Config{}
   108  	}
   109  
   110  	return &Loader{Config: config}
   111  }
   112  
   113  // Load creates a Loader with nil config and loads configuration to dst,
   114  // it is a shortcut for New(nil).Load(dst, files...).
   115  func Load(dst interface{}, files ...string) error {
   116  	return New(nil).Load(dst, files...)
   117  }
   118  
   119  // Load loads configuration to dst using using the Loader's Config
   120  // and the given configuration files.
   121  //
   122  // See Loader and Config for detailed document.
   123  func (p *Loader) Load(dst interface{}, files ...string) error {
   124  	return p.load(dst, files...)
   125  }
   126  
   127  func (p *Loader) load(dst interface{}, files ...string) error {
   128  	dstTyp := reflect.TypeOf(dst)
   129  	if dstTyp.Kind() != reflect.Ptr || dstTyp.Elem().Kind() != reflect.Struct {
   130  		return errors.New("invalid destination, should be a struct pointer")
   131  	}
   132  
   133  	if err := p.loadFiles(dst, files...); err != nil {
   134  		return err
   135  	}
   136  	if err := p.processEnv(dst, ""); err != nil {
   137  		return err
   138  	}
   139  	if err := p.processCustom(dst); err != nil {
   140  		return err
   141  	}
   142  	if err := p.processDefaults(dst); err != nil {
   143  		return err
   144  	}
   145  	if err := p.processFlags(dst); err != nil {
   146  		return err
   147  	}
   148  	return nil
   149  }
   150  
   151  func (p *Loader) loadFiles(config interface{}, files ...string) error {
   152  	for i := len(files) - 1; i >= 0; i-- {
   153  		file := files[i]
   154  		if p.Verbose {
   155  			log.Printf("loading configuration from file %s", file)
   156  		}
   157  
   158  		err := p.processFile(config, file)
   159  		if err != nil {
   160  			return err
   161  		}
   162  	}
   163  	return nil
   164  }
   165  
   166  func (p *Loader) processFile(config interface{}, file string) error {
   167  	if info, err := os.Stat(file); err != nil || !info.Mode().IsRegular() {
   168  		return fmt.Errorf("invalid configuration file: %s", file)
   169  	}
   170  
   171  	var unmarshalFunc func(data []byte, v interface{}, disallowUnknownFields bool) error
   172  	extname := path.Ext(file)
   173  	switch strings.ToLower(extname) {
   174  	case ".json":
   175  		unmarshalFunc = unmarshalJSON
   176  	case ".yaml", ".yml":
   177  		unmarshalFunc = unmarshalYAML
   178  	case ".toml":
   179  		unmarshalFunc = unmarshalTOML
   180  	default:
   181  		return fmt.Errorf("unsupported file type: %v", extname)
   182  	}
   183  	data, err := ioutil.ReadFile(file)
   184  	if err != nil {
   185  		return fmt.Errorf("cannot read file %s: %w", file, err)
   186  	}
   187  	err = unmarshalFunc(data, config, p.DisallowUnknownFields)
   188  	if err != nil {
   189  		return fmt.Errorf("cannot unmarshal file %s: %v", file, err)
   190  	}
   191  	return nil
   192  }
   193  
   194  func unmarshalJSON(data []byte, v interface{}, disallowUnknownFields bool) error {
   195  	if disallowUnknownFields {
   196  		dec := json.NewDecoder(bytes.NewReader(data))
   197  		dec.DisallowUnknownFields()
   198  		return dec.Decode(v)
   199  	}
   200  	return json.Unmarshal(data, v)
   201  }
   202  
   203  func unmarshalYAML(data []byte, v interface{}, disallowUnknownFields bool) error {
   204  	if disallowUnknownFields {
   205  		return yaml.UnmarshalStrict(data, v)
   206  	}
   207  	return yaml.Unmarshal(data, v)
   208  }
   209  
   210  func unmarshalTOML(data []byte, v interface{}, disallowUnknownFields bool) error {
   211  	meta, err := toml.Decode(string(data), v)
   212  	if err == nil && len(meta.Undecoded()) > 0 && disallowUnknownFields {
   213  		return fmt.Errorf("toml: unknown fields %v", meta.Undecoded())
   214  	}
   215  	return err
   216  }
   217  
   218  func (p *Loader) processDefaults(config interface{}) error {
   219  	configVal := reflect.Indirect(reflect.ValueOf(config))
   220  	configTyp := configVal.Type()
   221  
   222  	for i := 0; i < configTyp.NumField(); i++ {
   223  		field := configTyp.Field(i)
   224  		fieldVal := configVal.Field(i)
   225  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   226  			continue
   227  		}
   228  		if field.Tag.Get(ConfrTag) == "-" {
   229  			continue
   230  		}
   231  
   232  		defaultValue := field.Tag.Get(DefaultValueTag)
   233  		if defaultValue != "" {
   234  			if p.Verbose {
   235  				log.Printf("processing default value for field %s.%s", configTyp.Name(), field.Name)
   236  			}
   237  
   238  			isBlank := reflect.DeepEqual(fieldVal.Interface(), reflect.Zero(field.Type).Interface())
   239  			if isBlank {
   240  				err := assignFieldValue(fieldVal, defaultValue)
   241  				if err != nil {
   242  					return fmt.Errorf("cannot assign default value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   243  				}
   244  			}
   245  		}
   246  
   247  		fieldVal = reflect.Indirect(fieldVal)
   248  		switch fieldVal.Kind() {
   249  		case reflect.Struct:
   250  			if err := p.processDefaults(fieldVal.Addr().Interface()); err != nil {
   251  				return err
   252  			}
   253  		case reflect.Slice:
   254  			for i := 0; i < fieldVal.Len(); i++ {
   255  				elemVal := reflect.Indirect(fieldVal.Index(i))
   256  				if elemVal.Kind() == reflect.Struct {
   257  					if err := p.processDefaults(elemVal.Addr().Interface()); err != nil {
   258  						return err
   259  					}
   260  				}
   261  			}
   262  		}
   263  	}
   264  	return nil
   265  }
   266  
   267  func (p *Loader) processFlags(config interface{}) error {
   268  	if p.FlagSet == nil {
   269  		return nil
   270  	}
   271  	if !p.FlagSet.Parsed() {
   272  		return errors.New("flag set is not parsed")
   273  	}
   274  
   275  	fs := p.FlagSet
   276  	configVal := reflect.Indirect(reflect.ValueOf(config))
   277  	configTyp := configVal.Type()
   278  
   279  	for i := 0; i < configTyp.NumField(); i++ {
   280  		field := configTyp.Field(i)
   281  		fieldVal := configVal.Field(i)
   282  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   283  			continue
   284  		}
   285  		if field.Tag.Get(ConfrTag) == "-" {
   286  			continue
   287  		}
   288  
   289  		flagName := field.Tag.Get(FlagTag)
   290  		if flagName != "" && flagName != "-" {
   291  			if p.Verbose {
   292  				log.Printf("processing flag for field %s.%s", configTyp.Name(), field.Name)
   293  			}
   294  
   295  			if flagVal, isSet := lookupFlag(fs, flagName); flagVal != nil {
   296  				err := assignFlagValue(fieldVal, flagVal, isSet)
   297  				if err != nil {
   298  					return fmt.Errorf("cannot assign flag value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   299  				}
   300  			}
   301  		}
   302  
   303  		fieldVal = reflect.Indirect(fieldVal)
   304  		switch fieldVal.Kind() {
   305  		case reflect.Struct:
   306  			if err := p.processFlags(fieldVal.Addr().Interface()); err != nil {
   307  				return err
   308  			}
   309  		}
   310  	}
   311  	return nil
   312  }
   313  
   314  // lookupFlag returns a flag and tells whether the flag is set.
   315  func lookupFlag(fs *flag.FlagSet, name string) (out *flag.Flag, isSet bool) {
   316  	fs.Visit(func(f *flag.Flag) {
   317  		if f.Name == name {
   318  			out = f
   319  			isSet = true
   320  		}
   321  	})
   322  	if out == nil {
   323  		out = fs.Lookup(name)
   324  	}
   325  	return
   326  }
   327  
   328  func (p *Loader) processEnv(config interface{}, prefix string) error {
   329  	configVal := reflect.Indirect(reflect.ValueOf(config))
   330  	configTyp := configVal.Type()
   331  
   332  	for i := 0; i < configTyp.NumField(); i++ {
   333  		field := configTyp.Field(i)
   334  		fieldVal := configVal.Field(i)
   335  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   336  			continue
   337  		}
   338  		if field.Tag.Get(ConfrTag) == "-" {
   339  			continue
   340  		}
   341  
   342  		var envNames []string
   343  		envTag := field.Tag.Get(EnvTag)
   344  		if envTag != "" {
   345  			envNames = append(envNames, envTag)
   346  		} else if p.EnableImplicitEnv {
   347  			tmp := p.getEnvName(prefix, field.Name)
   348  			envNames = append(envNames, tmp, strings.ToUpper(tmp))
   349  		}
   350  		if len(envNames) > 0 {
   351  			if p.Verbose {
   352  				log.Printf("loading env for field %s.%s from %v", configTyp.Name(), field.Name, envNames)
   353  			}
   354  
   355  			for _, envName := range envNames {
   356  				if value := os.Getenv(envName); value != "" {
   357  					err := assignFieldValue(fieldVal, value)
   358  					if err != nil {
   359  						return fmt.Errorf("cannot assign env value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   360  					}
   361  					break
   362  				}
   363  			}
   364  		}
   365  
   366  		fieldVal = reflect.Indirect(fieldVal)
   367  		switch fieldVal.Kind() {
   368  		case reflect.Struct:
   369  			fieldPrefix := p.getEnvName(prefix, field.Name)
   370  			if err := p.processEnv(fieldVal.Addr().Interface(), fieldPrefix); err != nil {
   371  				return err
   372  			}
   373  		}
   374  	}
   375  	return nil
   376  }
   377  
   378  func (p *Loader) getEnvName(prefix string, name string) string {
   379  	var envName []byte
   380  	for i := 0; i < len(name); i++ {
   381  		if i > 0 && unicode.IsUpper(rune(name[i])) &&
   382  			name[i-1] != '_' &&
   383  			((i+1 < len(name) && unicode.IsLower(rune(name[i+1]))) || unicode.IsLower(rune(name[i-1]))) {
   384  			envName = append(envName, '_')
   385  		}
   386  		envName = append(envName, name[i])
   387  	}
   388  	if prefix == "" {
   389  		prefix = p.EnvPrefix
   390  		if prefix == "" {
   391  			prefix = DefaultEnvPrefix
   392  		}
   393  	}
   394  	return prefix + "_" + string(envName)
   395  }
   396  
   397  func (p *Loader) processCustom(config interface{}) error {
   398  	if p.CustomLoader == nil {
   399  		return nil
   400  	}
   401  	configVal := reflect.Indirect(reflect.ValueOf(config))
   402  	configTyp := configVal.Type()
   403  
   404  	for i := 0; i < configTyp.NumField(); i++ {
   405  		field := configTyp.Field(i)
   406  		fieldVal := configVal.Field(i)
   407  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   408  			continue
   409  		}
   410  		if field.Tag.Get(ConfrTag) == "-" {
   411  			continue
   412  		}
   413  
   414  		customTag := field.Tag.Get(CustomTag)
   415  		if customTag != "" && customTag != "-" {
   416  			if p.Verbose {
   417  				log.Printf("processing custom loader for field %s.%s", configTyp.Name(), field.Name)
   418  			}
   419  
   420  			tmp, err := p.CustomLoader(fieldVal.Type(), customTag)
   421  			if err != nil {
   422  				return err
   423  			}
   424  			if err = assignFieldValue(fieldVal, tmp); err != nil {
   425  				return fmt.Errorf("cannot assign custom value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   426  			}
   427  		}
   428  
   429  		fieldVal = reflect.Indirect(fieldVal)
   430  		switch fieldVal.Kind() {
   431  		case reflect.Struct:
   432  			if err := p.processCustom(fieldVal.Addr().Interface()); err != nil {
   433  				return err
   434  			}
   435  		}
   436  	}
   437  	return nil
   438  }
   439  
   440  func assignFlagValue(dst reflect.Value, ff *flag.Flag, isSet bool) error {
   441  	if isSet {
   442  		if getter, ok := ff.Value.(flag.Getter); ok {
   443  			return assignFieldValue(dst, getter.Get())
   444  		}
   445  		return assignFieldValue(dst, ff.Value.String())
   446  	}
   447  
   448  	// default value
   449  	if dst.IsZero() && ff.DefValue != "" {
   450  		return assignFieldValue(dst, ff.DefValue)
   451  	}
   452  	return nil
   453  }
   454  
   455  func assignFieldValue(dst reflect.Value, value interface{}) error {
   456  	inputVal := reflect.ValueOf(value)
   457  	if dst.Type() == inputVal.Type() {
   458  		dst.Set(inputVal)
   459  		return nil
   460  	}
   461  
   462  	var ptrDest reflect.Value
   463  	if dst.Kind() == reflect.Ptr {
   464  		ptrDest = dst
   465  		dst = reflect.New(dst.Type().Elem()).Elem()
   466  		if dst.Type() == inputVal.Type() {
   467  			dst.Set(inputVal)
   468  			ptrDest.Set(dst.Addr())
   469  			return nil
   470  		}
   471  	}
   472  
   473  	var err error
   474  	var val interface{}
   475  	switch dst.Interface().(type) {
   476  	case bool:
   477  		val, err = toBooleanE(value)
   478  	case int:
   479  		val, err = cast.ToIntE(value)
   480  	case []int:
   481  		val, err = cast.ToIntSliceE(value)
   482  	case int64:
   483  		val, err = cast.ToInt64E(value)
   484  	case []int64:
   485  		val, err = toInt64SliceE(value)
   486  	case int32:
   487  		val, err = cast.ToInt32E(value)
   488  	case []int32:
   489  		val, err = toInt32SliceE(value)
   490  	case float64:
   491  		val, err = cast.ToFloat64E(value)
   492  	case float32:
   493  		val, err = cast.ToFloat32E(value)
   494  	case string:
   495  		val, err = cast.ToStringE(value)
   496  	case []string:
   497  		val, err = cast.ToStringSliceE(value)
   498  	case map[string]bool:
   499  		val, err = cast.ToStringMapBoolE(value)
   500  	case map[string]int:
   501  		val, err = cast.ToStringMapIntE(value)
   502  	case map[string]int64:
   503  		val, err = cast.ToStringMapInt64E(value)
   504  	case map[string]string:
   505  		val, err = cast.ToStringMapStringE(value)
   506  	case map[string][]string:
   507  		val, err = cast.ToStringMapStringSliceE(value)
   508  	case map[string]interface{}:
   509  		val, err = cast.ToStringMapE(value)
   510  	default:
   511  		err = errors.New("unsupported type")
   512  	}
   513  	if err != nil {
   514  		return err
   515  	}
   516  
   517  	dst.Set(reflect.ValueOf(val))
   518  	if ptrDest.IsValid() {
   519  		ptrDest.Set(dst.Addr())
   520  	}
   521  	return nil
   522  }
   523  
   524  func toBooleanE(v interface{}) (bool, error) {
   525  	if strval, ok := v.(string); ok {
   526  		switch strval {
   527  		case "", "0", "f", "false", "no", "off":
   528  			return false, nil
   529  		case "1", "t", "true", "yes", "on":
   530  			return true, nil
   531  		default:
   532  			return false, fmt.Errorf("invalid boolean value: %s", strval)
   533  		}
   534  	}
   535  	return cast.ToBoolE(v)
   536  }
   537  
   538  func toInt64SliceE(v interface{}) ([]int64, error) {
   539  	intValues, err := cast.ToIntSliceE(v)
   540  	if err != nil {
   541  		return nil, err
   542  	}
   543  	out := make([]int64, len(intValues))
   544  	for i, x := range intValues {
   545  		out[i] = int64(x)
   546  	}
   547  	return out, nil
   548  }
   549  
   550  func toInt32SliceE(v interface{}) ([]int32, error) {
   551  	intValues, err := cast.ToIntSliceE(v)
   552  	if err != nil {
   553  		return nil, err
   554  	}
   555  	out := make([]int32, len(intValues))
   556  	for i, x := range intValues {
   557  		out[i] = int32(x)
   558  	}
   559  	return out, nil
   560  }