github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/plugin/plugin.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package plugin
    15  
    16  import (
    17  	"context"
    18  	"path/filepath"
    19  	gplugin "plugin"
    20  	"strconv"
    21  	"strings"
    22  	"sync/atomic"
    23  	"unsafe"
    24  
    25  	"github.com/whtcorpsinc/errors"
    26  	"github.com/whtcorpsinc/milevadb/petri"
    27  	"github.com/whtcorpsinc/milevadb/soliton"
    28  	"github.com/whtcorpsinc/milevadb/soliton/logutil"
    29  	"github.com/whtcorpsinc/milevadb/stochastikctx/variable"
    30  	"go.etcd.io/etcd/clientv3"
    31  	"go.uber.org/zap"
    32  )
    33  
    34  // pluginGlobal holds all global variables for plugin.
    35  var pluginGlobal copyOnWriteContext
    36  
    37  // copyOnWriteContext wraps a context follow COW idiom.
    38  type copyOnWriteContext struct {
    39  	tiPlugins unsafe.Pointer // *plugins
    40  }
    41  
    42  // plugins defCauslects loaded plugins info.
    43  type plugins struct {
    44  	plugins      map[HoTT][]Plugin
    45  	versions     map[string]uint16
    46  	dyingPlugins []Plugin
    47  }
    48  
    49  // clone deep copies plugins info.
    50  func (p *plugins) clone() *plugins {
    51  	np := &plugins{
    52  		plugins:      make(map[HoTT][]Plugin, len(p.plugins)),
    53  		versions:     make(map[string]uint16, len(p.versions)),
    54  		dyingPlugins: make([]Plugin, len(p.dyingPlugins)),
    55  	}
    56  	for key, value := range p.plugins {
    57  		np.plugins[key] = append([]Plugin(nil), value...)
    58  	}
    59  	for key, value := range p.versions {
    60  		np.versions[key] = value
    61  	}
    62  	copy(np.dyingPlugins, p.dyingPlugins)
    63  	return np
    64  }
    65  
    66  // add adds a plugin to loaded plugin defCauslection.
    67  func (p plugins) add(plugin *Plugin) {
    68  	plugins, ok := p.plugins[plugin.HoTT]
    69  	if !ok {
    70  		plugins = make([]Plugin, 0)
    71  	}
    72  	plugins = append(plugins, *plugin)
    73  	p.plugins[plugin.HoTT] = plugins
    74  	p.versions[plugin.Name] = plugin.Version
    75  }
    76  
    77  // plugins got plugin in COW context.
    78  func (p copyOnWriteContext) plugins() *plugins {
    79  	return (*plugins)(atomic.LoadPointer(&p.tiPlugins))
    80  }
    81  
    82  // Config presents the init configuration for plugin framework.
    83  type Config struct {
    84  	Plugins        []string
    85  	PluginDir      string
    86  	GlobalSysVar   *map[string]*variable.SysVar
    87  	PluginVarNames *[]string
    88  	SkipWhenFail   bool
    89  	EnvVersion     map[string]uint16
    90  	EtcdClient     *clientv3.Client
    91  }
    92  
    93  // Plugin presents a MilevaDB plugin.
    94  type Plugin struct {
    95  	*Manifest
    96  	library  *gplugin.Plugin
    97  	Path     string
    98  	Disabled uint32
    99  	State    State
   100  }
   101  
   102  // StateValue returns readable state string.
   103  func (p *Plugin) StateValue() string {
   104  	flag := "enable"
   105  	if atomic.LoadUint32(&p.Disabled) == 1 {
   106  		flag = "disable"
   107  	}
   108  	return p.State.String() + "-" + flag
   109  }
   110  
   111  // DisableFlag changes the disable flag of plugin.
   112  func (p *Plugin) DisableFlag(disable bool) {
   113  	if disable {
   114  		atomic.StoreUint32(&p.Disabled, 1)
   115  	} else {
   116  		atomic.StoreUint32(&p.Disabled, 0)
   117  	}
   118  }
   119  
   120  func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins) error {
   121  	if p.RequireVersion != nil {
   122  		for component, reqVer := range p.RequireVersion {
   123  			if ver, ok := tiPlugins.versions[component]; !ok || ver < reqVer {
   124  				return errRequireVersionCheckFail.GenWithStackByArgs(p.Name, component, reqVer, ver)
   125  			}
   126  		}
   127  	}
   128  	if p.SysVars != nil {
   129  		for varName := range p.SysVars {
   130  			if !strings.HasPrefix(varName, p.Name) {
   131  				return errInvalidPluginSysVarName.GenWithStackByArgs(p.Name, varName, p.Name)
   132  			}
   133  		}
   134  	}
   135  	if p.Manifest.Validate != nil {
   136  		if err := p.Manifest.Validate(ctx, p.Manifest); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  // Load load plugin by config param.
   144  // This method need be called before petri init to inject global variable info during bootstrap.
   145  func Load(ctx context.Context, cfg Config) (err error) {
   146  	tiPlugins := &plugins{
   147  		plugins:      make(map[HoTT][]Plugin),
   148  		versions:     make(map[string]uint16, len(cfg.EnvVersion)),
   149  		dyingPlugins: make([]Plugin, 0),
   150  	}
   151  
   152  	// Setup component version info for plugin running env.
   153  	for component, version := range cfg.EnvVersion {
   154  		tiPlugins.versions[component] = version
   155  	}
   156  
   157  	// Load plugin dl & manifest.
   158  	for _, pluginID := range cfg.Plugins {
   159  		var pName string
   160  		pName, _, err = ID(pluginID).Decode()
   161  		if err != nil {
   162  			err = errors.Trace(err)
   163  			return
   164  		}
   165  		// Check duplicate.
   166  		_, dup := tiPlugins.versions[pName]
   167  		if dup {
   168  			if cfg.SkipWhenFail {
   169  				logutil.Logger(ctx).Warn("duplicate load %s and ignored", zap.String("pluginName", pName))
   170  				continue
   171  			}
   172  			err = errDuplicatePlugin.GenWithStackByArgs(pluginID)
   173  			return
   174  		}
   175  		// Load dl.
   176  		var plugin Plugin
   177  		plugin, err = loadOne(cfg.PluginDir, ID(pluginID))
   178  		if err != nil {
   179  			if cfg.SkipWhenFail {
   180  				logutil.Logger(ctx).Warn("load plugin failure and ignored", zap.String("pluginID", pluginID), zap.Error(err))
   181  				continue
   182  			}
   183  			return
   184  		}
   185  		tiPlugins.add(&plugin)
   186  	}
   187  
   188  	// Cross validate & Load plugins.
   189  	for HoTT := range tiPlugins.plugins {
   190  		for i := range tiPlugins.plugins[HoTT] {
   191  			if err = tiPlugins.plugins[HoTT][i].validate(ctx, tiPlugins); err != nil {
   192  				if cfg.SkipWhenFail {
   193  					logutil.Logger(ctx).Warn("validate plugin fail and disable plugin",
   194  						zap.String("plugin", tiPlugins.plugins[HoTT][i].Name), zap.Error(err))
   195  					tiPlugins.plugins[HoTT][i].State = Disable
   196  					err = nil
   197  					continue
   198  				}
   199  				return
   200  			}
   201  			if cfg.GlobalSysVar != nil {
   202  				for key, value := range tiPlugins.plugins[HoTT][i].SysVars {
   203  					(*cfg.GlobalSysVar)[key] = value
   204  					if value.Scope != variable.ScopeStochastik && cfg.PluginVarNames != nil {
   205  						*cfg.PluginVarNames = append(*cfg.PluginVarNames, key)
   206  					}
   207  				}
   208  			}
   209  		}
   210  	}
   211  	pluginGlobal = copyOnWriteContext{tiPlugins: unsafe.Pointer(tiPlugins)}
   212  	err = nil
   213  	return
   214  }
   215  
   216  // Init initializes the loaded plugin by config param.
   217  // This method must be called after `Load` but before any other plugin method call, so it call got MilevaDB petri info.
   218  func Init(ctx context.Context, cfg Config) (err error) {
   219  	tiPlugins := pluginGlobal.plugins()
   220  	if tiPlugins == nil {
   221  		return nil
   222  	}
   223  	for HoTT := range tiPlugins.plugins {
   224  		for i := range tiPlugins.plugins[HoTT] {
   225  			p := tiPlugins.plugins[HoTT][i]
   226  			if err = p.OnInit(ctx, p.Manifest); err != nil {
   227  				if cfg.SkipWhenFail {
   228  					logutil.Logger(ctx).Warn("call Plugin OnInit failure, err: %v",
   229  						zap.String("plugin", p.Name), zap.Error(err))
   230  					tiPlugins.plugins[HoTT][i].State = Disable
   231  					err = nil
   232  					continue
   233  				}
   234  				return
   235  			}
   236  			if p.OnFlush != nil && cfg.EtcdClient != nil {
   237  				const pluginWatchPrefix = "/milevadb/plugins/"
   238  				ctx, cancel := context.WithCancel(context.Background())
   239  				watcher := &flushWatcher{
   240  					ctx:      ctx,
   241  					cancel:   cancel,
   242  					path:     pluginWatchPrefix + tiPlugins.plugins[HoTT][i].Name,
   243  					etcd:     cfg.EtcdClient,
   244  					manifest: tiPlugins.plugins[HoTT][i].Manifest,
   245  					plugin:   &tiPlugins.plugins[HoTT][i],
   246  				}
   247  				tiPlugins.plugins[HoTT][i].flushWatcher = watcher
   248  				go soliton.WithRecovery(watcher.watchLoop, nil)
   249  			}
   250  			tiPlugins.plugins[HoTT][i].State = Ready
   251  		}
   252  	}
   253  	return
   254  }
   255  
   256  type flushWatcher struct {
   257  	ctx      context.Context
   258  	cancel   context.CancelFunc
   259  	path     string
   260  	etcd     *clientv3.Client
   261  	manifest *Manifest
   262  	plugin   *Plugin
   263  }
   264  
   265  func (w *flushWatcher) watchLoop() {
   266  	watchChan := w.etcd.Watch(w.ctx, w.path)
   267  	for {
   268  		select {
   269  		case <-w.ctx.Done():
   270  			return
   271  		case <-watchChan:
   272  			disabled, err := w.getPluginDisabledFlag()
   273  			if err != nil {
   274  				logutil.BgLogger().Error("get plugin disabled flag failure", zap.String("plugin", w.manifest.Name), zap.Error(err))
   275  			}
   276  			if disabled {
   277  				atomic.StoreUint32(&w.manifest.flushWatcher.plugin.Disabled, 1)
   278  			} else {
   279  				atomic.StoreUint32(&w.manifest.flushWatcher.plugin.Disabled, 0)
   280  			}
   281  			err = w.manifest.OnFlush(w.ctx, w.manifest)
   282  			if err != nil {
   283  				logutil.BgLogger().Error("notify plugin flush event failed", zap.String("plugin", w.manifest.Name), zap.Error(err))
   284  			}
   285  		}
   286  	}
   287  }
   288  
   289  func (w *flushWatcher) getPluginDisabledFlag() (bool, error) {
   290  	if w == nil || w.etcd == nil {
   291  		return true, errors.New("etcd is need to get plugin enable status")
   292  	}
   293  	resp, err := w.etcd.Get(context.Background(), w.manifest.flushWatcher.path)
   294  	if err != nil {
   295  		return true, errors.Trace(err)
   296  	}
   297  	if len(resp.Ekvs) == 0 {
   298  		return false, nil
   299  	}
   300  	return string(resp.Ekvs[0].Value) == "1", nil
   301  }
   302  
   303  type loadFn func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error)
   304  
   305  var testHook *struct {
   306  	loadOne loadFn
   307  }
   308  
   309  func loadOne(dir string, pluginID ID) (plugin Plugin, err error) {
   310  	pName, pVersion, err := pluginID.Decode()
   311  	if err != nil {
   312  		err = errors.Trace(err)
   313  		return
   314  	}
   315  	var manifest func() *Manifest
   316  	if testHook == nil {
   317  		manifest, err = loadManifestByGoPlugin(&plugin, dir, pluginID)
   318  	} else {
   319  		manifest, err = testHook.loadOne(&plugin, dir, pluginID)
   320  	}
   321  	if err != nil {
   322  		return
   323  	}
   324  	plugin.Manifest = manifest()
   325  	if plugin.Name != pName {
   326  		err = errInvalidPluginName.GenWithStackByArgs(string(pluginID), plugin.Name)
   327  		return
   328  	}
   329  	if strconv.Itoa(int(plugin.Version)) != pVersion {
   330  		err = errInvalidPluginVersion.GenWithStackByArgs(string(pluginID))
   331  		return
   332  	}
   333  	return
   334  }
   335  
   336  func loadManifestByGoPlugin(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
   337  	plugin.Path = filepath.Join(dir, string(pluginID)+LibrarySuffix)
   338  	plugin.library, err = gplugin.Open(plugin.Path)
   339  	if err != nil {
   340  		err = errors.Trace(err)
   341  		return
   342  	}
   343  	manifestSym, err := plugin.library.Lookup(ManifestSymbol)
   344  	if err != nil {
   345  		err = errors.Trace(err)
   346  		return
   347  	}
   348  	var ok bool
   349  	manifest, ok = manifestSym.(func() *Manifest)
   350  	if !ok {
   351  		err = errInvalidPluginManifest.GenWithStackByArgs(string(pluginID))
   352  		return
   353  	}
   354  	return
   355  }
   356  
   357  // Shutdown cleanups all plugin resources.
   358  // Notice: it just cleanups the resource of plugin, but cannot unload plugins(limited by go plugin).
   359  func Shutdown(ctx context.Context) {
   360  	for {
   361  		tiPlugins := pluginGlobal.plugins()
   362  		if tiPlugins == nil {
   363  			return
   364  		}
   365  		for _, plugins := range tiPlugins.plugins {
   366  			for _, p := range plugins {
   367  				p.State = Dying
   368  				if p.flushWatcher != nil {
   369  					p.flushWatcher.cancel()
   370  				}
   371  				if p.OnShutdown == nil {
   372  					continue
   373  				}
   374  				if err := p.OnShutdown(ctx, p.Manifest); err != nil {
   375  					logutil.Logger(ctx).Error("call OnShutdown for failure",
   376  						zap.String("plugin", p.Name), zap.Error(err))
   377  				}
   378  			}
   379  		}
   380  		if atomic.CompareAndSwapPointer(&pluginGlobal.tiPlugins, unsafe.Pointer(tiPlugins), nil) {
   381  			return
   382  		}
   383  	}
   384  }
   385  
   386  // Get finds and returns plugin by HoTT and name parameters.
   387  func Get(HoTT HoTT, name string) *Plugin {
   388  	plugins := pluginGlobal.plugins()
   389  	if plugins == nil {
   390  		return nil
   391  	}
   392  	for _, p := range plugins.plugins[HoTT] {
   393  		if p.Name == name {
   394  			return &p
   395  		}
   396  	}
   397  	return nil
   398  }
   399  
   400  // ForeachPlugin loops all ready plugins.
   401  func ForeachPlugin(HoTT HoTT, fn func(plugin *Plugin) error) error {
   402  	plugins := pluginGlobal.plugins()
   403  	if plugins == nil {
   404  		return nil
   405  	}
   406  	for i := range plugins.plugins[HoTT] {
   407  		p := &plugins.plugins[HoTT][i]
   408  		if p.State != Ready {
   409  			continue
   410  		}
   411  		if atomic.LoadUint32(&p.Disabled) == 1 {
   412  			continue
   413  		}
   414  		err := fn(p)
   415  		if err != nil {
   416  			return err
   417  		}
   418  	}
   419  	return nil
   420  }
   421  
   422  // IsEnable checks plugin's enable state.
   423  func IsEnable(HoTT HoTT) bool {
   424  	plugins := pluginGlobal.plugins()
   425  	if plugins == nil {
   426  		return false
   427  	}
   428  	for i := range plugins.plugins[HoTT] {
   429  		p := &plugins.plugins[HoTT][i]
   430  		if p.State == Ready && atomic.LoadUint32(&p.Disabled) != 1 {
   431  			return true
   432  		}
   433  	}
   434  	return false
   435  }
   436  
   437  // GetAll finds and returns all plugins.
   438  func GetAll() map[HoTT][]Plugin {
   439  	plugins := pluginGlobal.plugins()
   440  	if plugins == nil {
   441  		return nil
   442  	}
   443  	return plugins.plugins
   444  }
   445  
   446  // NotifyFlush notify plugins to do flush logic.
   447  func NotifyFlush(dom *petri.Petri, pluginName string) error {
   448  	p := getByName(pluginName)
   449  	if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready {
   450  		return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with FIDel", pluginName)
   451  	}
   452  	_, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, strconv.Itoa(int(p.Disabled)))
   453  	if err != nil {
   454  		return err
   455  	}
   456  	return nil
   457  }
   458  
   459  // ChangeDisableFlagAndFlush changes plugin disable flag and notify other nodes to do same change.
   460  func ChangeDisableFlagAndFlush(dom *petri.Petri, pluginName string, disable bool) error {
   461  	p := getByName(pluginName)
   462  	if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready {
   463  		return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with FIDel", pluginName)
   464  	}
   465  	disableInt := uint32(0)
   466  	if disable {
   467  		disableInt = 1
   468  	}
   469  	atomic.StoreUint32(&p.Disabled, disableInt)
   470  	_, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, strconv.Itoa(int(disableInt)))
   471  	if err != nil {
   472  		return err
   473  	}
   474  	return nil
   475  }
   476  
   477  func getByName(pluginName string) *Plugin {
   478  	for _, plugins := range GetAll() {
   479  		for _, p := range plugins {
   480  			if p.Name == pluginName {
   481  				return &p
   482  			}
   483  		}
   484  	}
   485  	return nil
   486  }