github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/appbuilder/rpc/rpc.go (about)

     1  // Copyright (c) 2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package rpc
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"reflect"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/AlecAivazis/survey/v2"
    21  	"github.com/choria-io/appbuilder/builder"
    22  	"github.com/choria-io/fisk"
    23  	"github.com/choria-io/go-choria/choria"
    24  	"github.com/choria-io/go-choria/client/discovery"
    25  	"github.com/choria-io/go-choria/config"
    26  	"github.com/choria-io/go-choria/inter"
    27  	"github.com/choria-io/go-choria/protocol"
    28  	"github.com/choria-io/go-choria/providers/agent/mcorpc/client"
    29  	"github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent"
    30  	"github.com/choria-io/go-choria/providers/agent/mcorpc/replyfmt"
    31  	"github.com/choria-io/go-choria/providers/appbuilder"
    32  	"github.com/gosuri/uiprogress"
    33  	"github.com/sirupsen/logrus"
    34  )
    35  
    36  type Flag struct {
    37  	builder.GenericFlag
    38  	ReplyFilter string `json:"reply_filter"`
    39  }
    40  
    41  type Request struct {
    42  	Agent  string                     `json:"agent"`
    43  	Action string                     `json:"action"`
    44  	Params map[string]string          `json:"inputs"`
    45  	Filter *discovery.StandardOptions `json:"filter"`
    46  }
    47  
    48  type Command struct {
    49  	StandardFilter        bool               `json:"std_filters"`
    50  	OutputFormatFlags     bool               `json:"output_format_flags"`
    51  	OutputFormat          string             `json:"output_format"`
    52  	Display               string             `json:"display"`
    53  	DisplayFlag           bool               `json:"display_flag"`
    54  	BatchFlags            bool               `json:"batch_flags"`
    55  	BatchSize             int                `json:"batch"`
    56  	BatchSleep            int                `json:"batch_sleep"`
    57  	NoProgress            bool               `json:"no_progress"`
    58  	AllNodesConfirmPrompt string             `json:"all_nodes_confirm_prompt"`
    59  	Flags                 []Flag             `json:"flags"`
    60  	Request               Request            `json:"request"`
    61  	Transform             *builder.Transform `json:"transform"`
    62  
    63  	builder.GenericCommand
    64  	builder.GenericSubCommands
    65  }
    66  
    67  type RPC struct {
    68  	b           *builder.AppBuilder
    69  	cmd         *fisk.CmdClause
    70  	fo          *discovery.StandardOptions
    71  	def         *Command
    72  	cfg         any
    73  	arguments   map[string]any
    74  	flags       map[string]any
    75  	senders     bool
    76  	json        bool
    77  	table       bool
    78  	display     string
    79  	batch       int
    80  	batchSleep  int
    81  	progressBar *uiprogress.Bar
    82  	log         builder.Logger
    83  	ctx         context.Context
    84  }
    85  
    86  func NewRPCCommand(b *builder.AppBuilder, j json.RawMessage, log builder.Logger) (builder.Command, error) {
    87  	rpc := &RPC{
    88  		arguments: map[string]any{},
    89  		flags:     map[string]any{},
    90  		def:       &Command{},
    91  		cfg:       b.Configuration(),
    92  		ctx:       b.Context(),
    93  		b:         b,
    94  		log:       log,
    95  	}
    96  
    97  	err := json.Unmarshal(j, rpc.def)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	return rpc, nil
   103  }
   104  
   105  func Register() error {
   106  	return builder.RegisterCommand("rpc", NewRPCCommand)
   107  }
   108  
   109  func MustRegister() {
   110  	builder.MustRegisterCommand("rpc", NewRPCCommand)
   111  }
   112  
   113  func (r *RPC) String() string { return fmt.Sprintf("%s (rpc)", r.def.Name) }
   114  
   115  func (r *RPC) Validate(log builder.Logger) error {
   116  	if r.def.Type != "rpc" {
   117  		return fmt.Errorf("not a rpc command")
   118  	}
   119  
   120  	var errs []string
   121  
   122  	err := r.def.GenericCommand.Validate(log)
   123  	if err != nil {
   124  		errs = append(errs, err.Error())
   125  	}
   126  
   127  	if r.def.Transform != nil {
   128  		err := r.def.Transform.Validate(log)
   129  		if err != nil {
   130  			errs = append(errs, err.Error())
   131  		}
   132  	}
   133  
   134  	if r.def.Request.Agent == "" {
   135  		errs = append(errs, "agent is required")
   136  	}
   137  	if r.def.Request.Action == "" {
   138  		errs = append(errs, "action is required")
   139  	}
   140  
   141  	if len(errs) > 0 {
   142  		return errors.New(strings.Join(errs, ", "))
   143  	}
   144  
   145  	return nil
   146  }
   147  
   148  func (r *RPC) SubCommands() []json.RawMessage {
   149  	return r.def.Commands
   150  }
   151  
   152  func (r *RPC) CreateCommand(app builder.KingpinCommand) (*fisk.CmdClause, error) {
   153  	r.cmd = builder.CreateGenericCommand(app, &r.def.GenericCommand, r.arguments, r.flags, r.b, r.runCommand)
   154  
   155  	switch {
   156  	case r.def.OutputFormatFlags && r.def.OutputFormat != "":
   157  		return nil, fmt.Errorf("only one of output_format_flags and output_format may be supplied to command %s", r.def.Name)
   158  
   159  	case r.def.OutputFormatFlags:
   160  		r.cmd.Flag("senders", "List only the names of matching nodes").BoolVar(&r.senders)
   161  		r.cmd.Flag("json", "Render results as JSON").BoolVar(&r.json)
   162  		r.cmd.Flag("table", "Render results as a table").BoolVar(&r.table)
   163  
   164  	case r.def.OutputFormat == "senders":
   165  		r.senders = true
   166  
   167  	case r.def.OutputFormat == "json":
   168  		r.json = true
   169  
   170  	case r.def.OutputFormat == "table":
   171  		r.table = true
   172  
   173  	case r.def.OutputFormat != "":
   174  		return nil, fmt.Errorf("invalid output format %q, valid formats are senders, json and table", r.def.OutputFormat)
   175  	}
   176  
   177  	if r.def.StandardFilter {
   178  		r.fo = discovery.NewStandardOptions()
   179  		r.fo.AddFilterFlags(r.cmd)
   180  		r.fo.AddFlatFileFlags(r.cmd)
   181  		r.fo.AddSelectionFlags(r.cmd)
   182  	}
   183  
   184  	switch {
   185  	case r.def.BatchFlags && r.def.BatchSize > 0:
   186  		return nil, fmt.Errorf("only one of batch_flags and batch may be supplied in command %s", r.def.Name)
   187  
   188  	case r.def.BatchFlags:
   189  		r.cmd.Flag("batch", "Do requests in batches").PlaceHolder("SIZE").IntVar(&r.batch)
   190  		r.cmd.Flag("batch-sleep", "Sleep time between batches").PlaceHolder("SECONDS").IntVar(&r.batchSleep)
   191  
   192  	case r.def.BatchSize > 0:
   193  		r.batch = r.def.BatchSize
   194  		r.batchSleep = r.def.BatchSleep
   195  	}
   196  
   197  	switch {
   198  	case r.def.DisplayFlag && r.def.Display != "":
   199  		return nil, fmt.Errorf("only one of display_flag and display may be supplied in command %s", r.def.Name)
   200  
   201  	case r.def.DisplayFlag:
   202  		r.cmd.Flag("display", "Display only a subset of results (ok, failed, all, none)").EnumVar(&r.display, "ok", "failed", "all", "none")
   203  
   204  	case r.def.Display != "":
   205  		r.display = r.def.Display
   206  	}
   207  
   208  	// because we define our own concept of a flag here we cant rely on flag processing from appbuilder
   209  	// so have to basically duplicate all this from appbuilder code.
   210  	for _, f := range r.def.Flags {
   211  		flag := r.cmd.Flag(f.Name, f.Description)
   212  		if f.Required {
   213  			flag.Required()
   214  		}
   215  
   216  		if f.PlaceHolder != "" {
   217  			flag.PlaceHolder(f.PlaceHolder)
   218  		}
   219  
   220  		if f.Default != nil {
   221  			flag.Default(fmt.Sprintf("%v", f.Default))
   222  		}
   223  
   224  		if f.EnvVar != "" {
   225  			flag.Envar(f.EnvVar)
   226  		}
   227  
   228  		if f.Short != "" {
   229  			flag.Short([]rune(f.Short)[0])
   230  		}
   231  
   232  		switch {
   233  		case len(f.Enum) > 0:
   234  			r.flags[f.Name] = flag.Enum(f.Enum...)
   235  		case f.Bool:
   236  			if f.Default == true || f.Default == "true" {
   237  				r.flags[f.Name] = flag.Bool()
   238  			} else {
   239  				r.flags[f.Name] = flag.UnNegatableBool()
   240  			}
   241  		default:
   242  			r.flags[f.Name] = flag.String()
   243  		}
   244  	}
   245  
   246  	return r.cmd, nil
   247  }
   248  
   249  func (r *RPC) configureProgressBar(fw inter.Framework, count int, expected int) {
   250  	if r.def.NoProgress {
   251  		return
   252  	}
   253  
   254  	width := fw.ProgressWidth()
   255  	if width == -1 {
   256  		fmt.Printf("\nInvoking %s#%s action\n\n", r.def.Request.Agent, r.def.Request.Action)
   257  		return
   258  	}
   259  
   260  	r.progressBar = uiprogress.AddBar(count).AppendCompleted().PrependElapsed()
   261  	r.progressBar.Width = width
   262  
   263  	fmt.Println()
   264  
   265  	r.progressBar.PrependFunc(func(b *uiprogress.Bar) string {
   266  		if b.Current() < expected {
   267  			return fw.Colorize("red", "%d / %d", b.Current(), count)
   268  		}
   269  
   270  		return fw.Colorize("green", "%d / %d", b.Current(), count)
   271  	})
   272  
   273  	uiprogress.Start()
   274  }
   275  
   276  func (r *RPC) setupFilter(fw inter.Framework) error {
   277  	var err error
   278  
   279  	if r.fo == nil {
   280  		r.fo = discovery.NewStandardOptions()
   281  	}
   282  
   283  	if r.def.Request.Filter != nil {
   284  		err = appbuilder.ProcessStdDiscoveryOptions(r.def.Request.Filter, r.arguments, r.flags, r.cfg)
   285  		if err != nil {
   286  			return err
   287  		}
   288  
   289  		r.fo.Merge(r.def.Request.Filter)
   290  	}
   291  
   292  	r.fo.SetDefaultsFromChoria(fw)
   293  
   294  	if r.def.AllNodesConfirmPrompt != "" && r.fo.NodesFile == "" {
   295  		f, err := r.fo.NewFilter(r.def.Request.Agent)
   296  		if err != nil {
   297  			return err
   298  		}
   299  		if f.Empty() {
   300  			ans := false
   301  			err := survey.AskOne(&survey.Confirm{Message: r.def.AllNodesConfirmPrompt, Default: false}, &ans)
   302  			if err != nil {
   303  				return err
   304  			}
   305  			if !ans {
   306  				return fmt.Errorf("aborted")
   307  			}
   308  		}
   309  	}
   310  
   311  	return nil
   312  }
   313  
   314  func (r *RPC) runCommand(_ *fisk.ParseContext) error {
   315  	var (
   316  		noisy   = !(r.json || r.senders || r.def.NoProgress || r.def.Transform != nil)
   317  		mu      = sync.Mutex{}
   318  		dt      time.Duration
   319  		targets []string
   320  	)
   321  
   322  	cfg, err := config.NewConfig(choria.UserConfig())
   323  	if err != nil {
   324  		return err
   325  	}
   326  
   327  	logger, ok := any(r.log).(*logrus.Logger)
   328  	if ok {
   329  		cfg.CustomLogger = logger
   330  	}
   331  
   332  	fw, err := choria.NewWithConfig(cfg)
   333  	if err != nil {
   334  		return err
   335  	}
   336  
   337  	log := fw.Logger(r.def.Name)
   338  
   339  	agent, err := client.New(fw, r.def.Request.Agent)
   340  	if err != nil {
   341  		return err
   342  	}
   343  
   344  	err = agent.ResolveDDL(r.ctx)
   345  	if err != nil {
   346  		return err
   347  	}
   348  
   349  	ddl := agent.DDL()
   350  	action, err := ddl.ActionInterface(r.def.Request.Action)
   351  	if err != nil {
   352  		return err
   353  	}
   354  
   355  	_, rpcInputs, opts, err := r.reqOptions(action)
   356  	if err != nil {
   357  		return err
   358  	}
   359  
   360  	results := &replyfmt.RPCResults{
   361  		Agent:   r.def.Request.Agent,
   362  		Action:  r.def.Request.Action,
   363  		Replies: []*replyfmt.RPCReply{},
   364  	}
   365  
   366  	opts = append(opts, client.ReplyHandler(func(pr protocol.Reply, reply *client.RPCReply) {
   367  		mu.Lock()
   368  		if reply != nil {
   369  			results.Replies = append(results.Replies, &replyfmt.RPCReply{Sender: pr.SenderID(), RPCReply: reply})
   370  			if r.progressBar != nil {
   371  				r.progressBar.Incr()
   372  			}
   373  		}
   374  		mu.Unlock()
   375  	}))
   376  
   377  	err = r.setupFilter(fw)
   378  	if err != nil {
   379  		return err
   380  	}
   381  
   382  	start := time.Now()
   383  	targets, dt, err = r.fo.Discover(r.ctx, fw, r.def.Request.Agent, true, noisy, log)
   384  	if err != nil {
   385  		return err
   386  	}
   387  	if len(targets) == 0 {
   388  		return fmt.Errorf("no nodes discovered")
   389  	}
   390  	opts = append(opts, client.Targets(targets))
   391  
   392  	if noisy {
   393  		if ddl.Metadata.Service {
   394  			r.configureProgressBar(fw, 1, 1)
   395  		} else {
   396  			r.configureProgressBar(fw, len(targets), len(targets))
   397  		}
   398  	}
   399  
   400  	if r.batch > 0 {
   401  		if r.batchSleep == 0 {
   402  			r.batchSleep = 1
   403  		}
   404  
   405  		opts = append(opts, client.InBatches(r.batch, r.batchSleep))
   406  	}
   407  
   408  	rpcres, err := agent.Do(r.ctx, r.def.Request.Action, rpcInputs, opts...)
   409  	if err != nil {
   410  		return err
   411  	}
   412  	results.Stats = rpcres.Stats()
   413  
   414  	if dt > 0 {
   415  		rpcres.Stats().OverrideDiscoveryTime(start, start.Add(dt))
   416  	}
   417  
   418  	if r.progressBar != nil {
   419  		uiprogress.Stop()
   420  		fmt.Println()
   421  	}
   422  
   423  	err = r.renderResults(fw, log, results, action)
   424  	if err != nil {
   425  		return err
   426  	}
   427  
   428  	return nil
   429  
   430  }
   431  
   432  func (r *RPC) transformResults(w io.Writer, results *replyfmt.RPCResults, action *agent.Action) error {
   433  	out := bytes.NewBuffer([]byte{})
   434  	err := results.RenderJSON(out, action)
   435  	if err != nil {
   436  		return err
   437  	}
   438  
   439  	res, err := r.def.Transform.TransformBytes(r.ctx, out.Bytes(), r.flags, r.arguments, r.b)
   440  	if err != nil {
   441  		return err
   442  	}
   443  
   444  	fmt.Fprintln(w, string(res))
   445  	return nil
   446  }
   447  
   448  func (r *RPC) renderResults(fw inter.Framework, log *logrus.Entry, results *replyfmt.RPCResults, action *agent.Action) (err error) {
   449  	switch {
   450  	case r.def.Transform != nil:
   451  		err = r.transformResults(os.Stdout, results, action)
   452  	case r.senders:
   453  		err = results.RenderNames(os.Stdout, r.json, false)
   454  	case r.table:
   455  		err = results.RenderTable(os.Stdout, action)
   456  	case r.json:
   457  		err = results.RenderJSON(os.Stdout, action)
   458  	default:
   459  		mode := replyfmt.DisplayDDL
   460  		switch r.display {
   461  		case "ok":
   462  			mode = replyfmt.DisplayOK
   463  		case "failed":
   464  			mode = replyfmt.DisplayFailed
   465  		case "all":
   466  			mode = replyfmt.DisplayAll
   467  		case "none":
   468  			mode = replyfmt.DisplayNone
   469  		}
   470  
   471  		err = results.RenderTXT(os.Stdout, action, false, false, mode, fw.Configuration().Color, log)
   472  	}
   473  
   474  	return err
   475  }
   476  
   477  func (r *RPC) reqOptions(action *agent.Action) (inputs map[string]string, rpcInputs map[string]any, opts []client.RequestOption, err error) {
   478  	opts = []client.RequestOption{}
   479  	inputs = map[string]string{}
   480  
   481  	for k, v := range r.def.Request.Params {
   482  		body, err := r.parseStateTemplate(v)
   483  		if err != nil {
   484  			return nil, nil, nil, err
   485  		}
   486  		if len(body) > 0 {
   487  			inputs[k] = body
   488  		}
   489  	}
   490  
   491  	filter := ""
   492  	dFlags := dereferenceArgsOrFlags(r.flags)
   493  	for _, flag := range r.def.Flags {
   494  		if dFlags[flag.Name] != "" {
   495  			if flag.ReplyFilter == "" {
   496  				continue
   497  			}
   498  
   499  			if filter != "" {
   500  				return nil, nil, nil, fmt.Errorf("only one filter flag can match")
   501  			}
   502  
   503  			body, err := r.parseStateTemplate(flag.ReplyFilter)
   504  			if err != nil {
   505  				return nil, nil, nil, err
   506  			}
   507  
   508  			filter = body
   509  			break
   510  		}
   511  	}
   512  
   513  	if filter != "" {
   514  		opts = append(opts, client.ReplyExprFilter(filter))
   515  	}
   516  
   517  	rpcInputs, _, err = action.ValidateAndConvertToDDLTypes(inputs)
   518  	if err != nil {
   519  		return nil, nil, nil, err
   520  	}
   521  
   522  	return inputs, rpcInputs, opts, nil
   523  }
   524  
   525  func (r *RPC) parseStateTemplate(body string) (string, error) {
   526  	return builder.ParseStateTemplate(body, r.arguments, r.flags, r.cfg)
   527  }
   528  
   529  func dereferenceArgsOrFlags(input map[string]any) map[string]any {
   530  	res := map[string]any{}
   531  	for k, v := range input {
   532  		e := reflect.ValueOf(v).Elem()
   533  
   534  		// the only kinds of values we support
   535  		if e.Kind() == reflect.Bool {
   536  			res[k] = e.Bool()
   537  		} else {
   538  			res[k] = e.String()
   539  		}
   540  	}
   541  
   542  	return res
   543  }