github.com/yandex/pandora@v0.5.32/core/plugin/registry_test.go (about)

     1  package plugin
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"testing"
     7  
     8  	"github.com/mitchellh/mapstructure"
     9  	"github.com/pkg/errors"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestNewDefaultConfigContainerExpectationFail(t *testing.T) {
    15  	tests := []struct {
    16  		name                     string
    17  		constructor              any
    18  		newDefaultConfigOptional []any
    19  	}{
    20  		{
    21  			name:        "invalid type",
    22  			constructor: func(int) ptestPlugin { return nil },
    23  		},
    24  		{
    25  			name:        "invalid ptr type",
    26  			constructor: func(*int) ptestPlugin { return nil },
    27  		},
    28  		{
    29  			name:        "to many args",
    30  			constructor: func(_, _ ptestConfig) ptestPlugin { return nil },
    31  		},
    32  		{
    33  			name:                     "default without config",
    34  			constructor:              func() ptestPlugin { return nil },
    35  			newDefaultConfigOptional: []any{func() *ptestConfig { return nil }}},
    36  		{
    37  			name:                     "invalid default config",
    38  			constructor:              func(ptestConfig) ptestPlugin { return nil },
    39  			newDefaultConfigOptional: []any{func() *ptestConfig { return nil }}},
    40  		{
    41  			name:                     "default config accepts args",
    42  			constructor:              func(*ptestConfig) ptestPlugin { return nil },
    43  			newDefaultConfigOptional: []any{func(int) *ptestConfig { return nil }},
    44  		},
    45  	}
    46  	for _, tt := range tests {
    47  		t.Run(tt.name, func(t *testing.T) {
    48  			newDefaultConfig := getNewDefaultConfig(tt.newDefaultConfigOptional)
    49  			defer recoverExpectationFail(t)
    50  			newDefaultConfigContainer(reflect.TypeOf(tt.constructor), newDefaultConfig)
    51  		})
    52  	}
    53  }
    54  
    55  func TestNewDefaultConfigContainerExpectationOk(t *testing.T) {
    56  	tests := []struct {
    57  		name                     string
    58  		constructor              any
    59  		newDefaultConfigOptional []any
    60  	}{
    61  
    62  		{
    63  			name:        "no default config",
    64  			constructor: ptestNewConf},
    65  		{
    66  			name:        "no default ptr config",
    67  			constructor: ptestNewPtrConf},
    68  		{
    69  			name:                     "default config",
    70  			constructor:              ptestNewConf,
    71  			newDefaultConfigOptional: []any{ptestDefaultConf}},
    72  		{
    73  			name:                     "default ptr config",
    74  			constructor:              ptestNewPtrConf,
    75  			newDefaultConfigOptional: []any{ptestNewDefaultPtrConf},
    76  		},
    77  	}
    78  	for _, tt := range tests {
    79  		t.Run(tt.name, func(t *testing.T) {
    80  			newDefaultConfig := getNewDefaultConfig(tt.newDefaultConfigOptional)
    81  			container := newDefaultConfigContainer(reflect.TypeOf(tt.constructor), newDefaultConfig)
    82  			conf, err := container.Get(ptestFillConf)
    83  			assert.NoError(t, err)
    84  			assert.Len(t, conf, 1)
    85  			ptestExpectConfigValue(t, conf[0].Interface(), ptestFilledValue)
    86  		})
    87  	}
    88  }
    89  
    90  // new default config container fill no config failed
    91  func TestNewDefault(t *testing.T) {
    92  	container := newDefaultConfigContainer(ptestNewErrType(), nil)
    93  	_, err := container.Get(ptestFillConf)
    94  	assert.Error(t, err)
    95  }
    96  
    97  func TestRegistry(t *testing.T) {
    98  	t.Run("register name collision panics", func(t *testing.T) {
    99  		r := NewRegistry()
   100  		r.ptestRegister(ptestNewImpl)
   101  		defer recoverExpectationFail(t)
   102  		r.ptestRegister(ptestNewImpl)
   103  	})
   104  
   105  	t.Run("lookup", func(t *testing.T) {
   106  		r := NewRegistry()
   107  		r.ptestRegister(ptestNewImpl)
   108  		assert.True(t, r.Lookup(ptestType()))
   109  		assert.False(t, r.Lookup(reflect.TypeOf(0)))
   110  		assert.False(t, r.Lookup(reflect.TypeOf(&ptestImpl{})))
   111  		assert.False(t, r.Lookup(reflect.TypeOf((*io.Writer)(nil)).Elem()))
   112  	})
   113  
   114  	t.Run("lookup factory", func(t *testing.T) {
   115  		r := NewRegistry()
   116  		r.ptestRegister(ptestNewImpl)
   117  		assert.True(t, r.LookupFactory(ptestNewType()))
   118  		assert.True(t, r.LookupFactory(ptestNewErrType()))
   119  
   120  		assert.False(t, r.LookupFactory(reflect.TypeOf(0)))
   121  		assert.False(t, r.LookupFactory(reflect.TypeOf(&ptestImpl{})))
   122  		assert.False(t, r.LookupFactory(reflect.TypeOf((*io.Writer)(nil)).Elem()))
   123  	})
   124  }
   125  
   126  func TestNew(t *testing.T) {
   127  	type New func(r *Registry, fillConfOptional ...func(conf interface{}) error) (interface{}, error)
   128  
   129  	testNewOk := func(t *testing.T, r *Registry, testNew New, fillConfOptional ...func(conf interface{}) error) (pluginVal string) {
   130  		plugin, err := testNew(r, fillConfOptional...)
   131  		require.NoError(t, err)
   132  		return plugin.(*ptestImpl).Value
   133  	}
   134  
   135  	tests := []struct {
   136  		name string
   137  
   138  		assert func(t *testing.T, r *Registry, testNew New)
   139  	}{
   140  		{
   141  			name: "plugin constructor. no conf",
   142  			assert: func(t *testing.T, r *Registry, testNew New) {
   143  				r.ptestRegister(ptestNewImpl)
   144  				got := testNewOk(t, r, testNew)
   145  				assert.Equal(t, ptestInitValue, got)
   146  			},
   147  		},
   148  		{
   149  			name: "plugin conf: nil error",
   150  			assert: func(t *testing.T, r *Registry, testNew New) {
   151  				r.ptestRegister(ptestNewErr)
   152  				got := testNewOk(t, r, testNew)
   153  				assert.Equal(t, ptestInitValue, got)
   154  			},
   155  		},
   156  		{
   157  			name: "plugin conf: non-nil error",
   158  			assert: func(t *testing.T, r *Registry, testNew New) {
   159  				r.ptestRegister(ptestNewErrFailing)
   160  				_, err := testNew(r)
   161  				assert.Error(t, err)
   162  				assert.ErrorIs(t, err, ptestCreateFailedErr)
   163  			},
   164  		},
   165  		{
   166  			name: "plugin conf: no conf, fill conf error",
   167  			assert: func(t *testing.T, r *Registry, testNew New) {
   168  				r.ptestRegister(ptestNewImpl)
   169  				expectedErr := errors.New("fill conf err")
   170  				_, err := testNew(r, func(_ interface{}) error { return expectedErr })
   171  				assert.ErrorIs(t, err, expectedErr)
   172  			},
   173  		},
   174  		{
   175  			name: "plugin conf: no default",
   176  			assert: func(t *testing.T, r *Registry, testNew New) {
   177  				r.ptestRegister(ptestNewConf)
   178  				got := testNewOk(t, r, testNew)
   179  				assert.Equal(t, "", got)
   180  			},
   181  		},
   182  		{
   183  			name: "plugin conf: default",
   184  			assert: func(t *testing.T, r *Registry, testNew New) {
   185  				r.ptestRegister(ptestNewConf, ptestDefaultConf)
   186  				got := testNewOk(t, r, testNew)
   187  				assert.Equal(t, ptestDefaultValue, got)
   188  			},
   189  		},
   190  		{
   191  			name: "plugin conf: fill conf default",
   192  			assert: func(t *testing.T, r *Registry, testNew New) {
   193  				r.ptestRegister(ptestNewConf, ptestDefaultConf)
   194  				got := testNewOk(t, r, testNew, ptestFillConf)
   195  				assert.Equal(t, ptestFilledValue, got)
   196  			},
   197  		},
   198  		{
   199  			name: "plugin conf: fill conf no default",
   200  			assert: func(t *testing.T, r *Registry, testNew New) {
   201  				r.ptestRegister(ptestNewConf)
   202  				got := testNewOk(t, r, testNew, ptestFillConf)
   203  				assert.Equal(t, ptestFilledValue, got)
   204  			},
   205  		},
   206  		{
   207  			name: "plugin conf: fill ptr conf no default",
   208  			assert: func(t *testing.T, r *Registry, testNew New) {
   209  				r.ptestRegister(ptestNewPtrConf)
   210  				got := testNewOk(t, r, testNew, ptestFillConf)
   211  				assert.Equal(t, ptestFilledValue, got)
   212  			},
   213  		},
   214  		{
   215  			name: "plugin conf: no default ptr conf not nil",
   216  			assert: func(t *testing.T, r *Registry, testNew New) {
   217  				r.ptestRegister(ptestNewPtrConf)
   218  				got := testNewOk(t, r, testNew)
   219  				assert.Equal(t, "", got)
   220  			},
   221  		},
   222  		{
   223  			name: "plugin conf: nil default, conf not nil",
   224  			assert: func(t *testing.T, r *Registry, testNew New) {
   225  				r.ptestRegister(ptestNewPtrConf, func() *ptestConfig { return nil })
   226  				got := testNewOk(t, r, testNew)
   227  				assert.Equal(t, "", got)
   228  			},
   229  		},
   230  		{
   231  			name: "plugin conf: fill nil default",
   232  			assert: func(t *testing.T, r *Registry, testNew New) {
   233  				r.ptestRegister(ptestNewPtrConf, func() *ptestConfig { return nil })
   234  				got := testNewOk(t, r, testNew, ptestFillConf)
   235  				assert.Equal(t, ptestFilledValue, got)
   236  			},
   237  		},
   238  		{
   239  			name: "plugin conf: more than one fill conf panics",
   240  			assert: func(t *testing.T, r *Registry, testNew New) {
   241  				r.ptestRegister(ptestNewPtrConf)
   242  				defer recoverExpectationFail(t)
   243  				testNew(r, ptestFillConf, ptestFillConf)
   244  			},
   245  		},
   246  		{
   247  			name: "factory constructor; no conf",
   248  			assert: func(t *testing.T, r *Registry, testNew New) {
   249  				r.ptestRegister(ptestNewFactory)
   250  				got := testNewOk(t, r, testNew)
   251  				assert.Equal(t, ptestInitValue, got)
   252  			},
   253  		},
   254  		{
   255  			name: "factory constructor; nil error",
   256  			assert: func(t *testing.T, r *Registry, testNew New) {
   257  				r.ptestRegister(func() (ptestPlugin, error) {
   258  					return ptestNewImpl(), nil
   259  				})
   260  				got := testNewOk(t, r, testNew)
   261  				assert.Equal(t, ptestInitValue, got)
   262  			},
   263  		},
   264  		{
   265  			name: "factory constructor; non-nil error",
   266  			assert: func(t *testing.T, r *Registry, testNew New) {
   267  				r.ptestRegister(ptestNewFactoryFactoryErrFailing)
   268  				_, err := testNew(r)
   269  				assert.Error(t, err)
   270  				assert.ErrorIs(t, err, ptestCreateFailedErr)
   271  			},
   272  		},
   273  		{
   274  			name: "factory constructor; no conf, fill conf error",
   275  			assert: func(t *testing.T, r *Registry, testNew New) {
   276  				r.ptestRegister(ptestNewFactory)
   277  				expectedErr := errors.New("fill conf err")
   278  				_, err := testNew(r, func(_ interface{}) error { return expectedErr })
   279  				assert.ErrorIs(t, err, expectedErr)
   280  			},
   281  		},
   282  		{
   283  			name: "factory constructor; no default",
   284  			assert: func(t *testing.T, r *Registry, testNew New) {
   285  				r.ptestRegister(ptestNewFactoryConf)
   286  				got := testNewOk(t, r, testNew)
   287  				assert.Equal(t, "", got)
   288  			},
   289  		},
   290  		{
   291  			name: "factory constructor; default",
   292  			assert: func(t *testing.T, r *Registry, testNew New) {
   293  				r.ptestRegister(ptestNewFactoryConf, ptestDefaultConf)
   294  				got := testNewOk(t, r, testNew)
   295  				assert.Equal(t, ptestDefaultValue, got)
   296  			},
   297  		},
   298  		{
   299  			name: "factory constructor; fill conf default",
   300  			assert: func(t *testing.T, r *Registry, testNew New) {
   301  				r.ptestRegister(ptestNewFactoryConf, ptestDefaultConf)
   302  				got := testNewOk(t, r, testNew, ptestFillConf)
   303  				assert.Equal(t, ptestFilledValue, got)
   304  			},
   305  		},
   306  		{
   307  			name: "factory constructor; fill conf no default",
   308  			assert: func(t *testing.T, r *Registry, testNew New) {
   309  				r.ptestRegister(ptestNewFactoryConf)
   310  				got := testNewOk(t, r, testNew, ptestFillConf)
   311  				assert.Equal(t, ptestFilledValue, got)
   312  			},
   313  		},
   314  		{
   315  			name: "factory constructor; fill ptr conf no default",
   316  			assert: func(t *testing.T, r *Registry, testNew New) {
   317  				r.ptestRegister(ptestNewFactoryPtrConf)
   318  				got := testNewOk(t, r, testNew, ptestFillConf)
   319  				assert.Equal(t, ptestFilledValue, got)
   320  			},
   321  		},
   322  		{
   323  			name: "factory constructor; no default ptr conf not nil",
   324  			assert: func(t *testing.T, r *Registry, testNew New) {
   325  				r.ptestRegister(ptestNewFactoryPtrConf)
   326  				got := testNewOk(t, r, testNew)
   327  				assert.Equal(t, "", got)
   328  			},
   329  		},
   330  		{
   331  			name: "factory constructor; nil default, conf not nil",
   332  			assert: func(t *testing.T, r *Registry, testNew New) {
   333  				r.ptestRegister(ptestNewFactoryPtrConf, func() *ptestConfig { return nil })
   334  				got := testNewOk(t, r, testNew)
   335  				assert.Equal(t, "", got)
   336  			},
   337  		},
   338  		{
   339  			name: "factory constructor; fill nil default",
   340  			assert: func(t *testing.T, r *Registry, testNew New) {
   341  				r.ptestRegister(ptestNewFactoryPtrConf, func() *ptestConfig { return nil })
   342  				got := testNewOk(t, r, testNew, ptestFillConf)
   343  				assert.Equal(t, ptestFilledValue, got)
   344  			},
   345  		},
   346  		{
   347  			name: "factory constructor; more than one fill conf panics",
   348  			assert: func(t *testing.T, r *Registry, testNew New) {
   349  				r.ptestRegister(ptestNewFactoryPtrConf)
   350  				defer recoverExpectationFail(t)
   351  				testNew(r, ptestFillConf, ptestFillConf)
   352  			},
   353  		},
   354  	}
   355  	for _, tt := range tests {
   356  		t.Run(tt.name, func(t *testing.T) {
   357  			r := NewRegistry()
   358  			testNew := (*Registry).ptestNew
   359  			tt.assert(t, r, testNew)
   360  
   361  			r = NewRegistry()
   362  			testNew = (*Registry).ptestNewFactory
   363  			tt.assert(t, r, testNew)
   364  		})
   365  	}
   366  }
   367  
   368  func TestDecode(t *testing.T) {
   369  	r := NewRegistry()
   370  	const nameKey = "type"
   371  
   372  	var hook mapstructure.DecodeHookFunc
   373  	decode := func(input, result interface{}) error {
   374  		decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
   375  			DecodeHook:  hook,
   376  			ErrorUnused: true,
   377  			Result:      result,
   378  		})
   379  		if err != nil {
   380  			return err
   381  		}
   382  		return decoder.Decode(input)
   383  	}
   384  	hook = mapstructure.ComposeDecodeHookFunc(
   385  		func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) {
   386  			if !r.Lookup(to) {
   387  				return data, nil
   388  			}
   389  			// NOTE: could be map[interface{}]interface{} here.
   390  			input := data.(map[string]interface{})
   391  			// NOTE: should be case insensitive behaviour.
   392  			pluginName := input[nameKey].(string)
   393  			delete(input, nameKey)
   394  			return r.New(to, pluginName, func(conf interface{}) error {
   395  				// NOTE: should error, if conf has "type" field.
   396  				return decode(input, conf)
   397  			})
   398  		})
   399  
   400  	r.Register(ptestType(), "my-plugin", ptestNewConf, ptestDefaultConf)
   401  	input := map[string]interface{}{
   402  		"plugin": map[string]interface{}{
   403  			nameKey: "my-plugin",
   404  			"value": ptestFilledValue,
   405  		},
   406  	}
   407  	type Config struct {
   408  		Plugin ptestPlugin
   409  	}
   410  	var conf Config
   411  	err := decode(input, &conf)
   412  	assert.NoError(t, err)
   413  	actualValue := conf.Plugin.(*ptestImpl).Value
   414  	assert.Equal(t, ptestFilledValue, actualValue)
   415  }
   416  
   417  func recoverExpectationFail(t *testing.T) {
   418  	r := recover()
   419  	assert.NotNil(t, r)
   420  	assert.Contains(t, r, "expectation failed")
   421  }