github.com/wfusion/gofusion@v1.1.14/internal/configor/utils.go (about)

     1  // Fork from github.com/jinzhu/configor@v1.2.2-0.20230118083828-f7a0fc7c9fc6
     2  // Here is the license:
     3  //
     4  // The MIT License (MIT)
     5  //
     6  // Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
     7  //
     8  // Permission is hereby granted, free of charge, to any person obtaining a copy
     9  // of this software and associated documentation files (the "Software"), to deal
    10  // in the Software without restriction, including without limitation the rights
    11  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    12  // copies of the Software, and to permit persons to whom the Software is
    13  // furnished to do so, subject to the following conditions:
    14  //
    15  // The above copyright notice and this permission notice shall be included in all
    16  // copies or substantial portions of the Software.
    17  //
    18  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    19  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    20  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    21  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    22  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    23  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    24  // SOFTWARE.
    25  
    26  package configor
    27  
    28  import (
    29  	"bytes"
    30  	"fmt"
    31  	"io"
    32  	"io/fs"
    33  	"io/ioutil"
    34  	"log"
    35  	"os"
    36  	"path"
    37  	"reflect"
    38  	"runtime/debug"
    39  	"strings"
    40  	"time"
    41  
    42  	"github.com/BurntSushi/toml"
    43  	"github.com/pkg/errors"
    44  	"gopkg.in/yaml.v3"
    45  
    46  	"github.com/wfusion/gofusion/common/utils"
    47  	"github.com/wfusion/gofusion/common/utils/serialize/json"
    48  )
    49  
    50  // UnmatchedTomlKeysError errors are returned by the Load function when
    51  // ErrorOnUnmatchedKeys is set to true and there are unmatched keys in the input
    52  // toml config file. The string returned by Error() contains the names of the
    53  // missing keys.
    54  type UnmatchedTomlKeysError struct {
    55  	Keys []toml.Key
    56  }
    57  
    58  func (e *UnmatchedTomlKeysError) Error() string {
    59  	return fmt.Sprintf("There are keys in the config file that do not match any field in the given struct: %v", e.Keys)
    60  }
    61  
    62  func (c *Configor) getENVPrefix(config any) string {
    63  	if c.Config.ENVPrefix == "" {
    64  		if prefix := os.Getenv("CONFIGOR_ENV_PREFIX"); prefix != "" {
    65  			return prefix
    66  		}
    67  		return "Configor"
    68  	}
    69  	return c.Config.ENVPrefix
    70  }
    71  
    72  func (c *Configor) getConfigurationFileWithENVPrefix(file, env string) (string, time.Time, string, error) {
    73  	var (
    74  		envFile string
    75  		extname = path.Ext(file)
    76  	)
    77  
    78  	if extname == "" {
    79  		envFile = fmt.Sprintf("%v.%v", file, env)
    80  	} else {
    81  		envFile = fmt.Sprintf("%v.%v%v", strings.TrimSuffix(file, extname), env, extname)
    82  	}
    83  
    84  	if fileInfo, err := c.statFunc(envFile); err == nil && fileInfo.Mode().IsRegular() {
    85  		fileHash, _ := c.hashFunc(envFile)
    86  		return envFile, fileInfo.ModTime(), fileHash, nil
    87  	}
    88  	return "", time.Now(), "", fmt.Errorf("failed to find file %v", file)
    89  }
    90  
    91  func (c *Configor) getConfigurationFiles(watchMode bool, files ...string) (
    92  	[]string, map[string]time.Time, map[string]string) {
    93  	resultKeys := make([]string, 0, len(files))
    94  	hashResult := make(map[string]string, len(files))
    95  	modTimeResult := make(map[string]time.Time, len(files))
    96  	if !watchMode && (c.Config.Debug || c.Config.Verbose) {
    97  		fmt.Printf("Current environment: '%v'\n", c.GetEnvironment())
    98  	}
    99  
   100  	for i := len(files) - 1; i >= 0; i-- {
   101  		foundFile := false
   102  		file := files[i]
   103  
   104  		// check configuration
   105  		if fileInfo, err := c.statFunc(file); err == nil && fileInfo.Mode().IsRegular() {
   106  			foundFile = true
   107  			resultKeys = append(resultKeys, file)
   108  			modTimeResult[file] = fileInfo.ModTime()
   109  			if hash, err := c.hashFunc(file); err == nil {
   110  				hashResult[file] = hash
   111  			}
   112  		}
   113  
   114  		// check configuration with env
   115  		if file, modTime, hash, err := c.getConfigurationFileWithENVPrefix(file, c.GetEnvironment()); err == nil {
   116  			foundFile = true
   117  			resultKeys = append(resultKeys, file)
   118  			modTimeResult[file] = modTime
   119  			if hash != "" {
   120  				hashResult[file] = hash
   121  			}
   122  		}
   123  
   124  		// check example configuration
   125  		if !foundFile {
   126  			if example, modTime, hash, err := c.getConfigurationFileWithENVPrefix(file, "example"); err == nil {
   127  				if !watchMode && !c.Silent {
   128  					log.Printf("Failed to find configuration %v, using example file %v\n", file, example)
   129  				}
   130  				resultKeys = append(resultKeys, example)
   131  				modTimeResult[example] = modTime
   132  				if hash != "" {
   133  					hashResult[file] = hash
   134  				}
   135  			} else if !c.Silent {
   136  				fmt.Printf("Failed to find configuration %v\n", file)
   137  			}
   138  		}
   139  	}
   140  	return resultKeys, modTimeResult, hashResult
   141  }
   142  
   143  func (c *Configor) processFile(config any, file string, errorOnUnmatchedKeys bool) error {
   144  	readFile := ioutil.ReadFile
   145  	if c.FS != nil {
   146  		readFile = func(filename string) ([]byte, error) {
   147  			return fs.ReadFile(c.FS, filename)
   148  		}
   149  	}
   150  	data, err := readFile(file)
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	switch {
   156  	case strings.HasSuffix(file, ".yaml") || strings.HasSuffix(file, ".yml"):
   157  		if errorOnUnmatchedKeys {
   158  			decoder := yaml.NewDecoder(bytes.NewBuffer(data))
   159  			decoder.KnownFields(true)
   160  			return decoder.Decode(config)
   161  		}
   162  		return yaml.Unmarshal(data, config)
   163  	case strings.HasSuffix(file, ".toml"):
   164  		return unmarshalToml(data, config, errorOnUnmatchedKeys)
   165  	case strings.HasSuffix(file, ".json"):
   166  		return unmarshalJSON(data, config, errorOnUnmatchedKeys)
   167  	default:
   168  		if err := unmarshalToml(data, config, errorOnUnmatchedKeys); err == nil {
   169  			return nil
   170  		} else if errUnmatchedKeys, ok := err.(*UnmatchedTomlKeysError); ok {
   171  			return errUnmatchedKeys
   172  		}
   173  
   174  		if err := unmarshalJSON(data, config, errorOnUnmatchedKeys); err == nil {
   175  			return nil
   176  		} else if strings.Contains(err.Error(), "json: unknown field") {
   177  			return err
   178  		}
   179  
   180  		var yamlError error
   181  		if errorOnUnmatchedKeys {
   182  			decoder := yaml.NewDecoder(bytes.NewBuffer(data))
   183  			decoder.KnownFields(true)
   184  			yamlError = decoder.Decode(config)
   185  		} else {
   186  			yamlError = yaml.Unmarshal(data, config)
   187  		}
   188  
   189  		if yamlError == nil {
   190  			return nil
   191  		} else if yErr, ok := yamlError.(*yaml.TypeError); ok {
   192  			return yErr
   193  		}
   194  
   195  		return errors.New("failed to decode config")
   196  	}
   197  }
   198  
   199  func unmarshalToml(data []byte, config any, errorOnUnmatchedKeys bool) error {
   200  	metadata, err := toml.Decode(string(data), config)
   201  	if err == nil && len(metadata.Undecoded()) > 0 && errorOnUnmatchedKeys {
   202  		return &UnmatchedTomlKeysError{Keys: metadata.Undecoded()}
   203  	}
   204  	return err
   205  }
   206  
   207  // unmarshalJSON unmarshals the given data into the config interface.
   208  // If the errorOnUnmatchedKeys boolean is true, an error will be returned if there
   209  // are keys in the data that do not match fields in the config interface.
   210  func unmarshalJSON(data []byte, config any, errorOnUnmatchedKeys bool) error {
   211  	reader := strings.NewReader(string(data))
   212  	decoder := json.NewDecoder(reader)
   213  
   214  	if errorOnUnmatchedKeys {
   215  		decoder.DisallowUnknownFields()
   216  	}
   217  
   218  	err := decoder.Decode(config)
   219  	if err != nil && err != io.EOF {
   220  		return err
   221  	}
   222  	return nil
   223  }
   224  
   225  func getPrefixForStruct(prefixes []string, fieldStruct *reflect.StructField) []string {
   226  	if fieldStruct.Anonymous && fieldStruct.Tag.Get("anonymous") == "true" {
   227  		return prefixes
   228  	}
   229  	return append(prefixes, fieldStruct.Name)
   230  }
   231  
   232  func (c *Configor) processTags(config any, prefixes ...string) error {
   233  	configValue := reflect.Indirect(reflect.ValueOf(config))
   234  	if configValue.Kind() != reflect.Struct {
   235  		return errors.New("invalid config, should be struct")
   236  	}
   237  
   238  	configType := configValue.Type()
   239  	for i := 0; i < configType.NumField(); i++ {
   240  		var (
   241  			envNames    []string
   242  			fieldStruct = configType.Field(i)
   243  			field       = configValue.Field(i)
   244  			envName     = fieldStruct.Tag.Get("env") // read configuration from shell env
   245  		)
   246  
   247  		if !field.CanAddr() || !field.CanInterface() {
   248  			continue
   249  		}
   250  
   251  		if envName == "" {
   252  			envNames = append(envNames,
   253  				strings.Join(append(prefixes, fieldStruct.Name), "_")) // Configor_DB_Name
   254  			envNames = append(envNames,
   255  				strings.ToUpper(strings.Join(append(prefixes, fieldStruct.Name), "_"))) // CONFIGOR_DB_NAME
   256  		} else {
   257  			envNames = []string{envName}
   258  		}
   259  
   260  		if c.Config.Verbose {
   261  			fmt.Printf("Trying to load struct `%v`'s field `%v` from env %v\n",
   262  				configType.Name(), fieldStruct.Name, strings.Join(envNames, ", "))
   263  		}
   264  
   265  		// Load From Shell ENV
   266  		for _, env := range envNames {
   267  			if value := os.Getenv(env); value != "" {
   268  				if c.Config.Debug || c.Config.Verbose {
   269  					fmt.Printf("Loading configuration for struct `%v`'s field `%v` from env %v...\n",
   270  						configType.Name(), fieldStruct.Name, env)
   271  				}
   272  
   273  				switch reflect.Indirect(field).Kind() {
   274  				case reflect.Bool:
   275  					switch strings.ToLower(value) {
   276  					case "", "0", "f", "false":
   277  						field.Set(reflect.ValueOf(false))
   278  					default:
   279  						field.Set(reflect.ValueOf(true))
   280  					}
   281  				case reflect.String:
   282  					field.Set(reflect.ValueOf(value))
   283  				default:
   284  					if err := yaml.Unmarshal([]byte(value), field.Addr().Interface()); err != nil {
   285  						return err
   286  					}
   287  				}
   288  				break
   289  			}
   290  		}
   291  
   292  		if isBlank := reflect.DeepEqual(field.Interface(), reflect.Zero(field.Type()).Interface()); isBlank &&
   293  			fieldStruct.Tag.Get("required") == "true" {
   294  			// return error if it is required but blank
   295  			return errors.New(fieldStruct.Name + " is required, but blank")
   296  		}
   297  
   298  		field = utils.IndirectValue(field)
   299  		if field.Kind() == reflect.Struct {
   300  			if err := c.processTags(field.Addr().Interface(),
   301  				getPrefixForStruct(prefixes, &fieldStruct)...); err != nil {
   302  				return err
   303  			}
   304  		}
   305  
   306  		if field.Kind() == reflect.Slice {
   307  			if arrLen := field.Len(); arrLen > 0 {
   308  				for i := 0; i < arrLen; i++ {
   309  					if reflect.Indirect(field.Index(i)).Kind() == reflect.Struct {
   310  						if err := c.processTags(field.Index(i).Addr().Interface(),
   311  							append(getPrefixForStruct(prefixes, &fieldStruct), fmt.Sprint(i))...); err != nil {
   312  							return err
   313  						}
   314  					}
   315  				}
   316  			} else {
   317  				defer func(field reflect.Value, fieldStruct reflect.StructField) {
   318  					if !configValue.IsZero() {
   319  						// load slice from env
   320  						newVal := reflect.New(field.Type().Elem()).Elem()
   321  						if newVal.Kind() == reflect.Struct {
   322  							idx := 0
   323  							for {
   324  								newVal = reflect.New(field.Type().Elem()).Elem()
   325  								if err := c.processTags(newVal.Addr().Interface(), append(
   326  									getPrefixForStruct(prefixes, &fieldStruct), fmt.Sprint(idx))...); err != nil {
   327  									return // err
   328  								} else if reflect.DeepEqual(newVal.Interface(),
   329  									reflect.New(field.Type().Elem()).Elem().Interface()) {
   330  									break
   331  								} else {
   332  									idx++
   333  									field.Set(reflect.Append(field, newVal))
   334  								}
   335  							}
   336  						}
   337  					}
   338  				}(field, fieldStruct)
   339  			}
   340  		}
   341  	}
   342  	return nil
   343  }
   344  
   345  func (c *Configor) load(config any, watchMode bool, files ...string) (err error, changed bool) {
   346  	defer func() {
   347  		if r := recover(); r != nil {
   348  			err = errors.Errorf("panic %s =>\n%s", r, debug.Stack())
   349  			return
   350  		}
   351  		if c.Config.Debug || c.Config.Verbose {
   352  			if err != nil {
   353  				fmt.Printf("Failed to load configuration from %v, got %v\n", files, err)
   354  			}
   355  
   356  			fmt.Printf("Configuration:\n  %#v\n", config)
   357  		}
   358  	}()
   359  
   360  	configFiles, configModTimeMap, hashMap := c.getConfigurationFiles(watchMode, files...)
   361  	if watchMode && len(configModTimeMap) == len(c.configModTimes) && len(hashMap) == len(c.configHash) {
   362  		var changed bool
   363  		for f, curModTime := range configModTimeMap {
   364  			curHash := hashMap[f]
   365  			preHash, ok1 := c.configHash[f]
   366  			preModTime, ok2 := c.configModTimes[f]
   367  			if changed = !ok1 || !ok2 || curModTime.After(preModTime) || curHash != preHash; changed {
   368  				break
   369  			}
   370  		}
   371  
   372  		if !changed {
   373  			return nil, false
   374  		}
   375  	}
   376  
   377  	type withBeforeCallback interface {
   378  		BeforeLoad(opts ...utils.OptionExtender)
   379  	}
   380  	type withAfterCallback interface {
   381  		AfterLoad(opts ...utils.OptionExtender)
   382  	}
   383  	if cb, ok := config.(withBeforeCallback); ok {
   384  		cb.BeforeLoad()
   385  	}
   386  	if cb, ok := config.(withAfterCallback); ok {
   387  		defer cb.AfterLoad()
   388  	}
   389  
   390  	for _, file := range configFiles {
   391  		if c.Config.Debug || c.Config.Verbose {
   392  			fmt.Printf("Loading configurations from file '%v'...\n", file)
   393  		}
   394  		if err = c.processFile(config, file, c.GetErrorOnUnmatchedKeys()); err != nil {
   395  			return err, true
   396  		}
   397  	}
   398  
   399  	// process defaults after process file because map struct should be assigned first
   400  	_ = utils.ParseTag(config, utils.ParseTagName("default"), utils.ParseTagUnmarshalType(utils.UnmarshalTypeYaml))
   401  
   402  	// process file again to ensure read config from file
   403  	for _, file := range configFiles {
   404  		if c.Config.Debug || c.Config.Verbose {
   405  			fmt.Printf("Loading configurations from file '%v'...\n", file)
   406  		}
   407  		if err = c.processFile(config, file, c.GetErrorOnUnmatchedKeys()); err != nil {
   408  			return err, true
   409  		}
   410  	}
   411  
   412  	c.configHash = hashMap
   413  	c.configModTimes = configModTimeMap
   414  
   415  	if prefix := c.getENVPrefix(config); prefix == "-" {
   416  		err = c.processTags(config)
   417  	} else {
   418  		err = c.processTags(config, prefix)
   419  	}
   420  
   421  	// process defaults
   422  	_ = utils.ParseTag(config, utils.ParseTagName("default"), utils.ParseTagUnmarshalType(utils.UnmarshalTypeYaml))
   423  
   424  	return err, true
   425  }