github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/conf/conf.go (about)

     1  package conf
     2  
     3  import (
     4  	"encoding"
     5  	"fmt"
     6  	"go/ast"
     7  	"os"
     8  	"reflect"
     9  	"strings"
    10  
    11  	"github.com/sirupsen/logrus"
    12  
    13  	"github.com/artisanhe/tools/courier/transport_http/transform"
    14  	"github.com/artisanhe/tools/reflectx"
    15  	"github.com/artisanhe/tools/strutil"
    16  	"github.com/artisanhe/tools/validate"
    17  )
    18  
    19  func UnmarshalConf(c interface{}, prefix string) EnvVars {
    20  	rv := reflectx.Indirect(reflect.ValueOf(c))
    21  	tpe := reflectx.IndirectType(reflect.TypeOf(c))
    22  
    23  	if !rv.CanSet() || rv.Type().Kind() != reflect.Struct {
    24  		panic("UnmarshalConf need an variable which can set")
    25  	}
    26  
    27  	ok, errMsgs := NewScanner(prefix).Unmarshal(rv, tpe)
    28  	if !ok {
    29  		for k, v := range errMsgs {
    30  			logrus.Errorf("%s: %s", k, v)
    31  		}
    32  		logrus.Panic()
    33  	}
    34  
    35  	envVars, err := CollectEnvVars(rv, prefix)
    36  	if err != nil {
    37  		panic(err)
    38  	}
    39  	InitialRoot(rv)
    40  	return envVars
    41  }
    42  
    43  type ICanInit interface {
    44  	Init()
    45  }
    46  
    47  func InitialRoot(rv reflect.Value) {
    48  	tpe := rv.Type()
    49  	for i := 0; i < tpe.NumField(); i++ {
    50  		value := rv.Field(i)
    51  		if conf, ok := value.Interface().(ICanInit); ok {
    52  			conf.Init()
    53  		}
    54  	}
    55  }
    56  
    57  // check and modify value
    58  type IDefaultsMarshaller interface {
    59  	MarshalDefaults(v interface{})
    60  }
    61  
    62  func NewScanner(prefix string) *Scanner {
    63  	if prefix == "" {
    64  		prefix = "s"
    65  	}
    66  	return &Scanner{
    67  		prefix: prefix,
    68  	}
    69  }
    70  
    71  type Scanner struct {
    72  	prefix    string
    73  	walker    transform.PathWalker
    74  	errMsgMap transform.ErrMsgMap
    75  }
    76  
    77  func (vs *Scanner) Unmarshal(rv reflect.Value, tpe reflect.Type) (bool, transform.ErrMsgMap) {
    78  	vs.marshalAndValidate(rv, tpe, "")
    79  	return len(vs.errMsgMap) == 0, vs.errMsgMap
    80  }
    81  
    82  func (vs *Scanner) setErrMsg(path string, msg string) {
    83  	if vs.errMsgMap == nil {
    84  		vs.errMsgMap = transform.ErrMsgMap{}
    85  	}
    86  	vs.errMsgMap[path] = msg
    87  }
    88  
    89  func (vs *Scanner) getEnvKey() string {
    90  	key := strings.ToUpper(vs.prefix)
    91  
    92  	for _, p := range vs.walker.Paths() {
    93  		key += strings.ToUpper(fmt.Sprintf("_%v", p))
    94  	}
    95  
    96  	return key
    97  }
    98  
    99  func (vs *Scanner) marshalAndValidate(rv reflect.Value, tpe reflect.Type, tagValidate string) {
   100  	v := rv.Interface()
   101  	if rv.Kind() != reflect.Ptr {
   102  		v = rv.Addr().Interface()
   103  	}
   104  
   105  	if defaultsMarshaller, ok := v.(IDefaultsMarshaller); ok {
   106  		defaultsMarshaller.MarshalDefaults(v)
   107  	}
   108  
   109  	if _, ok := v.(encoding.TextUnmarshaler); ok {
   110  		errMsg := marshalEnvValueAndValidate(rv, tpe, vs.getEnvKey(), tagValidate)
   111  		if errMsg != "" {
   112  			vs.setErrMsg(vs.walker.String(), errMsg)
   113  		}
   114  		return
   115  	}
   116  
   117  	tpe = reflectx.IndirectType(tpe)
   118  
   119  	switch tpe.Kind() {
   120  	case reflect.Func:
   121  		// skip func
   122  	case reflect.Struct:
   123  		if rv.Kind() == reflect.Ptr {
   124  			if rv.IsNil() && rv.CanSet() {
   125  				rv.Set(reflect.New(reflectx.IndirectType(tpe)))
   126  			}
   127  		}
   128  
   129  		rv = reflectx.Indirect(rv)
   130  
   131  		for i := 0; i < tpe.NumField(); i++ {
   132  			field := tpe.Field(i)
   133  			if !ast.IsExported(field.Name) {
   134  				continue
   135  			}
   136  
   137  			if !field.Anonymous {
   138  				vs.walker.Enter(field.Name)
   139  			}
   140  
   141  			tagValidate, _ := transform.GetTagValidate(&field)
   142  
   143  			vs.marshalAndValidate(rv.Field(i), field.Type, tagValidate)
   144  
   145  			if !field.Anonymous {
   146  				vs.walker.Exit()
   147  			}
   148  		}
   149  	default:
   150  		errMsg := marshalEnvValueAndValidate(rv, tpe, vs.getEnvKey(), tagValidate)
   151  		if errMsg != "" {
   152  			vs.setErrMsg(vs.walker.String(), errMsg)
   153  		}
   154  	}
   155  }
   156  
   157  func marshalEnvValueAndValidate(
   158  	rv reflect.Value,
   159  	tpe reflect.Type,
   160  	envKey string,
   161  	tagValidate string,
   162  ) string {
   163  	envValue, _ := os.LookupEnv(envKey)
   164  
   165  	isPtr := rv.Kind() == reflect.Ptr
   166  
   167  	if isPtr && rv.IsNil() {
   168  		// initial ptr
   169  		if rv.CanSet() {
   170  			rv.Set(reflect.New(reflectx.IndirectType(tpe)))
   171  		}
   172  	}
   173  
   174  	rv = reflectx.Indirect(rv)
   175  
   176  	if envValue != "" && rv.CanSet() {
   177  		err := strutil.ConvertFromStr(envValue, rv)
   178  		if err != nil {
   179  			return fmt.Sprintf("%s can't set wrong default value %s", rv.Type().Name(), envValue)
   180  		}
   181  	}
   182  
   183  	if tagValidate != "" {
   184  		isValid, msg := validate.ValidateItem(tagValidate, rv.Interface(), "")
   185  		if !isValid {
   186  			return msg
   187  		}
   188  	}
   189  
   190  	return ""
   191  }
   192  
   193  func getConfTagFlags(tagStr string) map[string]bool {
   194  	flagList := strings.Split(tagStr, ",")
   195  	flags := map[string]bool{}
   196  	for _, f := range flagList {
   197  		flags[f] = true
   198  	}
   199  	return flags
   200  }