github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/envconfig/envconfig.go (about)

     1  // Copyright (c) 2013 Kelsey Hightower. All rights reserved.
     2  // Use of this source code is governed by the MIT License that can be found in
     3  // the LICENSE file.
     4  
     5  package envconfig
     6  
     7  import (
     8  	"encoding"
     9  	"fmt"
    10  	"github.com/pkg/errors"
    11  	"math"
    12  	"os"
    13  	"reflect"
    14  	"regexp"
    15  	"strconv"
    16  	"strings"
    17  	"time"
    18  )
    19  
    20  // ErrInvalidSpecification indicates that a specification is of the wrong type.
    21  var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
    22  
    23  var gatherRegexp = regexp.MustCompile("([^A-Z]+|[A-Z]+[^A-Z]+|[A-Z]+)")
    24  var acronymRegexp = regexp.MustCompile("([A-Z]+)([A-Z][^A-Z]+)")
    25  
    26  // A ParseError occurs when an environment variable cannot be converted to
    27  // the type required by a struct field during assignment.
    28  type ParseError struct {
    29  	KeyName   string
    30  	FieldName string
    31  	TypeName  string
    32  	Value     string
    33  	Err       error
    34  }
    35  
    36  // Decoder has the same semantics as Setter, but takes higher precedence.
    37  // It is provided for historical compatibility.
    38  type Decoder interface {
    39  	Decode(value string) error
    40  }
    41  
    42  // Setter is implemented by types can self-deserialize values.
    43  // Any type that implements flag.Value also implements Setter.
    44  type Setter interface {
    45  	Set(value string) error
    46  }
    47  
    48  func (e *ParseError) Error() string {
    49  	return fmt.Sprintf("envconfig.Process: assigning %[1]s to %[2]s: converting '%[3]s' to type %[4]s. details: %+v", e.KeyName, e.FieldName, e.Value, e.TypeName, e.Err)
    50  }
    51  
    52  type SeparatorType string
    53  
    54  const (
    55  	INDEX SeparatorType = "index"
    56  	COMMA SeparatorType = "comma"
    57  )
    58  
    59  // varInfo maintains information about the configuration variable
    60  type varInfo struct {
    61  	Name     string
    62  	Alt      string
    63  	Key      string
    64  	Field    reflect.Value
    65  	Tags     reflect.StructTag
    66  	ValuesBy SeparatorType
    67  }
    68  
    69  // GatherInfo gathers information about the specified struct
    70  func gatherInfo(prefix string, spec interface{}) ([]varInfo, error) {
    71  	s := reflect.ValueOf(spec)
    72  
    73  	if s.Kind() != reflect.Ptr {
    74  		return nil, ErrInvalidSpecification
    75  	}
    76  	s = s.Elem()
    77  	if s.Kind() != reflect.Struct {
    78  		return nil, ErrInvalidSpecification
    79  	}
    80  	typeOfSpec := s.Type()
    81  
    82  	// over allocate an info array, we will extend if needed later
    83  	infos := make([]varInfo, 0, s.NumField())
    84  	for i := 0; i < s.NumField(); i++ {
    85  		f := s.Field(i)
    86  		ftype := typeOfSpec.Field(i)
    87  		if !f.CanSet() || isTrue(ftype.Tag.Get("ignored")) {
    88  			continue
    89  		}
    90  
    91  		for f.Kind() == reflect.Ptr {
    92  			if f.IsNil() {
    93  				if f.Type().Elem().Kind() != reflect.Struct {
    94  					// nil pointer to a non-struct: leave it alone
    95  					break
    96  				}
    97  				// nil pointer to struct: create a zero instance
    98  				f.Set(reflect.New(f.Type().Elem()))
    99  			}
   100  			f = f.Elem()
   101  		}
   102  
   103  		// Capture information about the config variable
   104  		info := varInfo{
   105  			Name:  ftype.Name,
   106  			Field: f,
   107  			Tags:  ftype.Tag,
   108  			Alt:   strings.ToUpper(ftype.Tag.Get("envconfig")),
   109  		}
   110  
   111  		// Default to the field name as the env var name (will be upcased)
   112  		info.Key = info.Name
   113  
   114  		// Best effort to un-pick camel casing as separate words
   115  		if isTrue(ftype.Tag.Get("split_words")) {
   116  			words := gatherRegexp.FindAllStringSubmatch(ftype.Name, -1)
   117  			if len(words) > 0 {
   118  				var name []string
   119  				for _, words := range words {
   120  					if m := acronymRegexp.FindStringSubmatch(words[0]); len(m) == 3 {
   121  						name = append(name, m[1], m[2])
   122  					} else {
   123  						name = append(name, words[0])
   124  					}
   125  				}
   126  
   127  				info.Key = strings.Join(name, "_")
   128  			}
   129  		}
   130  		if info.Alt != "" {
   131  			info.Key = info.Alt
   132  		}
   133  		if prefix != "" {
   134  			info.Key = fmt.Sprintf("%s_%s", prefix, info.Key)
   135  		}
   136  		info.Key = strings.ToUpper(info.Key)
   137  		info.ValuesBy = SeparatorType(ftype.Tag.Get("values_by"))
   138  
   139  		infos = append(infos, info)
   140  
   141  		if f.Kind() == reflect.Struct {
   142  			// honor Decode if present
   143  			if decoderFrom(f) == nil && setterFrom(f) == nil && textUnmarshaler(f) == nil && binaryUnmarshaler(f) == nil {
   144  				innerPrefix := prefix
   145  				if !ftype.Anonymous {
   146  					innerPrefix = info.Key
   147  				}
   148  
   149  				embeddedPtr := f.Addr().Interface()
   150  				embeddedInfos, err := gatherInfo(innerPrefix, embeddedPtr)
   151  				if err != nil {
   152  					return nil, err
   153  				}
   154  				infos = append(infos[:len(infos)-1], embeddedInfos...)
   155  
   156  				continue
   157  			}
   158  		}
   159  	}
   160  	return infos, nil
   161  }
   162  
   163  // CheckDisallowed checks that no environment variables with the prefix are set
   164  // that we don't know how or want to parse. This is likely only meaningful with
   165  // a non-empty prefix.
   166  func CheckDisallowed(prefix string, spec interface{}) error {
   167  	infos, err := gatherInfo(prefix, spec)
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	vars := make(map[string]struct{})
   173  	for _, info := range infos {
   174  		vars[info.Key] = struct{}{}
   175  	}
   176  
   177  	if prefix != "" {
   178  		prefix = strings.ToUpper(prefix) + "_"
   179  	}
   180  
   181  	for _, env := range os.Environ() {
   182  		if !strings.HasPrefix(env, prefix) {
   183  			continue
   184  		}
   185  		v := strings.SplitN(env, "=", 2)[0]
   186  		if _, found := vars[v]; !found {
   187  			return fmt.Errorf("unknown environment variable %s", v)
   188  		}
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  // Process populates the specified struct based on environment variables
   195  func Process(prefix string, spec interface{}) error {
   196  	infos, err := gatherInfo(prefix, spec)
   197  
   198  	for _, info := range infos {
   199  
   200  		if info.ValuesBy == INDEX {
   201  			index := 0
   202  			end := false
   203  			for !end {
   204  				key := info.Key + "_" + strconv.Itoa(index)
   205  				if end, err = doProcess(key, info); err != nil {
   206  					return errors.WithStack(err)
   207  				}
   208  				index++
   209  			}
   210  			if info.Field.IsNil() || info.Field.IsZero() {
   211  				if err = checkRequired(info); err != nil {
   212  					return errors.WithStack(err)
   213  				}
   214  				def := info.Tags.Get("default")
   215  				if def != "" {
   216  					err = processFieldWithConfig(def, info.Field, ProcessConfig{
   217  						ValuesBy: info.ValuesBy,
   218  					})
   219  				}
   220  			}
   221  			continue
   222  		}
   223  
   224  		if _, err = doProcess(info.Key, info); err != nil {
   225  			return errors.WithStack(err)
   226  		}
   227  	}
   228  
   229  	return err
   230  }
   231  
   232  func checkRequired(info varInfo) error {
   233  	def := info.Tags.Get("default")
   234  	req := info.Tags.Get("required")
   235  	if def == "" {
   236  		if isTrue(req) {
   237  			key := info.Key
   238  			if info.Alt != "" {
   239  				key = info.Alt
   240  			}
   241  			return fmt.Errorf("required key %s missing value", key)
   242  		}
   243  	}
   244  	return nil
   245  }
   246  
   247  func doProcess(key string, info varInfo) (end bool, err error) {
   248  	// `os.Getenv` cannot differentiate between an explicitly set empty value
   249  	// and an unset value. `os.LookupEnv` is preferred to `syscall.Getenv`,
   250  	// but it is only available in go1.5 or newer. We're using Go build tags
   251  	// here to use os.LookupEnv for >=go1.5
   252  	value, ok := lookupEnv(key)
   253  
   254  	if info.ValuesBy == INDEX && !ok {
   255  		return true, nil
   256  	}
   257  
   258  	if !ok && info.Alt != "" {
   259  		value, ok = lookupEnv(info.Alt)
   260  	}
   261  
   262  	if info.ValuesBy != INDEX {
   263  		if !ok {
   264  			if err = checkRequired(info); err != nil {
   265  				return false, errors.WithStack(err)
   266  			}
   267  			def := info.Tags.Get("default")
   268  			if def == "" {
   269  				return false, nil
   270  			}
   271  			value = def
   272  		}
   273  	}
   274  
   275  	err = processFieldWithConfig(value, info.Field, ProcessConfig{
   276  		ValuesBy: info.ValuesBy,
   277  	})
   278  	if err != nil {
   279  		return false, &ParseError{
   280  			KeyName:   info.Key,
   281  			FieldName: info.Name,
   282  			TypeName:  info.Field.Type().String(),
   283  			Value:     value,
   284  			Err:       errors.WithStack(err),
   285  		}
   286  	}
   287  
   288  	return false, nil
   289  }
   290  
   291  // MustProcess is the same as Process but panics if an error occurs
   292  func MustProcess(prefix string, spec interface{}) {
   293  	if err := Process(prefix, spec); err != nil {
   294  		panic(err)
   295  	}
   296  }
   297  
   298  func processField(value string, field reflect.Value) error {
   299  	typ := field.Type()
   300  
   301  	decoder := decoderFrom(field)
   302  	if decoder != nil {
   303  		return decoder.Decode(value)
   304  	}
   305  	// look for Set method if Decode not defined
   306  	setter := setterFrom(field)
   307  	if setter != nil {
   308  		return setter.Set(value)
   309  	}
   310  
   311  	if t := textUnmarshaler(field); t != nil {
   312  		return t.UnmarshalText([]byte(value))
   313  	}
   314  
   315  	if b := binaryUnmarshaler(field); b != nil {
   316  		return b.UnmarshalBinary([]byte(value))
   317  	}
   318  
   319  	if typ.Kind() == reflect.Ptr {
   320  		typ = typ.Elem()
   321  		if field.IsNil() {
   322  			field.Set(reflect.New(typ))
   323  		}
   324  		field = field.Elem()
   325  	}
   326  
   327  	switch typ.Kind() {
   328  	case reflect.String:
   329  		field.SetString(value)
   330  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   331  		var (
   332  			val int64
   333  			err error
   334  		)
   335  		if field.Kind() == reflect.Int64 && typ.PkgPath() == "time" && typ.Name() == "Duration" {
   336  			var d time.Duration
   337  			d, err = time.ParseDuration(value)
   338  			val = int64(d)
   339  		} else {
   340  			val, err = strconv.ParseInt(value, 0, typ.Bits())
   341  		}
   342  		if err != nil {
   343  			return err
   344  		}
   345  
   346  		field.SetInt(val)
   347  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   348  		val, err := strconv.ParseUint(value, 0, typ.Bits())
   349  		if err != nil {
   350  			return err
   351  		}
   352  		field.SetUint(val)
   353  	case reflect.Bool:
   354  		val, err := strconv.ParseBool(value)
   355  		if err != nil {
   356  			return err
   357  		}
   358  		field.SetBool(val)
   359  	case reflect.Float32, reflect.Float64:
   360  		val, err := strconv.ParseFloat(value, typ.Bits())
   361  		if err != nil {
   362  			return err
   363  		}
   364  		field.SetFloat(val)
   365  	case reflect.Slice:
   366  		sl := reflect.MakeSlice(typ, 0, 0)
   367  		if typ.Elem().Kind() == reflect.Uint8 {
   368  			sl = reflect.ValueOf([]byte(value))
   369  		} else if strings.TrimSpace(value) != "" {
   370  			vals := strings.Split(value, ",")
   371  			sl = reflect.MakeSlice(typ, len(vals), len(vals))
   372  			for i, val := range vals {
   373  				err := processField(val, sl.Index(i))
   374  				if err != nil {
   375  					return err
   376  				}
   377  			}
   378  		}
   379  		field.Set(sl)
   380  	case reflect.Map:
   381  		mp := reflect.MakeMap(typ)
   382  		if strings.TrimSpace(value) != "" {
   383  			pairs := strings.Split(value, ",")
   384  			for _, pair := range pairs {
   385  				kvpair := strings.Split(pair, ":")
   386  				if len(kvpair) != 2 {
   387  					return fmt.Errorf("invalid map item: %q", pair)
   388  				}
   389  				k := reflect.New(typ.Key()).Elem()
   390  				err := processField(kvpair[0], k)
   391  				if err != nil {
   392  					return err
   393  				}
   394  				v := reflect.New(typ.Elem()).Elem()
   395  				err = processField(kvpair[1], v)
   396  				if err != nil {
   397  					return err
   398  				}
   399  				mp.SetMapIndex(k, v)
   400  			}
   401  		}
   402  		field.Set(mp)
   403  	}
   404  
   405  	return nil
   406  }
   407  
   408  type ProcessConfig struct {
   409  	ValuesBy SeparatorType
   410  }
   411  
   412  const (
   413  	CAP_GROW_STEP = 8
   414  )
   415  
   416  func processFieldWithConfig(value string, field reflect.Value, config ProcessConfig) error {
   417  	typ := field.Type()
   418  
   419  	decoder := decoderFrom(field)
   420  	if decoder != nil {
   421  		return decoder.Decode(value)
   422  	}
   423  	// look for Set method if Decode not defined
   424  	setter := setterFrom(field)
   425  	if setter != nil {
   426  		return setter.Set(value)
   427  	}
   428  
   429  	if t := textUnmarshaler(field); t != nil {
   430  		return t.UnmarshalText([]byte(value))
   431  	}
   432  
   433  	if b := binaryUnmarshaler(field); b != nil {
   434  		return b.UnmarshalBinary([]byte(value))
   435  	}
   436  
   437  	if typ.Kind() == reflect.Ptr {
   438  		typ = typ.Elem()
   439  		if field.IsNil() {
   440  			field.Set(reflect.New(typ))
   441  		}
   442  		field = field.Elem()
   443  	}
   444  
   445  	switch typ.Kind() {
   446  	case reflect.String:
   447  		field.SetString(value)
   448  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   449  		var (
   450  			val int64
   451  			err error
   452  		)
   453  		if field.Kind() == reflect.Int64 && typ.PkgPath() == "time" && typ.Name() == "Duration" {
   454  			var d time.Duration
   455  			d, err = time.ParseDuration(value)
   456  			val = int64(d)
   457  		} else {
   458  			val, err = strconv.ParseInt(value, 0, typ.Bits())
   459  		}
   460  		if err != nil {
   461  			return errors.WithStack(err)
   462  		}
   463  
   464  		field.SetInt(val)
   465  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   466  		val, err := strconv.ParseUint(value, 0, typ.Bits())
   467  		if err != nil {
   468  			return errors.WithStack(err)
   469  		}
   470  		field.SetUint(val)
   471  	case reflect.Bool:
   472  		val, err := strconv.ParseBool(value)
   473  		if err != nil {
   474  			return errors.WithStack(err)
   475  		}
   476  		field.SetBool(val)
   477  	case reflect.Float32, reflect.Float64:
   478  		val, err := strconv.ParseFloat(value, typ.Bits())
   479  		if err != nil {
   480  			return errors.WithStack(err)
   481  		}
   482  		field.SetFloat(val)
   483  	case reflect.Slice:
   484  		if config.ValuesBy == INDEX {
   485  			if field.IsNil() || field.IsZero() {
   486  				sl := reflect.MakeSlice(typ, 0, 0)
   487  				if typ.Elem().Kind() == reflect.Uint8 {
   488  					return errors.WithStack(errors.New("Not support uint8 slice field to be processed with index values_by tag"))
   489  				} else if strings.TrimSpace(value) != "" {
   490  					sl = reflect.MakeSlice(typ, 1, CAP_GROW_STEP)
   491  					err := processFieldWithConfig(value, sl.Index(0), config)
   492  					if err != nil {
   493  						return errors.WithStack(err)
   494  					}
   495  				}
   496  				field.Set(sl)
   497  			} else {
   498  				size := field.Len()
   499  				if size >= int(math.Floor(float64(field.Cap())*0.75)) {
   500  					field.Grow(CAP_GROW_STEP)
   501  				}
   502  				field.SetLen(size + 1)
   503  				if typ.Elem().Kind() == reflect.Uint8 {
   504  					return errors.WithStack(errors.New("Not support uint8 slice field to be processed with index values_by tag"))
   505  				} else if strings.TrimSpace(value) != "" {
   506  					err := processFieldWithConfig(value, field.Index(size), config)
   507  					if err != nil {
   508  						return errors.WithStack(err)
   509  					}
   510  				}
   511  			}
   512  		} else {
   513  			sl := reflect.MakeSlice(typ, 0, 0)
   514  			if typ.Elem().Kind() == reflect.Uint8 {
   515  				sl = reflect.ValueOf([]byte(value))
   516  			} else if strings.TrimSpace(value) != "" {
   517  				vals := strings.Split(value, ",")
   518  				sl = reflect.MakeSlice(typ, len(vals), len(vals))
   519  				for i, val := range vals {
   520  					err := processFieldWithConfig(val, sl.Index(i), config)
   521  					if err != nil {
   522  						return errors.WithStack(err)
   523  					}
   524  				}
   525  			}
   526  			field.Set(sl)
   527  		}
   528  	case reflect.Map:
   529  		mp := reflect.MakeMap(typ)
   530  		if strings.TrimSpace(value) != "" {
   531  			pairs := strings.Split(value, ",")
   532  			for _, pair := range pairs {
   533  				kvpair := strings.Split(pair, ":")
   534  				if len(kvpair) != 2 {
   535  					return fmt.Errorf("invalid map item: %q", pair)
   536  				}
   537  				k := reflect.New(typ.Key()).Elem()
   538  				err := processFieldWithConfig(kvpair[0], k, config)
   539  				if err != nil {
   540  					return errors.WithStack(err)
   541  				}
   542  				v := reflect.New(typ.Elem()).Elem()
   543  				err = processFieldWithConfig(kvpair[1], v, config)
   544  				if err != nil {
   545  					return errors.WithStack(err)
   546  				}
   547  				mp.SetMapIndex(k, v)
   548  			}
   549  		}
   550  		field.Set(mp)
   551  	}
   552  
   553  	return nil
   554  }
   555  
   556  func interfaceFrom(field reflect.Value, fn func(interface{}, *bool)) {
   557  	// it may be impossible for a struct field to fail this check
   558  	if !field.CanInterface() {
   559  		return
   560  	}
   561  	var ok bool
   562  	fn(field.Interface(), &ok)
   563  	if !ok && field.CanAddr() {
   564  		fn(field.Addr().Interface(), &ok)
   565  	}
   566  }
   567  
   568  func decoderFrom(field reflect.Value) (d Decoder) {
   569  	interfaceFrom(field, func(v interface{}, ok *bool) { d, *ok = v.(Decoder) })
   570  	return d
   571  }
   572  
   573  func setterFrom(field reflect.Value) (s Setter) {
   574  	interfaceFrom(field, func(v interface{}, ok *bool) { s, *ok = v.(Setter) })
   575  	return s
   576  }
   577  
   578  func textUnmarshaler(field reflect.Value) (t encoding.TextUnmarshaler) {
   579  	interfaceFrom(field, func(v interface{}, ok *bool) { t, *ok = v.(encoding.TextUnmarshaler) })
   580  	return t
   581  }
   582  
   583  func binaryUnmarshaler(field reflect.Value) (b encoding.BinaryUnmarshaler) {
   584  	interfaceFrom(field, func(v interface{}, ok *bool) { b, *ok = v.(encoding.BinaryUnmarshaler) })
   585  	return b
   586  }
   587  
   588  func isTrue(s string) bool {
   589  	b, _ := strconv.ParseBool(s)
   590  	return b
   591  }