github.com/ferranbt/nomad@v0.9.3-0.20190607002617-85c449b7667c/plugins/shared/cmd/launcher/command/device.go (about)

     1  package command
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"os/exec"
    10  	"strings"
    11  	"time"
    12  
    13  	hclog "github.com/hashicorp/go-hclog"
    14  	plugin "github.com/hashicorp/go-plugin"
    15  	"github.com/hashicorp/hcl"
    16  	"github.com/hashicorp/hcl/hcl/ast"
    17  	"github.com/hashicorp/hcl2/hcldec"
    18  	"github.com/hashicorp/nomad/helper/pluginutils/hclspecutils"
    19  	"github.com/hashicorp/nomad/helper/pluginutils/hclutils"
    20  	"github.com/hashicorp/nomad/plugins/base"
    21  	"github.com/hashicorp/nomad/plugins/device"
    22  	"github.com/kr/pretty"
    23  	"github.com/mitchellh/cli"
    24  	"github.com/zclconf/go-cty/cty/msgpack"
    25  )
    26  
    27  func DeviceCommandFactory(meta Meta) cli.CommandFactory {
    28  	return func() (cli.Command, error) {
    29  		return &Device{Meta: meta}, nil
    30  	}
    31  }
    32  
    33  type Device struct {
    34  	Meta
    35  
    36  	// dev is the plugin device
    37  	dev device.DevicePlugin
    38  
    39  	// spec is the returned and parsed spec.
    40  	spec hcldec.Spec
    41  }
    42  
    43  func (c *Device) Help() string {
    44  	helpText := `
    45  Usage: nomad-plugin-launcher device <device-binary> <config_file>
    46  
    47    Device launches the given device binary and provides a REPL for interacting
    48    with it.
    49  
    50  General Options:
    51  
    52  ` + generalOptionsUsage() + `
    53  
    54  Device Options:
    55  
    56    -trace
    57      Enable trace level log output.
    58  `
    59  
    60  	return strings.TrimSpace(helpText)
    61  }
    62  
    63  func (c *Device) Synopsis() string {
    64  	return "REPL for interacting with device plugins"
    65  }
    66  
    67  func (c *Device) Run(args []string) int {
    68  	var trace bool
    69  	cmdFlags := c.FlagSet("device")
    70  	cmdFlags.Usage = func() { c.Ui.Output(c.Help()) }
    71  	cmdFlags.BoolVar(&trace, "trace", false, "")
    72  
    73  	if err := cmdFlags.Parse(args); err != nil {
    74  		c.logger.Error("failed to parse flags:", "error", err)
    75  		return 1
    76  	}
    77  	if trace {
    78  		c.logger.SetLevel(hclog.Trace)
    79  	} else if c.verbose {
    80  		c.logger.SetLevel(hclog.Debug)
    81  	}
    82  
    83  	args = cmdFlags.Args()
    84  	numArgs := len(args)
    85  	if numArgs < 1 {
    86  		c.logger.Error("expected at least 1 args (device binary)", "args", args)
    87  		return 1
    88  	} else if numArgs > 2 {
    89  		c.logger.Error("expected at most 2 args (device binary and config file)", "args", args)
    90  		return 1
    91  	}
    92  
    93  	binary := args[0]
    94  	var config []byte
    95  	if numArgs == 2 {
    96  		var err error
    97  		config, err = ioutil.ReadFile(args[1])
    98  		if err != nil {
    99  			c.logger.Error("failed to read config file", "error", err)
   100  			return 1
   101  		}
   102  
   103  		c.logger.Trace("read config", "config", string(config))
   104  	}
   105  
   106  	// Get the plugin
   107  	dev, cleanup, err := c.getDevicePlugin(binary)
   108  	if err != nil {
   109  		c.logger.Error("failed to launch device plugin", "error", err)
   110  		return 1
   111  	}
   112  	defer cleanup()
   113  	c.dev = dev
   114  
   115  	spec, err := c.getSpec()
   116  	if err != nil {
   117  		c.logger.Error("failed to get config spec", "error", err)
   118  		return 1
   119  	}
   120  	c.spec = spec
   121  
   122  	if err := c.setConfig(spec, device.ApiVersion010, config, nil); err != nil {
   123  		c.logger.Error("failed to set config", "error", err)
   124  		return 1
   125  	}
   126  
   127  	if err := c.startRepl(); err != nil {
   128  		c.logger.Error("error interacting with plugin", "error", err)
   129  		return 1
   130  	}
   131  
   132  	return 0
   133  }
   134  
   135  func (c *Device) getDevicePlugin(binary string) (device.DevicePlugin, func(), error) {
   136  	// Launch the plugin
   137  	client := plugin.NewClient(&plugin.ClientConfig{
   138  		HandshakeConfig: base.Handshake,
   139  		Plugins: map[string]plugin.Plugin{
   140  			base.PluginTypeBase:   &base.PluginBase{},
   141  			base.PluginTypeDevice: &device.PluginDevice{},
   142  		},
   143  		Cmd:              exec.Command(binary),
   144  		AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC},
   145  		Logger:           c.logger,
   146  	})
   147  
   148  	// Connect via RPC
   149  	rpcClient, err := client.Client()
   150  	if err != nil {
   151  		client.Kill()
   152  		return nil, nil, err
   153  	}
   154  
   155  	// Request the plugin
   156  	raw, err := rpcClient.Dispense(base.PluginTypeDevice)
   157  	if err != nil {
   158  		client.Kill()
   159  		return nil, nil, err
   160  	}
   161  
   162  	// We should have a KV store now! This feels like a normal interface
   163  	// implementation but is in fact over an RPC connection.
   164  	dev := raw.(device.DevicePlugin)
   165  	return dev, func() { client.Kill() }, nil
   166  }
   167  
   168  func (c *Device) getSpec() (hcldec.Spec, error) {
   169  	// Get the schema so we can parse the config
   170  	spec, err := c.dev.ConfigSchema()
   171  	if err != nil {
   172  		return nil, fmt.Errorf("failed to get config schema: %v", err)
   173  	}
   174  
   175  	c.logger.Trace("device spec", "spec", hclog.Fmt("% #v", pretty.Formatter(spec)))
   176  
   177  	// Convert the schema
   178  	schema, diag := hclspecutils.Convert(spec)
   179  	if diag.HasErrors() {
   180  		errStr := "failed to convert HCL schema: "
   181  		for _, err := range diag.Errs() {
   182  			errStr = fmt.Sprintf("%s\n* %s", errStr, err.Error())
   183  		}
   184  		return nil, errors.New(errStr)
   185  	}
   186  
   187  	return schema, nil
   188  }
   189  
   190  func (c *Device) setConfig(spec hcldec.Spec, apiVersion string, config []byte, nmdCfg *base.AgentConfig) error {
   191  	// Parse the config into hcl
   192  	configVal, err := hclConfigToInterface(config)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	c.logger.Trace("raw hcl config", "config", hclog.Fmt("% #v", pretty.Formatter(configVal)))
   198  
   199  	val, diag := hclutils.ParseHclInterface(configVal, spec, nil)
   200  	if diag.HasErrors() {
   201  		errStr := "failed to parse config"
   202  		for _, err := range diag.Errs() {
   203  			errStr = fmt.Sprintf("%s\n* %s", errStr, err.Error())
   204  		}
   205  		return errors.New(errStr)
   206  	}
   207  	c.logger.Trace("parsed hcl config", "config", hclog.Fmt("% #v", pretty.Formatter(val)))
   208  
   209  	cdata, err := msgpack.Marshal(val, val.Type())
   210  	if err != nil {
   211  		return err
   212  	}
   213  
   214  	req := &base.Config{
   215  		PluginConfig: config,
   216  		AgentConfig:  nmdCfg,
   217  		ApiVersion:   apiVersion,
   218  	}
   219  
   220  	c.logger.Trace("msgpack config", "config", string(cdata))
   221  	if err := c.dev.SetConfig(req); err != nil {
   222  		return err
   223  	}
   224  
   225  	return nil
   226  }
   227  
   228  func hclConfigToInterface(config []byte) (interface{}, error) {
   229  	if len(config) == 0 {
   230  		return map[string]interface{}{}, nil
   231  	}
   232  
   233  	// Parse as we do in the jobspec parser
   234  	root, err := hcl.Parse(string(config))
   235  	if err != nil {
   236  		return nil, fmt.Errorf("failed to hcl parse the config: %v", err)
   237  	}
   238  
   239  	// Top-level item should be a list
   240  	list, ok := root.Node.(*ast.ObjectList)
   241  	if !ok {
   242  		return nil, fmt.Errorf("root should be an object")
   243  	}
   244  
   245  	var m map[string]interface{}
   246  	if err := hcl.DecodeObject(&m, list.Items[0]); err != nil {
   247  		return nil, fmt.Errorf("failed to decode object: %v", err)
   248  	}
   249  
   250  	return m["config"], nil
   251  }
   252  
   253  func (c *Device) startRepl() error {
   254  	// Start the output goroutine
   255  	ctx, cancel := context.WithCancel(context.Background())
   256  	defer cancel()
   257  	fingerprint := make(chan context.Context)
   258  	stats := make(chan context.Context)
   259  	reserve := make(chan []string)
   260  	go c.replOutput(ctx, fingerprint, stats, reserve)
   261  
   262  	c.Ui.Output("> Availabile commands are: exit(), fingerprint(), stop_fingerprint(), stats(), stop_stats(), reserve(id1, id2, ...)")
   263  	var fingerprintCtx, statsCtx context.Context
   264  	var fingerprintCancel, statsCancel context.CancelFunc
   265  
   266  	for {
   267  		in, err := c.Ui.Ask("> ")
   268  		if err != nil {
   269  			if fingerprintCancel != nil {
   270  				fingerprintCancel()
   271  			}
   272  			if statsCancel != nil {
   273  				statsCancel()
   274  			}
   275  			return err
   276  		}
   277  
   278  		switch {
   279  		case in == "exit()":
   280  			if fingerprintCancel != nil {
   281  				fingerprintCancel()
   282  			}
   283  			if statsCancel != nil {
   284  				statsCancel()
   285  			}
   286  			return nil
   287  		case in == "fingerprint()":
   288  			if fingerprintCtx != nil {
   289  				continue
   290  			}
   291  			fingerprintCtx, fingerprintCancel = context.WithCancel(ctx)
   292  			fingerprint <- fingerprintCtx
   293  		case in == "stop_fingerprint()":
   294  			if fingerprintCtx == nil {
   295  				continue
   296  			}
   297  			fingerprintCancel()
   298  			fingerprintCtx = nil
   299  		case in == "stats()":
   300  			if statsCtx != nil {
   301  				continue
   302  			}
   303  			statsCtx, statsCancel = context.WithCancel(ctx)
   304  			stats <- statsCtx
   305  		case in == "stop_stats()":
   306  			if statsCtx == nil {
   307  				continue
   308  			}
   309  			statsCancel()
   310  			statsCtx = nil
   311  		case strings.HasPrefix(in, "reserve(") && strings.HasSuffix(in, ")"):
   312  			listString := strings.TrimSuffix(strings.TrimPrefix(in, "reserve("), ")")
   313  			ids := strings.Split(strings.TrimSpace(listString), ",")
   314  			reserve <- ids
   315  		default:
   316  			c.Ui.Error(fmt.Sprintf("> Unknown command %q", in))
   317  		}
   318  	}
   319  }
   320  
   321  func (c *Device) replOutput(ctx context.Context, startFingerprint, startStats <-chan context.Context, reserve <-chan []string) {
   322  	var fingerprint <-chan *device.FingerprintResponse
   323  	var stats <-chan *device.StatsResponse
   324  	for {
   325  		select {
   326  		case <-ctx.Done():
   327  			return
   328  		case ctx := <-startFingerprint:
   329  			var err error
   330  			fingerprint, err = c.dev.Fingerprint(ctx)
   331  			if err != nil {
   332  				c.Ui.Error(fmt.Sprintf("fingerprint: %s", err))
   333  				os.Exit(1)
   334  			}
   335  		case resp, ok := <-fingerprint:
   336  			if !ok {
   337  				c.Ui.Output("> fingerprint: fingerprint output closed")
   338  				fingerprint = nil
   339  				continue
   340  			}
   341  
   342  			if resp == nil {
   343  				c.Ui.Warn("> fingerprint: received nil result")
   344  				os.Exit(1)
   345  			}
   346  
   347  			c.Ui.Output(fmt.Sprintf("> fingerprint: % #v", pretty.Formatter(resp)))
   348  		case ctx := <-startStats:
   349  			var err error
   350  			stats, err = c.dev.Stats(ctx, 1*time.Second)
   351  			if err != nil {
   352  				c.Ui.Error(fmt.Sprintf("stats: %s", err))
   353  				os.Exit(1)
   354  			}
   355  		case resp, ok := <-stats:
   356  			if !ok {
   357  				c.Ui.Output("> stats: stats output closed")
   358  				stats = nil
   359  				continue
   360  			}
   361  
   362  			if resp == nil {
   363  				c.Ui.Warn("> stats: received nil result")
   364  				os.Exit(1)
   365  			}
   366  
   367  			c.Ui.Output(fmt.Sprintf("> stats: % #v", pretty.Formatter(resp)))
   368  		case ids := <-reserve:
   369  			resp, err := c.dev.Reserve(ids)
   370  			if err != nil {
   371  				c.Ui.Warn(fmt.Sprintf("> reserve(%s): %v", strings.Join(ids, ", "), err))
   372  			} else {
   373  				c.Ui.Output(fmt.Sprintf("> reserve(%s): % #v", strings.Join(ids, ", "), pretty.Formatter(resp)))
   374  			}
   375  		}
   376  	}
   377  }