trpc.group/trpc-go/trpc-go@v1.0.2/config/trpc_config.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package config
    15  
    16  import (
    17  	"encoding/json"
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  	"sync"
    22  
    23  	"github.com/BurntSushi/toml"
    24  	"github.com/spf13/cast"
    25  	yaml "gopkg.in/yaml.v3"
    26  
    27  	"trpc.group/trpc-go/trpc-go/log"
    28  )
    29  
    30  var (
    31  	// ErrConfigNotExist is config not exist error
    32  	ErrConfigNotExist = errors.New("trpc/config: config not exist")
    33  
    34  	// ErrProviderNotExist is provider not exist error
    35  	ErrProviderNotExist = errors.New("trpc/config: provider not exist")
    36  
    37  	// ErrCodecNotExist is codec not exist error
    38  	ErrCodecNotExist = errors.New("trpc/config: codec not exist")
    39  )
    40  
    41  func init() {
    42  	RegisterCodec(&YamlCodec{})
    43  	RegisterCodec(&JSONCodec{})
    44  	RegisterCodec(&TomlCodec{})
    45  }
    46  
    47  // LoadOption defines the option function for loading configuration.
    48  type LoadOption func(*TrpcConfig)
    49  
    50  // TrpcConfigLoader is a config loader for trpc.
    51  type TrpcConfigLoader struct {
    52  	configMap map[string]Config
    53  	rwl       sync.RWMutex
    54  }
    55  
    56  // Load returns the config specified by input parameter.
    57  func (loader *TrpcConfigLoader) Load(path string, opts ...LoadOption) (Config, error) {
    58  	yc := newTrpcConfig(path)
    59  	for _, o := range opts {
    60  		o(yc)
    61  	}
    62  	if yc.decoder == nil {
    63  		return nil, ErrCodecNotExist
    64  	}
    65  	if yc.p == nil {
    66  		return nil, ErrProviderNotExist
    67  	}
    68  
    69  	key := fmt.Sprintf("%s.%s.%s", yc.decoder.Name(), yc.p.Name(), path)
    70  	loader.rwl.RLock()
    71  	if c, ok := loader.configMap[key]; ok {
    72  		loader.rwl.RUnlock()
    73  		return c, nil
    74  	}
    75  	loader.rwl.RUnlock()
    76  
    77  	if err := yc.Load(); err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	loader.rwl.Lock()
    82  	loader.configMap[key] = yc
    83  	loader.rwl.Unlock()
    84  
    85  	yc.p.Watch(func(p string, data []byte) {
    86  		if p == path {
    87  			loader.rwl.Lock()
    88  			delete(loader.configMap, key)
    89  			loader.rwl.Unlock()
    90  		}
    91  	})
    92  	return yc, nil
    93  }
    94  
    95  // Reload reloads config data.
    96  func (loader *TrpcConfigLoader) Reload(path string, opts ...LoadOption) error {
    97  	yc := newTrpcConfig(path)
    98  	for _, o := range opts {
    99  		o(yc)
   100  	}
   101  	key := fmt.Sprintf("%s.%s.%s", yc.decoder.Name(), yc.p.Name(), path)
   102  	loader.rwl.RLock()
   103  	if config, ok := loader.configMap[key]; ok {
   104  		loader.rwl.RUnlock()
   105  		config.Reload()
   106  		return nil
   107  	}
   108  	loader.rwl.RUnlock()
   109  	return ErrConfigNotExist
   110  }
   111  
   112  func newTrpcConfigLoad() *TrpcConfigLoader {
   113  	return &TrpcConfigLoader{configMap: map[string]Config{}, rwl: sync.RWMutex{}}
   114  }
   115  
   116  // DefaultConfigLoader is the default config loader.
   117  var DefaultConfigLoader = newTrpcConfigLoad()
   118  
   119  // YamlCodec is yaml codec.
   120  type YamlCodec struct{}
   121  
   122  // Name returns yaml codec's name.
   123  func (*YamlCodec) Name() string {
   124  	return "yaml"
   125  }
   126  
   127  // Unmarshal deserializes the in bytes into out parameter by yaml.
   128  func (c *YamlCodec) Unmarshal(in []byte, out interface{}) error {
   129  	return yaml.Unmarshal(in, out)
   130  }
   131  
   132  // JSONCodec is json codec.
   133  type JSONCodec struct{}
   134  
   135  // Name returns json codec's name.
   136  func (*JSONCodec) Name() string {
   137  	return "json"
   138  }
   139  
   140  // Unmarshal deserializes the in bytes into out parameter by json.
   141  func (c *JSONCodec) Unmarshal(in []byte, out interface{}) error {
   142  	return json.Unmarshal(in, out)
   143  }
   144  
   145  // TomlCodec is toml codec.
   146  type TomlCodec struct{}
   147  
   148  // Name returns toml codec's name.
   149  func (*TomlCodec) Name() string {
   150  	return "toml"
   151  }
   152  
   153  // Unmarshal deserializes the in bytes into out parameter by toml.
   154  func (c *TomlCodec) Unmarshal(in []byte, out interface{}) error {
   155  	return toml.Unmarshal(in, out)
   156  }
   157  
   158  // TrpcConfig is used to parse yaml config file for trpc.
   159  type TrpcConfig struct {
   160  	p                DataProvider
   161  	unmarshalledData interface{}
   162  	path             string
   163  	decoder          Codec
   164  	rawData          []byte
   165  }
   166  
   167  func newTrpcConfig(path string) *TrpcConfig {
   168  	return &TrpcConfig{
   169  		p:                GetProvider("file"),
   170  		unmarshalledData: make(map[string]interface{}),
   171  		path:             path,
   172  		decoder:          &YamlCodec{},
   173  	}
   174  }
   175  
   176  // Unmarshal deserializes the config into input param.
   177  func (c *TrpcConfig) Unmarshal(out interface{}) error {
   178  	return c.decoder.Unmarshal(c.rawData, out)
   179  }
   180  
   181  // Load loads config.
   182  func (c *TrpcConfig) Load() error {
   183  	if c.p == nil {
   184  		return ErrProviderNotExist
   185  	}
   186  
   187  	data, err := c.p.Read(c.path)
   188  	if err != nil {
   189  		return fmt.Errorf("trpc/config: failed to load %s: %s", c.path, err.Error())
   190  	}
   191  
   192  	c.rawData = data
   193  	if err := c.decoder.Unmarshal(c.rawData, &c.unmarshalledData); err != nil {
   194  		return fmt.Errorf("trpc/config: failed to parse %s: %s", c.path, err.Error())
   195  	}
   196  	return nil
   197  }
   198  
   199  // Reload reloads config.
   200  func (c *TrpcConfig) Reload() {
   201  	if c.p == nil {
   202  		return
   203  	}
   204  
   205  	data, err := c.p.Read(c.path)
   206  	if err != nil {
   207  		log.Tracef("trpc/config: failed to reload %s: %v", c.path, err)
   208  		return
   209  	}
   210  
   211  	c.rawData = data
   212  	if err := c.decoder.Unmarshal(data, &c.unmarshalledData); err != nil {
   213  		log.Tracef("trpc/config: failed to parse %s: %v", c.path, err)
   214  		return
   215  	}
   216  }
   217  
   218  // Get returns config value by key. If key is absent will return the default value.
   219  func (c *TrpcConfig) Get(key string, defaultValue interface{}) interface{} {
   220  	if v, ok := c.search(key); ok {
   221  		return v
   222  	}
   223  	return defaultValue
   224  }
   225  
   226  // Bytes returns original config data as bytes.
   227  func (c *TrpcConfig) Bytes() []byte {
   228  	return c.rawData
   229  }
   230  
   231  // GetInt returns int value by key, the second parameter
   232  // is default value when key is absent or type conversion fails.
   233  func (c *TrpcConfig) GetInt(key string, defaultValue int) int {
   234  	return c.findWithDefaultValue(key, defaultValue).(int)
   235  }
   236  
   237  // GetInt32 returns int32 value by key, the second parameter
   238  // is default value when key is absent or type conversion fails.
   239  func (c *TrpcConfig) GetInt32(key string, defaultValue int32) int32 {
   240  	return c.findWithDefaultValue(key, defaultValue).(int32)
   241  }
   242  
   243  // GetInt64 returns int64 value by key, the second parameter
   244  // is default value when key is absent or type conversion fails.
   245  func (c *TrpcConfig) GetInt64(key string, defaultValue int64) int64 {
   246  	return c.findWithDefaultValue(key, defaultValue).(int64)
   247  }
   248  
   249  // GetUint returns uint value by key, the second parameter
   250  // is default value when key is absent or type conversion fails.
   251  func (c *TrpcConfig) GetUint(key string, defaultValue uint) uint {
   252  	return c.findWithDefaultValue(key, defaultValue).(uint)
   253  }
   254  
   255  // GetUint32 returns uint32 value by key, the second parameter
   256  // is default value when key is absent or type conversion fails.
   257  func (c *TrpcConfig) GetUint32(key string, defaultValue uint32) uint32 {
   258  	return c.findWithDefaultValue(key, defaultValue).(uint32)
   259  }
   260  
   261  // GetUint64 returns uint64 value by key, the second parameter
   262  // is default value when key is absent or type conversion fails.
   263  func (c *TrpcConfig) GetUint64(key string, defaultValue uint64) uint64 {
   264  	return c.findWithDefaultValue(key, defaultValue).(uint64)
   265  }
   266  
   267  // GetFloat64 returns float64 value by key, the second parameter
   268  // is default value when key is absent or type conversion fails.
   269  func (c *TrpcConfig) GetFloat64(key string, defaultValue float64) float64 {
   270  	return c.findWithDefaultValue(key, defaultValue).(float64)
   271  }
   272  
   273  // GetFloat32 returns float32 value by key, the second parameter
   274  // is default value when key is absent or type conversion fails.
   275  func (c *TrpcConfig) GetFloat32(key string, defaultValue float32) float32 {
   276  	return c.findWithDefaultValue(key, defaultValue).(float32)
   277  }
   278  
   279  // GetBool returns bool value by key, the second parameter
   280  // is default value when key is absent or type conversion fails.
   281  func (c *TrpcConfig) GetBool(key string, defaultValue bool) bool {
   282  	return c.findWithDefaultValue(key, defaultValue).(bool)
   283  }
   284  
   285  // GetString returns string value by key, the second parameter
   286  // is default value when key is absent or type conversion fails.
   287  func (c *TrpcConfig) GetString(key string, defaultValue string) string {
   288  	return c.findWithDefaultValue(key, defaultValue).(string)
   289  }
   290  
   291  // IsSet returns if the config specified by key exists.
   292  func (c *TrpcConfig) IsSet(key string) bool {
   293  	_, ok := c.search(key)
   294  	return ok
   295  }
   296  
   297  // findWithDefaultValue ensures that the type of `value` is same as `defaultValue`
   298  func (c *TrpcConfig) findWithDefaultValue(key string, defaultValue interface{}) (value interface{}) {
   299  	v, ok := c.search(key)
   300  	if !ok {
   301  		return defaultValue
   302  	}
   303  
   304  	var err error
   305  	switch defaultValue.(type) {
   306  	case bool:
   307  		v, err = cast.ToBoolE(v)
   308  	case string:
   309  		v, err = cast.ToStringE(v)
   310  	case int:
   311  		v, err = cast.ToIntE(v)
   312  	case int32:
   313  		v, err = cast.ToInt32E(v)
   314  	case int64:
   315  		v, err = cast.ToInt64E(v)
   316  	case uint:
   317  		v, err = cast.ToUintE(v)
   318  	case uint32:
   319  		v, err = cast.ToUint32E(v)
   320  	case uint64:
   321  		v, err = cast.ToUint64E(v)
   322  	case float64:
   323  		v, err = cast.ToFloat64E(v)
   324  	case float32:
   325  		v, err = cast.ToFloat32E(v)
   326  	default:
   327  	}
   328  
   329  	if err != nil {
   330  		return defaultValue
   331  	}
   332  	return v
   333  }
   334  
   335  func (c *TrpcConfig) search(key string) (interface{}, bool) {
   336  	unmarshalledData, ok := c.unmarshalledData.(map[string]interface{})
   337  	if !ok {
   338  		return nil, false
   339  	}
   340  
   341  	subkeys := strings.Split(key, ".")
   342  	value, err := search(unmarshalledData, subkeys)
   343  	if err != nil {
   344  		log.Debugf("trpc config: search key %s failed: %+v", key, err)
   345  		return value, false
   346  	}
   347  
   348  	return value, true
   349  }
   350  
   351  func search(unmarshalledData map[string]interface{}, keys []string) (interface{}, error) {
   352  	if len(keys) == 0 {
   353  		return nil, ErrConfigNotExist
   354  	}
   355  
   356  	key, ok := unmarshalledData[keys[0]]
   357  	if !ok {
   358  		return nil, ErrConfigNotExist
   359  	}
   360  
   361  	if len(keys) == 1 {
   362  		return key, nil
   363  	}
   364  	switch key := key.(type) {
   365  	case map[interface{}]interface{}:
   366  		return search(cast.ToStringMap(key), keys[1:])
   367  	case map[string]interface{}:
   368  		return search(key, keys[1:])
   369  	default:
   370  		return nil, ErrConfigNotExist
   371  	}
   372  }