github.com/yanndegat/hiera@v0.6.8/session/pluginloader.go (about) 1 package session 2 3 import ( 4 "bufio" 5 "bytes" 6 "encoding/json" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net" 11 "net/http" 12 "net/url" 13 "os" 14 "os/exec" 15 "strconv" 16 "strings" 17 "sync" 18 "time" 19 20 "github.com/lyraproj/dgo/dgo" 21 "github.com/lyraproj/dgo/loader" 22 "github.com/lyraproj/dgo/streamer" 23 "github.com/lyraproj/dgo/vf" 24 "github.com/lyraproj/hierasdk/hiera" 25 log "github.com/sirupsen/logrus" 26 ) 27 28 // a plugin corresponds to a loaded process 29 type plugin struct { 30 lock sync.Mutex 31 wGroup sync.WaitGroup 32 process *os.Process 33 path string 34 addr string 35 network string 36 functions map[string]interface{} 37 } 38 39 // a pluginRegistry keeps track of loaded plugins 40 type pluginRegistry struct { 41 lock sync.Mutex 42 plugins map[string]*plugin 43 } 44 45 // stopAll will stop all plugins that this registry is aware of and empty the registry 46 func (r *pluginRegistry) stopAll() { 47 r.lock.Lock() 48 defer r.lock.Unlock() 49 50 for _, p := range r.plugins { 51 p.kill() 52 } 53 r.plugins = nil 54 } 55 56 func createPipe(path, name string, fn func() (io.ReadCloser, error)) io.ReadCloser { 57 pipe, err := fn() 58 if err != nil { 59 panic(fmt.Errorf(`unable to create %s pipe to plugin %s: %s`, name, path, err.Error())) 60 } 61 return pipe 62 } 63 64 // copyErrToLog propagates everything written on the plugin's stderr to the StandardLogger of this process. 65 func copyErrToLog(path string, cmdErr io.Reader, wGroup *sync.WaitGroup) { 66 defer wGroup.Done() 67 out := log.StandardLogger().Out 68 reader := bufio.NewReaderSize(cmdErr, 0x10000) 69 for { 70 line, pfx, err := reader.ReadLine() 71 if err != nil { 72 if err != io.EOF { 73 log.Errorf(`error reading stderr of plugin %s: %s`, path, err.Error()) 74 } 75 return 76 } 77 _, _ = out.Write(line) 78 if !pfx { 79 _, _ = out.Write([]byte{'\n'}) 80 } 81 } 82 } 83 84 func awaitMetaData(metaCh chan interface{}, cmdOut io.Reader, wGroup *sync.WaitGroup) { 85 defer wGroup.Done() 86 var meta map[string]interface{} 87 dc := json.NewDecoder(cmdOut) 88 err := dc.Decode(&meta) 89 if err != nil { 90 metaCh <- err 91 } else { 92 metaCh <- meta 93 } 94 } 95 96 func ignoreOut(cmdOut io.Reader, wGroup *sync.WaitGroup) { 97 defer wGroup.Done() 98 toss := make([]byte, 0x1000) 99 for { 100 _, err := cmdOut.Read(toss) 101 if err == io.EOF { 102 return 103 } 104 } 105 } 106 107 const pluginTransportUnix = "unix" 108 const pluginTransportTCP = "tcp" 109 110 var defaultUnixSocketDir = "/tmp" 111 112 // getUnixSocketDir resolves value of unixSocketDir 113 func getUnixSocketDir(opts dgo.Map) string { 114 if v, ok := opts.Get("unixSocketDir").(dgo.String); ok { 115 return v.GoString() 116 } 117 if v := os.TempDir(); v != "" { 118 return v 119 } 120 return defaultUnixSocketDir 121 } 122 123 // getPluginTransport resolves value of pluginTransport 124 func getPluginTransport(opts dgo.Map) string { 125 if v, ok := opts.Get("pluginTransport").(dgo.String); ok { 126 s := v.GoString() 127 switch s { 128 case 129 pluginTransportUnix, 130 pluginTransportTCP: 131 return s 132 } 133 } 134 return getDefaultPluginTransport() 135 } 136 137 // startPlugin will start the plugin loaded from the given path and register the functions that it makes available 138 // with the given loader. 139 func (r *pluginRegistry) startPlugin(opts dgo.Map, path string) dgo.Value { 140 r.lock.Lock() 141 defer r.lock.Unlock() 142 143 if r.plugins != nil { 144 if p, ok := r.plugins[path]; ok { 145 return p.functionMap() 146 } 147 } 148 cmd := initCmd(opts, path) 149 cmdErr := createPipe(path, `stderr`, cmd.StderrPipe) 150 cmdOut := createPipe(path, `stdout`, cmd.StdoutPipe) 151 if err := cmd.Start(); err != nil { 152 panic(fmt.Errorf(`unable to start plugin %s: %s`, path, err.Error())) 153 } 154 155 // Make sure the plugin process is killed if there is an error 156 defer func() { 157 if r := recover(); r != nil { 158 _ = cmd.Process.Kill() 159 panic(r) 160 } 161 }() 162 163 p := &plugin{path: path, process: cmd.Process} 164 p.wGroup.Add(1) 165 go copyErrToLog(path, cmdErr, &p.wGroup) 166 167 metaCh := make(chan interface{}) 168 p.wGroup.Add(1) 169 go awaitMetaData(metaCh, cmdOut, &p.wGroup) 170 171 // Give plugin some time to respond with meta-info 172 timeout := time.After(time.Second * 3) 173 var meta map[string]interface{} 174 select { 175 case <-timeout: 176 panic(fmt.Errorf(`timeout while waiting for plugin %s to start`, path)) 177 case mv := <-metaCh: 178 if err, ok := mv.(error); ok { 179 panic(fmt.Errorf(`error reading meta data of plugin %s: %s`, path, err.Error())) 180 } 181 meta = mv.(map[string]interface{}) 182 } 183 184 // start a go routine that ignores other stuff that is written on plugin's stdout 185 p.wGroup.Add(1) 186 go ignoreOut(cmdOut, &p.wGroup) 187 188 if r.plugins == nil { 189 r.plugins = make(map[string]*plugin) 190 } 191 p.initialize(meta) 192 r.plugins[path] = p 193 194 return p.functionMap() 195 } 196 197 func initCmd(opts dgo.Map, path string) *exec.Cmd { 198 cmd := exec.Command(path) 199 cmd.Env = os.Environ() 200 cmd.Env = append(cmd.Env, `HIERA_MAGIC_COOKIE=`+strconv.Itoa(hiera.MagicCookie)) 201 cmd.Env = append(cmd.Env, `HIERA_PLUGIN_SOCKET_DIR=`+getUnixSocketDir(opts)) 202 cmd.Env = append(cmd.Env, `HIERA_PLUGIN_TRANSPORT=`+getPluginTransport(opts)) 203 cmd.SysProcAttr = procAttrs 204 return cmd 205 } 206 207 func (p *plugin) kill() { 208 p.lock.Lock() 209 process := p.process 210 if process == nil { 211 return 212 } 213 214 defer func() { 215 p.wGroup.Wait() 216 p.process = nil 217 p.lock.Unlock() 218 }() 219 220 graceful := true 221 if err := terminateProc(process); err != nil { 222 graceful = false 223 } 224 225 if graceful { 226 done := make(chan bool) 227 go func() { 228 _, _ = process.Wait() 229 done <- true 230 }() 231 select { 232 case <-done: 233 case <-time.After(time.Second * 3): 234 _ = process.Kill() 235 } 236 } else { 237 // Graceful terminate failed. Just kill it! 238 _ = process.Kill() 239 } 240 } 241 242 // initialize the plugin with the given meta-data 243 func (p *plugin) initialize(meta map[string]interface{}) { 244 v, ok := meta[`version`].(float64) 245 if !(ok && int(v) == hiera.ProtoVersion) { 246 panic(fmt.Errorf(`plugin %s uses unsupported protocol %v`, p.path, v)) 247 } 248 p.addr, ok = meta[`address`].(string) 249 if !ok { 250 panic(fmt.Errorf(`plugin %s did not provide a valid address`, p.path)) 251 } 252 p.network, ok = meta[`network`].(string) 253 if !ok { 254 log.Printf(`plugin %s did not provide a valid network, assuming tcp`, p.path) 255 p.network = `tcp` 256 } 257 p.functions, ok = meta[`functions`].(map[string]interface{}) 258 if !ok { 259 panic(fmt.Errorf(`plugin %s did not provide a valid functions map`, p.path)) 260 } 261 } 262 263 type luDispatch func(string) dgo.Function 264 265 func (p *plugin) functionMap() dgo.Value { 266 m := vf.MutableMap() 267 for k, v := range p.functions { 268 names := v.([]interface{}) 269 var df luDispatch 270 switch k { 271 case `data_dig`: 272 df = p.dataDigDispatch 273 case `data_hash`: 274 df = p.dataHashDispatch 275 default: 276 df = p.lookupKeyDispatch 277 } 278 for _, x := range names { 279 n := x.(string) 280 m.Put(n, df(n)) 281 } 282 } 283 return loader.Multiple(m) 284 } 285 286 func (p *plugin) dataDigDispatch(name string) dgo.Function { 287 return vf.Value(func(pc hiera.ProviderContext, key dgo.Array) dgo.Value { 288 params := makeOptions(pc) 289 jp := streamer.MarshalJSON(key, nil) 290 params.Add(`key`, string(jp)) 291 return p.callPlugin(`data_dig`, name, params) 292 }).(dgo.Function) 293 } 294 295 func (p *plugin) dataHashDispatch(name string) dgo.Function { 296 return vf.Value(func(pc hiera.ProviderContext) dgo.Value { 297 return p.callPlugin(`data_hash`, name, makeOptions(pc)) 298 }).(dgo.Function) 299 } 300 301 func (p *plugin) lookupKeyDispatch(name string) dgo.Function { 302 return vf.Value(func(pc hiera.ProviderContext, key string) dgo.Value { 303 params := makeOptions(pc) 304 params.Add(`key`, key) 305 return p.callPlugin(`lookup_key`, name, params) 306 }).(dgo.Function) 307 } 308 309 func makeOptions(pc hiera.ProviderContext) url.Values { 310 params := make(url.Values) 311 opts := pc.OptionsMap() 312 if opts.Len() > 0 { 313 bld := bytes.Buffer{} 314 streamer.New(nil, streamer.DefaultOptions()).Stream(opts, streamer.JSON(&bld)) 315 params.Add(`options`, strings.TrimSpace(bld.String())) 316 } 317 return params 318 } 319 320 func (p *plugin) callPlugin(luType, name string, params url.Values) dgo.Value { 321 var ad *url.URL 322 var err error 323 324 if p.network == pluginTransportUnix { 325 ad, err = url.Parse(fmt.Sprintf(`http://%s/%s/%s`, p.network, luType, name)) 326 } else { 327 ad, err = url.Parse(fmt.Sprintf(`http://%s/%s/%s`, p.addr, luType, name)) 328 } 329 if err != nil { 330 panic(err) 331 } 332 if len(params) > 0 { 333 ad.RawQuery = params.Encode() 334 } 335 us := ad.String() 336 client := http.Client{ 337 Timeout: time.Second * 5, 338 Transport: &http.Transport{ 339 Dial: func(_, _ string) (net.Conn, error) { 340 return net.Dial(p.network, p.addr) 341 }, 342 }, 343 } 344 resp, err := client.Get(us) 345 if err != nil { 346 panic(err.Error()) 347 } 348 349 defer func() { 350 _ = resp.Body.Close() 351 }() 352 switch resp.StatusCode { 353 case http.StatusOK: 354 var bts []byte 355 if bts, err = ioutil.ReadAll(resp.Body); err == nil { 356 return streamer.UnmarshalJSON(bts, nil) 357 } 358 case http.StatusNotFound: 359 return nil 360 default: 361 var bts []byte 362 if bts, err = ioutil.ReadAll(resp.Body); err == nil { 363 err = fmt.Errorf(`%s %s: %s`, us, resp.Status, string(bts)) 364 } else { 365 err = fmt.Errorf(`%s %s`, us, resp.Status) 366 } 367 } 368 panic(err) 369 }