trpc.group/trpc-go/trpc-go@v1.0.3/config/trpc_config_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package config
    15  
    16  import (
    17  	"errors"
    18  	"fmt"
    19  	"os"
    20  	"reflect"
    21  	"sync"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	"trpc.group/trpc-go/trpc-go/errs"
    28  	"trpc.group/trpc-go/trpc-go/log"
    29  )
    30  
    31  func Test_search(t *testing.T) {
    32  	type args struct {
    33  		unmarshalledData map[string]interface{}
    34  		keys             []string
    35  	}
    36  	tests := []struct {
    37  		name    string
    38  		args    args
    39  		want    interface{}
    40  		wantErr assert.ErrorAssertionFunc
    41  	}{
    42  		{
    43  			name: "empty keys",
    44  			args: args{
    45  				keys: nil,
    46  			},
    47  			want: nil,
    48  			wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
    49  				if !errors.Is(err, ErrConfigNotExist) {
    50  					t.Errorf("received unexpected error got: %+v, want: +%v", err, ErrCodecNotExist)
    51  					return false
    52  				}
    53  				return true
    54  			},
    55  		},
    56  		{
    57  			name: "key doesn't match",
    58  			args: args{
    59  				unmarshalledData: map[string]interface{}{
    60  					"1": []string{"x", "y"},
    61  				},
    62  				keys: []string{"not-1"},
    63  			},
    64  			want: nil,
    65  			wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
    66  				if !errors.Is(err, ErrConfigNotExist) {
    67  					t.Errorf("received unexpected error got: %+v, want: +%v", err, ErrCodecNotExist)
    68  					return false
    69  				}
    70  				return true
    71  			},
    72  		},
    73  		{
    74  			name: "value of unmarshalledData isn't map type",
    75  			args: args{
    76  				unmarshalledData: map[string]interface{}{
    77  					"1": []string{"x", "y"},
    78  				},
    79  				keys: []string{"1", "2"},
    80  			},
    81  			want: nil,
    82  			wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
    83  				if !errors.Is(err, ErrConfigNotExist) {
    84  					t.Errorf("received unexpected error got: %+v, want: +%v", err, ErrCodecNotExist)
    85  					return false
    86  				}
    87  				return true
    88  			},
    89  		},
    90  		{
    91  			name: "value of unmarshalledData is map[interface{}]interface{} type",
    92  			args: args{
    93  				unmarshalledData: map[string]interface{}{
    94  					"1": map[interface{}]interface{}{"x": "y"},
    95  				},
    96  				keys: []string{"1", "x"},
    97  			},
    98  			want:    "y",
    99  			wantErr: assert.NoError,
   100  		},
   101  		{
   102  			name: "value of unmarshalledData is map[string]interface{} type",
   103  			args: args{
   104  				unmarshalledData: map[string]interface{}{
   105  					"1": map[string]interface{}{"x": "y"},
   106  				},
   107  				keys: []string{"1", "x"},
   108  			},
   109  			want:    "y",
   110  			wantErr: assert.NoError,
   111  		},
   112  	}
   113  	for _, tt := range tests {
   114  		t.Run(tt.name, func(t *testing.T) {
   115  			got, err := search(tt.args.unmarshalledData, tt.args.keys)
   116  			if !tt.wantErr(t, err, fmt.Sprintf("search(%v, %v)", tt.args.unmarshalledData, tt.args.keys)) {
   117  				return
   118  			}
   119  			assert.Equalf(t, tt.want, got, "search(%v, %v)", tt.args.unmarshalledData, tt.args.keys)
   120  		})
   121  	}
   122  }
   123  
   124  func TestTrpcConfig_Load(t *testing.T) {
   125  	t.Run("parse failed", func(t *testing.T) {
   126  		c, err := newTrpcConfig("../testdata/trpc_go.yaml")
   127  		require.Nil(t, err)
   128  		c.decoder = &TomlCodec{}
   129  		err = c.Load()
   130  		require.Contains(t, errs.Msg(err), "failed to parse")
   131  	})
   132  }
   133  func TestYamlCodec_Unmarshal(t *testing.T) {
   134  	t.Run("interface", func(t *testing.T) {
   135  		var tt interface{}
   136  		tt = map[string]interface{}{}
   137  		require.Nil(t, GetCodec("yaml").Unmarshal([]byte("[1, 2]"), &tt))
   138  	})
   139  	t.Run("map[string]interface{}", func(t *testing.T) {
   140  		tt := map[string]interface{}{}
   141  		require.NotNil(t, GetCodec("yaml").Unmarshal([]byte("[1, 2]"), &tt))
   142  	})
   143  }
   144  
   145  func TestEnvExpanded(t *testing.T) {
   146  	RegisterProvider(NewEnvProvider(t.Name(), []byte(`
   147  password: ${pwd}
   148  `)))
   149  
   150  	t.Setenv("pwd", t.Name())
   151  	cfg, err := DefaultConfigLoader.Load(
   152  		t.Name(),
   153  		WithProvider(t.Name()),
   154  		WithExpandEnv())
   155  	require.Nil(t, err)
   156  
   157  	require.Equal(t, t.Name(), cfg.GetString("password", ""))
   158  	require.Contains(t, string(cfg.Bytes()), fmt.Sprintf("password: %s", t.Name()))
   159  }
   160  
   161  func TestCodecUnmarshalDstMustBeMap(t *testing.T) {
   162  	filePath := t.TempDir() + "/conf.map"
   163  	require.Nil(t, os.WriteFile(filePath, []byte{}, 0600))
   164  	RegisterCodec(dstMustBeMapCodec{})
   165  	_, err := DefaultConfigLoader.Load(filePath, WithCodec(dstMustBeMapCodec{}.Name()))
   166  	require.Nil(t, err)
   167  }
   168  
   169  func NewEnvProvider(name string, data []byte) *EnvProvider {
   170  	return &EnvProvider{
   171  		name: name,
   172  		data: data,
   173  	}
   174  }
   175  
   176  type EnvProvider struct {
   177  	name string
   178  	data []byte
   179  }
   180  
   181  func (ep *EnvProvider) Name() string {
   182  	return ep.name
   183  }
   184  
   185  func (ep *EnvProvider) Read(string) ([]byte, error) {
   186  	return ep.data, nil
   187  }
   188  
   189  func (ep *EnvProvider) Watch(cb ProviderCallback) {
   190  	cb("", ep.data)
   191  }
   192  
   193  func TestWatch(t *testing.T) {
   194  	p := manualTriggerWatchProvider{}
   195  	var msgs = make(chan WatchMessage)
   196  	SetDefaultWatchHook(func(msg WatchMessage) {
   197  		if msg.Error != nil {
   198  			log.Errorf("config watch error: %+v", msg)
   199  		} else {
   200  			log.Infof("config watch error: %+v", msg)
   201  		}
   202  		msgs <- msg
   203  	})
   204  
   205  	RegisterProvider(&p)
   206  	p.Set("key", []byte(`key: value`))
   207  	ops := []LoadOption{WithProvider(p.Name()), WithCodec("yaml"), WithWatch()}
   208  	c1, err := DefaultConfigLoader.Load("key", ops...)
   209  	require.Nilf(t, err, "first load config:%+v", c1)
   210  	require.True(t, c1.IsSet("key"), "first load config key exist")
   211  	require.Equal(t, c1.Get("key", "default"), "value", "first load config get key value")
   212  
   213  	var c2 Config
   214  	c2, err = DefaultConfigLoader.Load("key", ops...)
   215  	require.Nil(t, err, "second load config:%+v", c2)
   216  	require.Equal(t, c1, c2, "first and second load config not equal")
   217  	require.True(t, c2.IsSet("key"), "second load config key exist")
   218  	require.Equal(t, c2.Get("key", "default"), "value", "second load config get key value")
   219  
   220  	var gw sync.WaitGroup
   221  	gw.Add(1)
   222  	go func() {
   223  		defer gw.Done()
   224  		tt := time.NewTimer(time.Second)
   225  		select {
   226  		case <-msgs:
   227  		case <-tt.C:
   228  			t.Errorf("receive message timeout")
   229  		}
   230  	}()
   231  
   232  	p.Set("key", []byte(`:key: value:`))
   233  	gw.Wait()
   234  
   235  	var c3 Config
   236  	c3, err = DefaultConfigLoader.Load("key", WithProvider(p.Name()), WithWatchHook(func(msg WatchMessage) {
   237  		msgs <- msg
   238  	}))
   239  	require.Contains(t, errs.Msg(err), "failed to parse")
   240  	require.Nil(t, c3, "update error")
   241  
   242  	require.True(t, c2.IsSet("key"), "third load config key exist")
   243  	require.Equal(t, c2.Get("key", "default"), "value", "third load config get key value")
   244  
   245  	gw.Add(1)
   246  	go func() {
   247  		defer gw.Done()
   248  		for i := 0; i < 2; i++ {
   249  			tt := time.NewTimer(time.Second)
   250  			select {
   251  			case <-msgs:
   252  			case <-tt.C:
   253  				t.Errorf("receive message timeout number%d ", i)
   254  			}
   255  		}
   256  	}()
   257  	p.Set("key", []byte(`key: value2`))
   258  	gw.Wait()
   259  
   260  	require.Truef(t, c2.IsSet("key"), "after update config and get key exist")
   261  	require.Equal(t, c2.Get("key", "default"), "value2", "after update config and config get value")
   262  }
   263  
   264  var _ DataProvider = (*manualTriggerWatchProvider)(nil)
   265  
   266  type manualTriggerWatchProvider struct {
   267  	values    sync.Map
   268  	callbacks []ProviderCallback
   269  }
   270  
   271  func (m *manualTriggerWatchProvider) Name() string {
   272  	return "manual_trigger_watch_provider"
   273  }
   274  
   275  func (m *manualTriggerWatchProvider) Read(s string) ([]byte, error) {
   276  	if v, ok := m.values.Load(s); ok {
   277  		return v.([]byte), nil
   278  	}
   279  	return nil, fmt.Errorf("not found config")
   280  }
   281  
   282  func (m *manualTriggerWatchProvider) Watch(callback ProviderCallback) {
   283  	m.callbacks = append(m.callbacks, callback)
   284  }
   285  
   286  func (m *manualTriggerWatchProvider) Set(key string, v []byte) {
   287  	m.values.Store(key, v)
   288  	for _, callback := range m.callbacks {
   289  		callback(key, v)
   290  	}
   291  }
   292  
   293  type dstMustBeMapCodec struct{}
   294  
   295  func (c dstMustBeMapCodec) Name() string {
   296  	return "map"
   297  }
   298  
   299  func (c dstMustBeMapCodec) Unmarshal(bts []byte, dst interface{}) error {
   300  	rv := reflect.ValueOf(dst)
   301  	if rv.Kind() != reflect.Ptr ||
   302  		rv.Elem().Kind() != reflect.Interface ||
   303  		rv.Elem().Elem().Kind() != reflect.Map ||
   304  		rv.Elem().Elem().Type().Key().Kind() != reflect.String ||
   305  		rv.Elem().Elem().Type().Elem().Kind() != reflect.Interface {
   306  		return errors.New("the dst of codec.Unmarshal must be a map")
   307  	}
   308  	return nil
   309  }