github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/appbuilder/rpc/rpc.go (about) 1 // Copyright (c) 2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package rpc 6 7 import ( 8 "bytes" 9 "context" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "io" 14 "os" 15 "reflect" 16 "strings" 17 "sync" 18 "time" 19 20 "github.com/AlecAivazis/survey/v2" 21 "github.com/choria-io/appbuilder/builder" 22 "github.com/choria-io/fisk" 23 "github.com/choria-io/go-choria/choria" 24 "github.com/choria-io/go-choria/client/discovery" 25 "github.com/choria-io/go-choria/config" 26 "github.com/choria-io/go-choria/inter" 27 "github.com/choria-io/go-choria/protocol" 28 "github.com/choria-io/go-choria/providers/agent/mcorpc/client" 29 "github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent" 30 "github.com/choria-io/go-choria/providers/agent/mcorpc/replyfmt" 31 "github.com/choria-io/go-choria/providers/appbuilder" 32 "github.com/gosuri/uiprogress" 33 "github.com/sirupsen/logrus" 34 ) 35 36 type Flag struct { 37 builder.GenericFlag 38 ReplyFilter string `json:"reply_filter"` 39 } 40 41 type Request struct { 42 Agent string `json:"agent"` 43 Action string `json:"action"` 44 Params map[string]string `json:"inputs"` 45 Filter *discovery.StandardOptions `json:"filter"` 46 } 47 48 type Command struct { 49 StandardFilter bool `json:"std_filters"` 50 OutputFormatFlags bool `json:"output_format_flags"` 51 OutputFormat string `json:"output_format"` 52 Display string `json:"display"` 53 DisplayFlag bool `json:"display_flag"` 54 BatchFlags bool `json:"batch_flags"` 55 BatchSize int `json:"batch"` 56 BatchSleep int `json:"batch_sleep"` 57 NoProgress bool `json:"no_progress"` 58 AllNodesConfirmPrompt string `json:"all_nodes_confirm_prompt"` 59 Flags []Flag `json:"flags"` 60 Request Request `json:"request"` 61 Transform *builder.Transform `json:"transform"` 62 63 builder.GenericCommand 64 builder.GenericSubCommands 65 } 66 67 type RPC struct { 68 b *builder.AppBuilder 69 cmd *fisk.CmdClause 70 fo *discovery.StandardOptions 71 def *Command 72 cfg any 73 arguments map[string]any 74 flags map[string]any 75 senders bool 76 json bool 77 table bool 78 display string 79 batch int 80 batchSleep int 81 progressBar *uiprogress.Bar 82 log builder.Logger 83 ctx context.Context 84 } 85 86 func NewRPCCommand(b *builder.AppBuilder, j json.RawMessage, log builder.Logger) (builder.Command, error) { 87 rpc := &RPC{ 88 arguments: map[string]any{}, 89 flags: map[string]any{}, 90 def: &Command{}, 91 cfg: b.Configuration(), 92 ctx: b.Context(), 93 b: b, 94 log: log, 95 } 96 97 err := json.Unmarshal(j, rpc.def) 98 if err != nil { 99 return nil, err 100 } 101 102 return rpc, nil 103 } 104 105 func Register() error { 106 return builder.RegisterCommand("rpc", NewRPCCommand) 107 } 108 109 func MustRegister() { 110 builder.MustRegisterCommand("rpc", NewRPCCommand) 111 } 112 113 func (r *RPC) String() string { return fmt.Sprintf("%s (rpc)", r.def.Name) } 114 115 func (r *RPC) Validate(log builder.Logger) error { 116 if r.def.Type != "rpc" { 117 return fmt.Errorf("not a rpc command") 118 } 119 120 var errs []string 121 122 err := r.def.GenericCommand.Validate(log) 123 if err != nil { 124 errs = append(errs, err.Error()) 125 } 126 127 if r.def.Transform != nil { 128 err := r.def.Transform.Validate(log) 129 if err != nil { 130 errs = append(errs, err.Error()) 131 } 132 } 133 134 if r.def.Request.Agent == "" { 135 errs = append(errs, "agent is required") 136 } 137 if r.def.Request.Action == "" { 138 errs = append(errs, "action is required") 139 } 140 141 if len(errs) > 0 { 142 return errors.New(strings.Join(errs, ", ")) 143 } 144 145 return nil 146 } 147 148 func (r *RPC) SubCommands() []json.RawMessage { 149 return r.def.Commands 150 } 151 152 func (r *RPC) CreateCommand(app builder.KingpinCommand) (*fisk.CmdClause, error) { 153 r.cmd = builder.CreateGenericCommand(app, &r.def.GenericCommand, r.arguments, r.flags, r.b, r.runCommand) 154 155 switch { 156 case r.def.OutputFormatFlags && r.def.OutputFormat != "": 157 return nil, fmt.Errorf("only one of output_format_flags and output_format may be supplied to command %s", r.def.Name) 158 159 case r.def.OutputFormatFlags: 160 r.cmd.Flag("senders", "List only the names of matching nodes").BoolVar(&r.senders) 161 r.cmd.Flag("json", "Render results as JSON").BoolVar(&r.json) 162 r.cmd.Flag("table", "Render results as a table").BoolVar(&r.table) 163 164 case r.def.OutputFormat == "senders": 165 r.senders = true 166 167 case r.def.OutputFormat == "json": 168 r.json = true 169 170 case r.def.OutputFormat == "table": 171 r.table = true 172 173 case r.def.OutputFormat != "": 174 return nil, fmt.Errorf("invalid output format %q, valid formats are senders, json and table", r.def.OutputFormat) 175 } 176 177 if r.def.StandardFilter { 178 r.fo = discovery.NewStandardOptions() 179 r.fo.AddFilterFlags(r.cmd) 180 r.fo.AddFlatFileFlags(r.cmd) 181 r.fo.AddSelectionFlags(r.cmd) 182 } 183 184 switch { 185 case r.def.BatchFlags && r.def.BatchSize > 0: 186 return nil, fmt.Errorf("only one of batch_flags and batch may be supplied in command %s", r.def.Name) 187 188 case r.def.BatchFlags: 189 r.cmd.Flag("batch", "Do requests in batches").PlaceHolder("SIZE").IntVar(&r.batch) 190 r.cmd.Flag("batch-sleep", "Sleep time between batches").PlaceHolder("SECONDS").IntVar(&r.batchSleep) 191 192 case r.def.BatchSize > 0: 193 r.batch = r.def.BatchSize 194 r.batchSleep = r.def.BatchSleep 195 } 196 197 switch { 198 case r.def.DisplayFlag && r.def.Display != "": 199 return nil, fmt.Errorf("only one of display_flag and display may be supplied in command %s", r.def.Name) 200 201 case r.def.DisplayFlag: 202 r.cmd.Flag("display", "Display only a subset of results (ok, failed, all, none)").EnumVar(&r.display, "ok", "failed", "all", "none") 203 204 case r.def.Display != "": 205 r.display = r.def.Display 206 } 207 208 // because we define our own concept of a flag here we cant rely on flag processing from appbuilder 209 // so have to basically duplicate all this from appbuilder code. 210 for _, f := range r.def.Flags { 211 flag := r.cmd.Flag(f.Name, f.Description) 212 if f.Required { 213 flag.Required() 214 } 215 216 if f.PlaceHolder != "" { 217 flag.PlaceHolder(f.PlaceHolder) 218 } 219 220 if f.Default != nil { 221 flag.Default(fmt.Sprintf("%v", f.Default)) 222 } 223 224 if f.EnvVar != "" { 225 flag.Envar(f.EnvVar) 226 } 227 228 if f.Short != "" { 229 flag.Short([]rune(f.Short)[0]) 230 } 231 232 switch { 233 case len(f.Enum) > 0: 234 r.flags[f.Name] = flag.Enum(f.Enum...) 235 case f.Bool: 236 if f.Default == true || f.Default == "true" { 237 r.flags[f.Name] = flag.Bool() 238 } else { 239 r.flags[f.Name] = flag.UnNegatableBool() 240 } 241 default: 242 r.flags[f.Name] = flag.String() 243 } 244 } 245 246 return r.cmd, nil 247 } 248 249 func (r *RPC) configureProgressBar(fw inter.Framework, count int, expected int) { 250 if r.def.NoProgress { 251 return 252 } 253 254 width := fw.ProgressWidth() 255 if width == -1 { 256 fmt.Printf("\nInvoking %s#%s action\n\n", r.def.Request.Agent, r.def.Request.Action) 257 return 258 } 259 260 r.progressBar = uiprogress.AddBar(count).AppendCompleted().PrependElapsed() 261 r.progressBar.Width = width 262 263 fmt.Println() 264 265 r.progressBar.PrependFunc(func(b *uiprogress.Bar) string { 266 if b.Current() < expected { 267 return fw.Colorize("red", "%d / %d", b.Current(), count) 268 } 269 270 return fw.Colorize("green", "%d / %d", b.Current(), count) 271 }) 272 273 uiprogress.Start() 274 } 275 276 func (r *RPC) setupFilter(fw inter.Framework) error { 277 var err error 278 279 if r.fo == nil { 280 r.fo = discovery.NewStandardOptions() 281 } 282 283 if r.def.Request.Filter != nil { 284 err = appbuilder.ProcessStdDiscoveryOptions(r.def.Request.Filter, r.arguments, r.flags, r.cfg) 285 if err != nil { 286 return err 287 } 288 289 r.fo.Merge(r.def.Request.Filter) 290 } 291 292 r.fo.SetDefaultsFromChoria(fw) 293 294 if r.def.AllNodesConfirmPrompt != "" && r.fo.NodesFile == "" { 295 f, err := r.fo.NewFilter(r.def.Request.Agent) 296 if err != nil { 297 return err 298 } 299 if f.Empty() { 300 ans := false 301 err := survey.AskOne(&survey.Confirm{Message: r.def.AllNodesConfirmPrompt, Default: false}, &ans) 302 if err != nil { 303 return err 304 } 305 if !ans { 306 return fmt.Errorf("aborted") 307 } 308 } 309 } 310 311 return nil 312 } 313 314 func (r *RPC) runCommand(_ *fisk.ParseContext) error { 315 var ( 316 noisy = !(r.json || r.senders || r.def.NoProgress || r.def.Transform != nil) 317 mu = sync.Mutex{} 318 dt time.Duration 319 targets []string 320 ) 321 322 cfg, err := config.NewConfig(choria.UserConfig()) 323 if err != nil { 324 return err 325 } 326 327 logger, ok := any(r.log).(*logrus.Logger) 328 if ok { 329 cfg.CustomLogger = logger 330 } 331 332 fw, err := choria.NewWithConfig(cfg) 333 if err != nil { 334 return err 335 } 336 337 log := fw.Logger(r.def.Name) 338 339 agent, err := client.New(fw, r.def.Request.Agent) 340 if err != nil { 341 return err 342 } 343 344 err = agent.ResolveDDL(r.ctx) 345 if err != nil { 346 return err 347 } 348 349 ddl := agent.DDL() 350 action, err := ddl.ActionInterface(r.def.Request.Action) 351 if err != nil { 352 return err 353 } 354 355 _, rpcInputs, opts, err := r.reqOptions(action) 356 if err != nil { 357 return err 358 } 359 360 results := &replyfmt.RPCResults{ 361 Agent: r.def.Request.Agent, 362 Action: r.def.Request.Action, 363 Replies: []*replyfmt.RPCReply{}, 364 } 365 366 opts = append(opts, client.ReplyHandler(func(pr protocol.Reply, reply *client.RPCReply) { 367 mu.Lock() 368 if reply != nil { 369 results.Replies = append(results.Replies, &replyfmt.RPCReply{Sender: pr.SenderID(), RPCReply: reply}) 370 if r.progressBar != nil { 371 r.progressBar.Incr() 372 } 373 } 374 mu.Unlock() 375 })) 376 377 err = r.setupFilter(fw) 378 if err != nil { 379 return err 380 } 381 382 start := time.Now() 383 targets, dt, err = r.fo.Discover(r.ctx, fw, r.def.Request.Agent, true, noisy, log) 384 if err != nil { 385 return err 386 } 387 if len(targets) == 0 { 388 return fmt.Errorf("no nodes discovered") 389 } 390 opts = append(opts, client.Targets(targets)) 391 392 if noisy { 393 if ddl.Metadata.Service { 394 r.configureProgressBar(fw, 1, 1) 395 } else { 396 r.configureProgressBar(fw, len(targets), len(targets)) 397 } 398 } 399 400 if r.batch > 0 { 401 if r.batchSleep == 0 { 402 r.batchSleep = 1 403 } 404 405 opts = append(opts, client.InBatches(r.batch, r.batchSleep)) 406 } 407 408 rpcres, err := agent.Do(r.ctx, r.def.Request.Action, rpcInputs, opts...) 409 if err != nil { 410 return err 411 } 412 results.Stats = rpcres.Stats() 413 414 if dt > 0 { 415 rpcres.Stats().OverrideDiscoveryTime(start, start.Add(dt)) 416 } 417 418 if r.progressBar != nil { 419 uiprogress.Stop() 420 fmt.Println() 421 } 422 423 err = r.renderResults(fw, log, results, action) 424 if err != nil { 425 return err 426 } 427 428 return nil 429 430 } 431 432 func (r *RPC) transformResults(w io.Writer, results *replyfmt.RPCResults, action *agent.Action) error { 433 out := bytes.NewBuffer([]byte{}) 434 err := results.RenderJSON(out, action) 435 if err != nil { 436 return err 437 } 438 439 res, err := r.def.Transform.TransformBytes(r.ctx, out.Bytes(), r.flags, r.arguments, r.b) 440 if err != nil { 441 return err 442 } 443 444 fmt.Fprintln(w, string(res)) 445 return nil 446 } 447 448 func (r *RPC) renderResults(fw inter.Framework, log *logrus.Entry, results *replyfmt.RPCResults, action *agent.Action) (err error) { 449 switch { 450 case r.def.Transform != nil: 451 err = r.transformResults(os.Stdout, results, action) 452 case r.senders: 453 err = results.RenderNames(os.Stdout, r.json, false) 454 case r.table: 455 err = results.RenderTable(os.Stdout, action) 456 case r.json: 457 err = results.RenderJSON(os.Stdout, action) 458 default: 459 mode := replyfmt.DisplayDDL 460 switch r.display { 461 case "ok": 462 mode = replyfmt.DisplayOK 463 case "failed": 464 mode = replyfmt.DisplayFailed 465 case "all": 466 mode = replyfmt.DisplayAll 467 case "none": 468 mode = replyfmt.DisplayNone 469 } 470 471 err = results.RenderTXT(os.Stdout, action, false, false, mode, fw.Configuration().Color, log) 472 } 473 474 return err 475 } 476 477 func (r *RPC) reqOptions(action *agent.Action) (inputs map[string]string, rpcInputs map[string]any, opts []client.RequestOption, err error) { 478 opts = []client.RequestOption{} 479 inputs = map[string]string{} 480 481 for k, v := range r.def.Request.Params { 482 body, err := r.parseStateTemplate(v) 483 if err != nil { 484 return nil, nil, nil, err 485 } 486 if len(body) > 0 { 487 inputs[k] = body 488 } 489 } 490 491 filter := "" 492 dFlags := dereferenceArgsOrFlags(r.flags) 493 for _, flag := range r.def.Flags { 494 if dFlags[flag.Name] != "" { 495 if flag.ReplyFilter == "" { 496 continue 497 } 498 499 if filter != "" { 500 return nil, nil, nil, fmt.Errorf("only one filter flag can match") 501 } 502 503 body, err := r.parseStateTemplate(flag.ReplyFilter) 504 if err != nil { 505 return nil, nil, nil, err 506 } 507 508 filter = body 509 break 510 } 511 } 512 513 if filter != "" { 514 opts = append(opts, client.ReplyExprFilter(filter)) 515 } 516 517 rpcInputs, _, err = action.ValidateAndConvertToDDLTypes(inputs) 518 if err != nil { 519 return nil, nil, nil, err 520 } 521 522 return inputs, rpcInputs, opts, nil 523 } 524 525 func (r *RPC) parseStateTemplate(body string) (string, error) { 526 return builder.ParseStateTemplate(body, r.arguments, r.flags, r.cfg) 527 } 528 529 func dereferenceArgsOrFlags(input map[string]any) map[string]any { 530 res := map[string]any{} 531 for k, v := range input { 532 e := reflect.ValueOf(v).Elem() 533 534 // the only kinds of values we support 535 if e.Kind() == reflect.Bool { 536 res[k] = e.Bool() 537 } else { 538 res[k] = e.String() 539 } 540 } 541 542 return res 543 }