github.com/jxskiss/gopkg/v2@v2.14.9-0.20240514120614-899f3e7952b4/confr/loader.go (about)

     1  // Package confr provides a simple but yet powerful configuration loader.
     2  //
     3  // Features:
     4  //
     5  // 1. Load from command line flags defined by field tag `flag`;
     6  //
     7  // 2. Load by custom loader function for fields which have a `custom` tag,
     8  // this is useful where you may have configuration values stored in a
     9  // centric repository or a remote config center;
    10  //
    11  // 3. Load from environment variables by explicitly defined `env` tag or
    12  // auto-generated names implicitly;
    13  //
    14  // 4. Load from multiple configuration fields with priority and overriding;
    15  //
    16  // 5. Set default values by field tag `default` if a configuration field
    17  // is not given by any of the higher priority source;
    18  //
    19  // 6. Minimal dependency;
    20  //
    21  // You may check Config and Loader for more details.
    22  package confr
    23  
    24  import (
    25  	"bytes"
    26  	"encoding/json"
    27  	"errors"
    28  	"flag"
    29  	"fmt"
    30  	"log"
    31  	"os"
    32  	"path"
    33  	"reflect"
    34  	"strings"
    35  	"time"
    36  	"unicode"
    37  
    38  	"github.com/BurntSushi/toml"
    39  	"github.com/spf13/cast"
    40  	"gopkg.in/yaml.v3"
    41  )
    42  
    43  const DefaultEnvPrefix = "Confr"
    44  
    45  const (
    46  	ConfrTag        = "confr"
    47  	CustomTag       = "custom"
    48  	DefaultValueTag = "default"
    49  	EnvTag          = "env"
    50  	FlagTag         = "flag"
    51  )
    52  
    53  // Config provides options to configure the behavior of Loader.
    54  type Config struct {
    55  
    56  	// LogFunc specifies a log function to use instead of [log.Printf].
    57  	LogFunc func(format string, v ...any)
    58  
    59  	// Verbose tells the loader to output verbose logging messages.
    60  	Verbose bool
    61  
    62  	// DisallowUnknownFields causes the loader to return an error when
    63  	// the configuration files contain object keys which do not match
    64  	// the given destination struct.
    65  	DisallowUnknownFields bool
    66  
    67  	// EnableImplicitEnv enables the loader checking auto-generated names
    68  	// to find environment variables.
    69  	// The default is false, which means the loader will only check `env`
    70  	// tag, won't check auto-generated names.
    71  	EnableImplicitEnv bool
    72  
    73  	// EnvPrefix is used to prefix the auto-generated names to find
    74  	// environment variables. The default value is "Confr".
    75  	EnvPrefix string
    76  
    77  	// CustomLoader optionally loads fields which have a `custom` tag,
    78  	// the field's type and the tag value will be passed to the custom loader.
    79  	CustomLoader func(typ reflect.Type, tag string) (any, error)
    80  
    81  	// UnmarshalFunc optionally specifies an unmarshal function
    82  	// to use instead of the default function.
    83  	UnmarshalFunc func(data []byte, v any, disallowUnknownFields bool) error
    84  
    85  	// FlagSet optionally specifies a flag set to lookup flag value
    86  	// for fields which have a `flag` tag. The tag value should be the
    87  	// flag name to lookup for.
    88  	FlagSet *flag.FlagSet
    89  }
    90  
    91  // Loader is used to load configuration from files (JSON/TOML/YAML),
    92  // environment variables, command line flags, or by custom loader function.
    93  //
    94  // The priority in descending order is:
    95  //
    96  // 1. command line flag defined by field tag `flag`;
    97  //
    98  // 2. custom loader function defined by field tag `custom`;
    99  //
   100  // 3. environment variables;
   101  //
   102  // 4. config files, if multiple files are given to Load, files appeared
   103  // first takes higher priority, if a config field appears in more
   104  // than one files, only the first has effect.
   105  //
   106  // 5. default values defined by field tag `default`;
   107  type Loader struct {
   108  	*Config
   109  }
   110  
   111  // New creates a new Loader.
   112  func New(config *Config) *Loader {
   113  	if config == nil {
   114  		config = &Config{}
   115  	}
   116  
   117  	return &Loader{Config: config}
   118  }
   119  
   120  // Load creates a Loader with nil config and loads configuration to dst,
   121  // it is a shortcut for New(nil).Load(dst, files...).
   122  func Load(dst any, files ...string) error {
   123  	return New(nil).Load(dst, files...)
   124  }
   125  
   126  // Load loads configuration to dst using the Loader's Config
   127  // and the given configuration files.
   128  //
   129  // See Loader and Config for detailed document.
   130  func (p *Loader) Load(dst any, files ...string) error {
   131  	return p.load(dst, files...)
   132  }
   133  
   134  func (p *Loader) load(dst any, files ...string) error {
   135  	dstTyp := reflect.TypeOf(dst)
   136  	if dstTyp.Kind() != reflect.Ptr || dstTyp.Elem().Kind() != reflect.Struct {
   137  		return errors.New("invalid destination, should be a struct pointer")
   138  	}
   139  
   140  	if err := p.loadFiles(dst, files...); err != nil {
   141  		return err
   142  	}
   143  	if err := p.processEnv(dst, ""); err != nil {
   144  		return err
   145  	}
   146  	if err := p.processCustom(dst); err != nil {
   147  		return err
   148  	}
   149  	if err := p.processDefaults(dst); err != nil {
   150  		return err
   151  	}
   152  	if err := p.processFlags(dst); err != nil {
   153  		return err
   154  	}
   155  	return nil
   156  }
   157  
   158  func (p *Loader) getLogFunc() func(format string, v ...any) {
   159  	if p.LogFunc != nil {
   160  		return p.LogFunc
   161  	}
   162  	return log.Printf
   163  }
   164  
   165  func (p *Loader) loadFiles(config any, files ...string) error {
   166  	for i := len(files) - 1; i >= 0; i-- {
   167  		file := files[i]
   168  		err := p.processFile(config, file)
   169  		if err != nil {
   170  			return err
   171  		}
   172  	}
   173  	return nil
   174  }
   175  
   176  func (p *Loader) processFile(config any, file string) error {
   177  	if info, err := os.Stat(file); err != nil || !info.Mode().IsRegular() {
   178  		return fmt.Errorf("invalid configuration file: %s", file)
   179  	}
   180  
   181  	p.getLogFunc()("loading configuration from file: %v", file)
   182  	var unmarshalFunc = p.Config.UnmarshalFunc
   183  	if unmarshalFunc == nil {
   184  		extname := path.Ext(file)
   185  		switch strings.ToLower(extname) {
   186  		case ".json":
   187  			unmarshalFunc = unmarshalJSON
   188  		case ".yaml", ".yml":
   189  			unmarshalFunc = unmarshalYAML
   190  		case ".toml":
   191  			unmarshalFunc = unmarshalTOML
   192  		default:
   193  			return fmt.Errorf("unsupported file type: %v", extname)
   194  		}
   195  	}
   196  	data, err := os.ReadFile(file)
   197  	if err != nil {
   198  		return fmt.Errorf("cannot read file %s: %w", file, err)
   199  	}
   200  	err = unmarshalFunc(data, config, p.DisallowUnknownFields)
   201  	if err != nil {
   202  		return fmt.Errorf("cannot unmarshal file %s: %w", file, err)
   203  	}
   204  	return nil
   205  }
   206  
   207  func unmarshalJSON(data []byte, v any, disallowUnknownFields bool) error {
   208  	if disallowUnknownFields {
   209  		dec := json.NewDecoder(bytes.NewReader(data))
   210  		dec.DisallowUnknownFields()
   211  		return dec.Decode(v)
   212  	}
   213  	return json.Unmarshal(data, v)
   214  }
   215  
   216  func unmarshalYAML(data []byte, v any, disallowUnknownFields bool) error {
   217  	dec := yaml.NewDecoder(bytes.NewReader(data))
   218  	if disallowUnknownFields {
   219  		dec.KnownFields(true)
   220  	}
   221  	return dec.Decode(v)
   222  }
   223  
   224  func unmarshalTOML(data []byte, v any, disallowUnknownFields bool) error {
   225  	meta, err := toml.Decode(string(data), v)
   226  	if err == nil && len(meta.Undecoded()) > 0 && disallowUnknownFields {
   227  		return fmt.Errorf("toml: unknown fields %v", meta.Undecoded())
   228  	}
   229  	return err
   230  }
   231  
   232  func (p *Loader) processDefaults(config any) error {
   233  	configVal := reflect.Indirect(reflect.ValueOf(config))
   234  	configTyp := configVal.Type()
   235  
   236  	for i := 0; i < configTyp.NumField(); i++ {
   237  		field := configTyp.Field(i)
   238  		fieldVal := configVal.Field(i)
   239  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   240  			continue
   241  		}
   242  		if field.Tag.Get(ConfrTag) == "-" {
   243  			continue
   244  		}
   245  
   246  		defaultValue := field.Tag.Get(DefaultValueTag)
   247  		if defaultValue != "" {
   248  			if p.Verbose {
   249  				p.getLogFunc()("processing default value for field %s.%s", configTyp.Name(), field.Name)
   250  			}
   251  
   252  			isBlank := reflect.DeepEqual(fieldVal.Interface(), reflect.Zero(field.Type).Interface())
   253  			if isBlank {
   254  				err := assignFieldValue(fieldVal, defaultValue)
   255  				if err != nil {
   256  					return fmt.Errorf("cannot assign default value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   257  				}
   258  			}
   259  		}
   260  
   261  		fieldVal = reflect.Indirect(fieldVal)
   262  		switch fieldVal.Kind() {
   263  		case reflect.Struct:
   264  			if err := p.processDefaults(fieldVal.Addr().Interface()); err != nil {
   265  				return err
   266  			}
   267  		case reflect.Slice:
   268  			for i := 0; i < fieldVal.Len(); i++ {
   269  				elemVal := reflect.Indirect(fieldVal.Index(i))
   270  				if elemVal.Kind() == reflect.Struct {
   271  					if err := p.processDefaults(elemVal.Addr().Interface()); err != nil {
   272  						return err
   273  					}
   274  				}
   275  			}
   276  		}
   277  	}
   278  	return nil
   279  }
   280  
   281  func (p *Loader) processFlags(config any) error {
   282  	if p.FlagSet == nil {
   283  		return nil
   284  	}
   285  	if !p.FlagSet.Parsed() {
   286  		return errors.New("flag set is not parsed")
   287  	}
   288  
   289  	fs := p.FlagSet
   290  	configVal := reflect.Indirect(reflect.ValueOf(config))
   291  	configTyp := configVal.Type()
   292  
   293  	for i := 0; i < configTyp.NumField(); i++ {
   294  		field := configTyp.Field(i)
   295  		fieldVal := configVal.Field(i)
   296  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   297  			continue
   298  		}
   299  		if field.Tag.Get(ConfrTag) == "-" {
   300  			continue
   301  		}
   302  
   303  		flagName := field.Tag.Get(FlagTag)
   304  		if flagName != "" && flagName != "-" {
   305  			if p.Verbose {
   306  				p.getLogFunc()("processing flag for field %s.%s", configTyp.Name(), field.Name)
   307  			}
   308  
   309  			if flagVal, isSet := lookupFlag(fs, flagName); flagVal != nil {
   310  				err := assignFlagValue(fieldVal, flagVal, isSet)
   311  				if err != nil {
   312  					return fmt.Errorf("cannot assign flag value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   313  				}
   314  			}
   315  		}
   316  
   317  		fieldVal = reflect.Indirect(fieldVal)
   318  		switch fieldVal.Kind() {
   319  		case reflect.Struct:
   320  			if err := p.processFlags(fieldVal.Addr().Interface()); err != nil {
   321  				return err
   322  			}
   323  		}
   324  	}
   325  	return nil
   326  }
   327  
   328  // lookupFlag returns a flag and tells whether the flag is set.
   329  func lookupFlag(fs *flag.FlagSet, name string) (out *flag.Flag, isSet bool) {
   330  	fs.Visit(func(f *flag.Flag) {
   331  		if f.Name == name {
   332  			out = f
   333  			isSet = true
   334  		}
   335  	})
   336  	if out == nil {
   337  		out = fs.Lookup(name)
   338  	}
   339  	return
   340  }
   341  
   342  func (p *Loader) processEnv(config any, prefix string) error {
   343  	configVal := reflect.Indirect(reflect.ValueOf(config))
   344  	configTyp := configVal.Type()
   345  
   346  	for i := 0; i < configTyp.NumField(); i++ {
   347  		field := configTyp.Field(i)
   348  		fieldVal := configVal.Field(i)
   349  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   350  			continue
   351  		}
   352  		if field.Tag.Get(ConfrTag) == "-" {
   353  			continue
   354  		}
   355  
   356  		var envNames []string
   357  		envTag := field.Tag.Get(EnvTag)
   358  		if envTag != "" {
   359  			for _, name := range strings.Split(envTag, ",") {
   360  				name = strings.TrimSpace(name)
   361  				if name != "" {
   362  					envNames = append(envNames, name)
   363  				}
   364  			}
   365  		} else if p.EnableImplicitEnv {
   366  			tmp := p.getEnvName(prefix, field.Name)
   367  			envNames = append(envNames, tmp, strings.ToUpper(tmp))
   368  		}
   369  		if len(envNames) > 0 {
   370  			if p.Verbose {
   371  				p.getLogFunc()("loading env for field %s.%s from %v", configTyp.Name(), field.Name, envNames)
   372  			}
   373  
   374  			for _, envName := range envNames {
   375  				if value := os.Getenv(envName); value != "" {
   376  					err := assignFieldValue(fieldVal, value)
   377  					if err != nil {
   378  						return fmt.Errorf("cannot assign env value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   379  					}
   380  					break
   381  				}
   382  			}
   383  		}
   384  
   385  		fieldVal = reflect.Indirect(fieldVal)
   386  		switch fieldVal.Kind() {
   387  		case reflect.Struct:
   388  			fieldPrefix := p.getEnvName(prefix, field.Name)
   389  			if err := p.processEnv(fieldVal.Addr().Interface(), fieldPrefix); err != nil {
   390  				return err
   391  			}
   392  		}
   393  	}
   394  	return nil
   395  }
   396  
   397  func (p *Loader) getEnvName(prefix string, name string) string {
   398  	var envName []byte
   399  	for i := 0; i < len(name); i++ {
   400  		if i > 0 && unicode.IsUpper(rune(name[i])) &&
   401  			name[i-1] != '_' &&
   402  			((i+1 < len(name) && unicode.IsLower(rune(name[i+1]))) || unicode.IsLower(rune(name[i-1]))) {
   403  			envName = append(envName, '_')
   404  		}
   405  		envName = append(envName, name[i])
   406  	}
   407  	if prefix == "" {
   408  		prefix = p.EnvPrefix
   409  		if prefix == "" {
   410  			prefix = DefaultEnvPrefix
   411  		}
   412  	}
   413  	return prefix + "_" + string(envName)
   414  }
   415  
   416  func (p *Loader) processCustom(config any) error {
   417  	if p.CustomLoader == nil {
   418  		return nil
   419  	}
   420  	configVal := reflect.Indirect(reflect.ValueOf(config))
   421  	configTyp := configVal.Type()
   422  
   423  	for i := 0; i < configTyp.NumField(); i++ {
   424  		field := configTyp.Field(i)
   425  		fieldVal := configVal.Field(i)
   426  		if !fieldVal.CanAddr() || !fieldVal.CanInterface() {
   427  			continue
   428  		}
   429  		if field.Tag.Get(ConfrTag) == "-" {
   430  			continue
   431  		}
   432  
   433  		customTag := field.Tag.Get(CustomTag)
   434  		if customTag != "" && customTag != "-" {
   435  			if p.Verbose {
   436  				p.getLogFunc()("processing custom loader for field %s.%s", configTyp.Name(), field.Name)
   437  			}
   438  
   439  			tmp, err := p.CustomLoader(fieldVal.Type(), customTag)
   440  			if err != nil {
   441  				return err
   442  			}
   443  			if err = assignFieldValue(fieldVal, tmp); err != nil {
   444  				return fmt.Errorf("cannot assign custom value to field %s.%s: %w", configTyp.Name(), field.Name, err)
   445  			}
   446  		}
   447  
   448  		fieldVal = reflect.Indirect(fieldVal)
   449  		switch fieldVal.Kind() {
   450  		case reflect.Struct:
   451  			if err := p.processCustom(fieldVal.Addr().Interface()); err != nil {
   452  				return err
   453  			}
   454  		}
   455  	}
   456  	return nil
   457  }
   458  
   459  func assignFlagValue(dst reflect.Value, ff *flag.Flag, isSet bool) error {
   460  	if isSet {
   461  		if getter, ok := ff.Value.(flag.Getter); ok {
   462  			return assignFieldValue(dst, getter.Get())
   463  		}
   464  		return assignFieldValue(dst, ff.Value.String())
   465  	}
   466  
   467  	// default value
   468  	if dst.IsZero() && ff.DefValue != "" {
   469  		return assignFieldValue(dst, ff.DefValue)
   470  	}
   471  	return nil
   472  }
   473  
   474  func assignFieldValue(dst reflect.Value, value any) error {
   475  	inputVal := reflect.ValueOf(value)
   476  	if dst.Type() == inputVal.Type() {
   477  		dst.Set(inputVal)
   478  		return nil
   479  	}
   480  
   481  	var ptrDest reflect.Value
   482  	if dst.Kind() == reflect.Ptr {
   483  		ptrDest = dst
   484  		dst = reflect.New(dst.Type().Elem()).Elem()
   485  		if dst.Type() == inputVal.Type() {
   486  			dst.Set(inputVal)
   487  			ptrDest.Set(dst.Addr())
   488  			return nil
   489  		}
   490  	}
   491  
   492  	var err error
   493  	var val any
   494  	switch dst.Interface().(type) {
   495  	case bool:
   496  		val, err = toBooleanE(value)
   497  	case int:
   498  		val, err = cast.ToIntE(value)
   499  	case []int:
   500  		val, err = cast.ToIntSliceE(value)
   501  	case int64:
   502  		val, err = cast.ToInt64E(value)
   503  	case []int64:
   504  		val, err = toIntSlice[int64](value)
   505  	case int32:
   506  		val, err = cast.ToInt32E(value)
   507  	case []int32:
   508  		val, err = toIntSlice[int32](value)
   509  	case float64:
   510  		val, err = cast.ToFloat64E(value)
   511  	case float32:
   512  		val, err = cast.ToFloat32E(value)
   513  	case string:
   514  		val, err = cast.ToStringE(value)
   515  	case []string:
   516  		val, err = cast.ToStringSliceE(value)
   517  	case map[string]bool:
   518  		val, err = cast.ToStringMapBoolE(value)
   519  	case map[string]int:
   520  		val, err = cast.ToStringMapIntE(value)
   521  	case map[string]int64:
   522  		val, err = cast.ToStringMapInt64E(value)
   523  	case map[string]string:
   524  		val, err = cast.ToStringMapStringE(value)
   525  	case map[string][]string:
   526  		val, err = cast.ToStringMapStringSliceE(value)
   527  	case map[string]any:
   528  		val, err = cast.ToStringMapE(value)
   529  	case time.Duration:
   530  		val, err = cast.ToDurationE(value)
   531  	case []time.Duration:
   532  		val, err = cast.ToDurationSliceE(value)
   533  	default:
   534  		err = errors.New("unsupported type")
   535  	}
   536  	if err != nil {
   537  		return err
   538  	}
   539  
   540  	dst.Set(reflect.ValueOf(val))
   541  	if ptrDest.IsValid() {
   542  		ptrDest.Set(dst.Addr())
   543  	}
   544  	return nil
   545  }
   546  
   547  func toBooleanE(v any) (bool, error) {
   548  	if strval, ok := v.(string); ok {
   549  		switch strval {
   550  		case "", "0", "f", "false", "no", "off":
   551  			return false, nil
   552  		case "1", "t", "true", "yes", "on":
   553  			return true, nil
   554  		}
   555  	}
   556  	return cast.ToBoolE(v)
   557  }
   558  
   559  func toIntSlice[T ~int32 | ~int64](v any) ([]T, error) {
   560  	intValues, err := cast.ToIntSliceE(v)
   561  	if err != nil {
   562  		return nil, err
   563  	}
   564  	out := make([]T, len(intValues))
   565  	for i, x := range intValues {
   566  		out[i] = T(x)
   567  	}
   568  	return out, nil
   569  }