github.com/yanndegat/hiera@v0.6.8/session/pluginloader.go (about)

     1  package session
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net"
    11  	"net/http"
    12  	"net/url"
    13  	"os"
    14  	"os/exec"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/lyraproj/dgo/dgo"
    21  	"github.com/lyraproj/dgo/loader"
    22  	"github.com/lyraproj/dgo/streamer"
    23  	"github.com/lyraproj/dgo/vf"
    24  	"github.com/lyraproj/hierasdk/hiera"
    25  	log "github.com/sirupsen/logrus"
    26  )
    27  
    28  // a plugin corresponds to a loaded process
    29  type plugin struct {
    30  	lock      sync.Mutex
    31  	wGroup    sync.WaitGroup
    32  	process   *os.Process
    33  	path      string
    34  	addr      string
    35  	network   string
    36  	functions map[string]interface{}
    37  }
    38  
    39  // a pluginRegistry keeps track of loaded plugins
    40  type pluginRegistry struct {
    41  	lock    sync.Mutex
    42  	plugins map[string]*plugin
    43  }
    44  
    45  // stopAll will stop all plugins that this registry is aware of and empty the registry
    46  func (r *pluginRegistry) stopAll() {
    47  	r.lock.Lock()
    48  	defer r.lock.Unlock()
    49  
    50  	for _, p := range r.plugins {
    51  		p.kill()
    52  	}
    53  	r.plugins = nil
    54  }
    55  
    56  func createPipe(path, name string, fn func() (io.ReadCloser, error)) io.ReadCloser {
    57  	pipe, err := fn()
    58  	if err != nil {
    59  		panic(fmt.Errorf(`unable to create %s pipe to plugin %s: %s`, name, path, err.Error()))
    60  	}
    61  	return pipe
    62  }
    63  
    64  // copyErrToLog propagates everything written on the plugin's stderr to the StandardLogger of this process.
    65  func copyErrToLog(path string, cmdErr io.Reader, wGroup *sync.WaitGroup) {
    66  	defer wGroup.Done()
    67  	out := log.StandardLogger().Out
    68  	reader := bufio.NewReaderSize(cmdErr, 0x10000)
    69  	for {
    70  		line, pfx, err := reader.ReadLine()
    71  		if err != nil {
    72  			if err != io.EOF {
    73  				log.Errorf(`error reading stderr of plugin %s: %s`, path, err.Error())
    74  			}
    75  			return
    76  		}
    77  		_, _ = out.Write(line)
    78  		if !pfx {
    79  			_, _ = out.Write([]byte{'\n'})
    80  		}
    81  	}
    82  }
    83  
    84  func awaitMetaData(metaCh chan interface{}, cmdOut io.Reader, wGroup *sync.WaitGroup) {
    85  	defer wGroup.Done()
    86  	var meta map[string]interface{}
    87  	dc := json.NewDecoder(cmdOut)
    88  	err := dc.Decode(&meta)
    89  	if err != nil {
    90  		metaCh <- err
    91  	} else {
    92  		metaCh <- meta
    93  	}
    94  }
    95  
    96  func ignoreOut(cmdOut io.Reader, wGroup *sync.WaitGroup) {
    97  	defer wGroup.Done()
    98  	toss := make([]byte, 0x1000)
    99  	for {
   100  		_, err := cmdOut.Read(toss)
   101  		if err == io.EOF {
   102  			return
   103  		}
   104  	}
   105  }
   106  
   107  const pluginTransportUnix = "unix"
   108  const pluginTransportTCP = "tcp"
   109  
   110  var defaultUnixSocketDir = "/tmp"
   111  
   112  // getUnixSocketDir resolves value of unixSocketDir
   113  func getUnixSocketDir(opts dgo.Map) string {
   114  	if v, ok := opts.Get("unixSocketDir").(dgo.String); ok {
   115  		return v.GoString()
   116  	}
   117  	if v := os.TempDir(); v != "" {
   118  		return v
   119  	}
   120  	return defaultUnixSocketDir
   121  }
   122  
   123  // getPluginTransport resolves value of pluginTransport
   124  func getPluginTransport(opts dgo.Map) string {
   125  	if v, ok := opts.Get("pluginTransport").(dgo.String); ok {
   126  		s := v.GoString()
   127  		switch s {
   128  		case
   129  			pluginTransportUnix,
   130  			pluginTransportTCP:
   131  			return s
   132  		}
   133  	}
   134  	return getDefaultPluginTransport()
   135  }
   136  
   137  // startPlugin will start the plugin loaded from the given path and register the functions that it makes available
   138  // with the given loader.
   139  func (r *pluginRegistry) startPlugin(opts dgo.Map, path string) dgo.Value {
   140  	r.lock.Lock()
   141  	defer r.lock.Unlock()
   142  
   143  	if r.plugins != nil {
   144  		if p, ok := r.plugins[path]; ok {
   145  			return p.functionMap()
   146  		}
   147  	}
   148  	cmd := initCmd(opts, path)
   149  	cmdErr := createPipe(path, `stderr`, cmd.StderrPipe)
   150  	cmdOut := createPipe(path, `stdout`, cmd.StdoutPipe)
   151  	if err := cmd.Start(); err != nil {
   152  		panic(fmt.Errorf(`unable to start plugin %s: %s`, path, err.Error()))
   153  	}
   154  
   155  	// Make sure the plugin process is killed if there is an error
   156  	defer func() {
   157  		if r := recover(); r != nil {
   158  			_ = cmd.Process.Kill()
   159  			panic(r)
   160  		}
   161  	}()
   162  
   163  	p := &plugin{path: path, process: cmd.Process}
   164  	p.wGroup.Add(1)
   165  	go copyErrToLog(path, cmdErr, &p.wGroup)
   166  
   167  	metaCh := make(chan interface{})
   168  	p.wGroup.Add(1)
   169  	go awaitMetaData(metaCh, cmdOut, &p.wGroup)
   170  
   171  	// Give plugin some time to respond with meta-info
   172  	timeout := time.After(time.Second * 3)
   173  	var meta map[string]interface{}
   174  	select {
   175  	case <-timeout:
   176  		panic(fmt.Errorf(`timeout while waiting for plugin %s to start`, path))
   177  	case mv := <-metaCh:
   178  		if err, ok := mv.(error); ok {
   179  			panic(fmt.Errorf(`error reading meta data of plugin %s: %s`, path, err.Error()))
   180  		}
   181  		meta = mv.(map[string]interface{})
   182  	}
   183  
   184  	// start a go routine that ignores other stuff that is written on plugin's stdout
   185  	p.wGroup.Add(1)
   186  	go ignoreOut(cmdOut, &p.wGroup)
   187  
   188  	if r.plugins == nil {
   189  		r.plugins = make(map[string]*plugin)
   190  	}
   191  	p.initialize(meta)
   192  	r.plugins[path] = p
   193  
   194  	return p.functionMap()
   195  }
   196  
   197  func initCmd(opts dgo.Map, path string) *exec.Cmd {
   198  	cmd := exec.Command(path)
   199  	cmd.Env = os.Environ()
   200  	cmd.Env = append(cmd.Env, `HIERA_MAGIC_COOKIE=`+strconv.Itoa(hiera.MagicCookie))
   201  	cmd.Env = append(cmd.Env, `HIERA_PLUGIN_SOCKET_DIR=`+getUnixSocketDir(opts))
   202  	cmd.Env = append(cmd.Env, `HIERA_PLUGIN_TRANSPORT=`+getPluginTransport(opts))
   203  	cmd.SysProcAttr = procAttrs
   204  	return cmd
   205  }
   206  
   207  func (p *plugin) kill() {
   208  	p.lock.Lock()
   209  	process := p.process
   210  	if process == nil {
   211  		return
   212  	}
   213  
   214  	defer func() {
   215  		p.wGroup.Wait()
   216  		p.process = nil
   217  		p.lock.Unlock()
   218  	}()
   219  
   220  	graceful := true
   221  	if err := terminateProc(process); err != nil {
   222  		graceful = false
   223  	}
   224  
   225  	if graceful {
   226  		done := make(chan bool)
   227  		go func() {
   228  			_, _ = process.Wait()
   229  			done <- true
   230  		}()
   231  		select {
   232  		case <-done:
   233  		case <-time.After(time.Second * 3):
   234  			_ = process.Kill()
   235  		}
   236  	} else {
   237  		// Graceful terminate failed. Just kill it!
   238  		_ = process.Kill()
   239  	}
   240  }
   241  
   242  // initialize the plugin with the given meta-data
   243  func (p *plugin) initialize(meta map[string]interface{}) {
   244  	v, ok := meta[`version`].(float64)
   245  	if !(ok && int(v) == hiera.ProtoVersion) {
   246  		panic(fmt.Errorf(`plugin %s uses unsupported protocol %v`, p.path, v))
   247  	}
   248  	p.addr, ok = meta[`address`].(string)
   249  	if !ok {
   250  		panic(fmt.Errorf(`plugin %s did not provide a valid address`, p.path))
   251  	}
   252  	p.network, ok = meta[`network`].(string)
   253  	if !ok {
   254  		log.Printf(`plugin %s did not provide a valid network, assuming tcp`, p.path)
   255  		p.network = `tcp`
   256  	}
   257  	p.functions, ok = meta[`functions`].(map[string]interface{})
   258  	if !ok {
   259  		panic(fmt.Errorf(`plugin %s did not provide a valid functions map`, p.path))
   260  	}
   261  }
   262  
   263  type luDispatch func(string) dgo.Function
   264  
   265  func (p *plugin) functionMap() dgo.Value {
   266  	m := vf.MutableMap()
   267  	for k, v := range p.functions {
   268  		names := v.([]interface{})
   269  		var df luDispatch
   270  		switch k {
   271  		case `data_dig`:
   272  			df = p.dataDigDispatch
   273  		case `data_hash`:
   274  			df = p.dataHashDispatch
   275  		default:
   276  			df = p.lookupKeyDispatch
   277  		}
   278  		for _, x := range names {
   279  			n := x.(string)
   280  			m.Put(n, df(n))
   281  		}
   282  	}
   283  	return loader.Multiple(m)
   284  }
   285  
   286  func (p *plugin) dataDigDispatch(name string) dgo.Function {
   287  	return vf.Value(func(pc hiera.ProviderContext, key dgo.Array) dgo.Value {
   288  		params := makeOptions(pc)
   289  		jp := streamer.MarshalJSON(key, nil)
   290  		params.Add(`key`, string(jp))
   291  		return p.callPlugin(`data_dig`, name, params)
   292  	}).(dgo.Function)
   293  }
   294  
   295  func (p *plugin) dataHashDispatch(name string) dgo.Function {
   296  	return vf.Value(func(pc hiera.ProviderContext) dgo.Value {
   297  		return p.callPlugin(`data_hash`, name, makeOptions(pc))
   298  	}).(dgo.Function)
   299  }
   300  
   301  func (p *plugin) lookupKeyDispatch(name string) dgo.Function {
   302  	return vf.Value(func(pc hiera.ProviderContext, key string) dgo.Value {
   303  		params := makeOptions(pc)
   304  		params.Add(`key`, key)
   305  		return p.callPlugin(`lookup_key`, name, params)
   306  	}).(dgo.Function)
   307  }
   308  
   309  func makeOptions(pc hiera.ProviderContext) url.Values {
   310  	params := make(url.Values)
   311  	opts := pc.OptionsMap()
   312  	if opts.Len() > 0 {
   313  		bld := bytes.Buffer{}
   314  		streamer.New(nil, streamer.DefaultOptions()).Stream(opts, streamer.JSON(&bld))
   315  		params.Add(`options`, strings.TrimSpace(bld.String()))
   316  	}
   317  	return params
   318  }
   319  
   320  func (p *plugin) callPlugin(luType, name string, params url.Values) dgo.Value {
   321  	var ad *url.URL
   322  	var err error
   323  
   324  	if p.network == pluginTransportUnix {
   325  		ad, err = url.Parse(fmt.Sprintf(`http://%s/%s/%s`, p.network, luType, name))
   326  	} else {
   327  		ad, err = url.Parse(fmt.Sprintf(`http://%s/%s/%s`, p.addr, luType, name))
   328  	}
   329  	if err != nil {
   330  		panic(err)
   331  	}
   332  	if len(params) > 0 {
   333  		ad.RawQuery = params.Encode()
   334  	}
   335  	us := ad.String()
   336  	client := http.Client{
   337  		Timeout: time.Second * 5,
   338  		Transport: &http.Transport{
   339  			Dial: func(_, _ string) (net.Conn, error) {
   340  				return net.Dial(p.network, p.addr)
   341  			},
   342  		},
   343  	}
   344  	resp, err := client.Get(us)
   345  	if err != nil {
   346  		panic(err.Error())
   347  	}
   348  
   349  	defer func() {
   350  		_ = resp.Body.Close()
   351  	}()
   352  	switch resp.StatusCode {
   353  	case http.StatusOK:
   354  		var bts []byte
   355  		if bts, err = ioutil.ReadAll(resp.Body); err == nil {
   356  			return streamer.UnmarshalJSON(bts, nil)
   357  		}
   358  	case http.StatusNotFound:
   359  		return nil
   360  	default:
   361  		var bts []byte
   362  		if bts, err = ioutil.ReadAll(resp.Body); err == nil {
   363  			err = fmt.Errorf(`%s %s: %s`, us, resp.Status, string(bts))
   364  		} else {
   365  			err = fmt.Errorf(`%s %s`, us, resp.Status)
   366  		}
   367  	}
   368  	panic(err)
   369  }