github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/client/client.go (about) 1 // Copyright (c) 2018-2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package client 6 7 import ( 8 "context" 9 "encoding/json" 10 "fmt" 11 "sync" 12 13 "github.com/choria-io/go-choria/inter" 14 "github.com/expr-lang/expr/vm" 15 16 "github.com/choria-io/go-choria/config" 17 "github.com/choria-io/go-choria/providers/discovery/broadcast" 18 "github.com/choria-io/go-choria/providers/discovery/puppetdb" 19 20 cclient "github.com/choria-io/go-choria/client/client" 21 "github.com/choria-io/go-choria/protocol" 22 "github.com/choria-io/go-choria/providers/agent/mcorpc" 23 addl "github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent" 24 25 "github.com/sirupsen/logrus" 26 ) 27 28 // RPC is a MCollective compatible RPC client 29 type RPC struct { 30 fw inter.Framework 31 opts *RequestOptions 32 log *logrus.Entry 33 cfg *config.Config 34 dm string 35 36 agent string 37 38 mu *sync.Mutex 39 40 ddl *addl.DDL 41 42 // used for testing only 43 cl ChoriaClient 44 } 45 46 // RPCRequest is a basic RPC request 47 type RPCRequest struct { 48 Agent string `json:"agent"` 49 Action string `json:"action"` 50 Data json.RawMessage `json:"data"` 51 } 52 53 // RequestResult is the result of a request 54 type RequestResult interface { 55 Stats() *Stats 56 } 57 58 // Handler is a function that should handle each reply synchronously 59 type Handler func(protocol.Reply, *RPCReply) 60 61 // ChoriaClient implements the connection to the Choria network 62 type ChoriaClient interface { 63 Request(ctx context.Context, msg inter.Message, handler cclient.Handler) (err error) 64 } 65 66 // Connector is a connection to the choria network 67 type Connector interface { 68 QueueSubscribe(ctx context.Context, name string, subject string, group string, output chan inter.ConnectorMessage) error 69 Publish(msg inter.Message) error 70 } 71 72 // Option configures the RPC client 73 type Option func(r *RPC) 74 75 // DDL supplies a DDL when creating the client thus avoiding a disk search 76 func DDL(d *addl.DDL) Option { 77 return func(r *RPC) { 78 r.ddl = d 79 } 80 } 81 82 // DiscoveryMethod sets a specific discovery method 83 func DiscoveryMethod(dm string) Option { 84 return func(r *RPC) { 85 r.dm = dm 86 } 87 } 88 89 // New creates a new RPC request 90 // 91 // A DDL is required when one is not given using the DDL() option as argument 92 // attempts will be made to find it on the file system should this fail an error will be returned 93 func New(fw inter.Framework, agent string, opts ...Option) (rpc *RPC, err error) { 94 rpc = &RPC{ 95 fw: fw, 96 cfg: fw.Configuration(), 97 mu: &sync.Mutex{}, 98 log: fw.Logger("rpc"), 99 agent: agent, 100 dm: fw.Configuration().DefaultDiscoveryMethod, 101 } 102 103 for _, opt := range opts { 104 opt(rpc) 105 } 106 107 return rpc, nil 108 } 109 110 func (r *RPC) setOptions(opts ...RequestOption) (err error) { 111 r.opts, err = NewRequestOptions(r.fw, r.ddl) 112 if err != nil { 113 return err 114 } 115 116 for _, opt := range opts { 117 opt(r.opts) 118 } 119 120 if r.ddl.Metadata.Service { 121 r.opts.Workers = 1 122 r.opts.RequestType = inter.ServiceRequestMessageType 123 } 124 125 return nil 126 } 127 128 func (r *RPC) ResolveDDL(ctx context.Context) error { 129 r.mu.Lock() 130 defer r.mu.Unlock() 131 132 if r.ddl != nil { 133 return nil 134 } 135 136 ddl, err := r.resolveDDL(ctx) 137 if err != nil { 138 return err 139 } 140 141 r.ddl = ddl 142 143 return nil 144 } 145 146 func (r *RPC) resolveDDL(ctx context.Context) (*addl.DDL, error) { 147 if r.ddl != nil { 148 return r.ddl, nil 149 } 150 151 resolvers, err := r.fw.DDLResolvers() 152 if err != nil { 153 return nil, err 154 } 155 156 for _, resolver := range resolvers { 157 r.log.Debugf("Attempting to resolve agent DDL using %s", resolver) 158 159 data, err := resolver.DDLBytes(ctx, "agent", r.agent, r.fw) 160 if err == nil { 161 ddl, err := addl.NewFromBytes(data) 162 if err != nil { 163 return nil, err 164 } 165 r.ddl = ddl 166 167 return ddl, nil 168 } else { 169 r.log.Debugf("DDL Resolution failed: %s", err) 170 } 171 } 172 173 return nil, fmt.Errorf("could not resolve %s DDL in any resolver: %s", r.agent, err) 174 } 175 176 // DDL returns the active DDL for this client 177 func (r *RPC) DDL() *addl.DDL { 178 r.mu.Lock() 179 defer r.mu.Unlock() 180 181 return r.ddl 182 } 183 184 // Do perform a RPC request and optionally processes replies 185 // 186 // If a filter is supplied using the Filter() option and Targets() are not then discovery will be done for you 187 // using the broadcast method, should no nodes be discovered an error will be returned 188 func (r *RPC) Do(ctx context.Context, action string, payload any, opts ...RequestOption) (RequestResult, error) { 189 r.mu.Lock() 190 defer r.mu.Unlock() 191 192 _, err := r.resolveDDL(ctx) 193 if err != nil { 194 return nil, err 195 } 196 197 if r.ddl.Metadata.Name != r.agent { 198 return nil, fmt.Errorf("the DDL does not describe the %s agent", r.agent) 199 } 200 201 // we want to force the passing of options on every request 202 err = r.setOptions(opts...) 203 if err != nil { 204 return nil, err 205 } 206 207 dctx, cancel := context.WithCancel(ctx) 208 defer cancel() 209 210 if len(r.opts.Targets) == 0 && (r.opts.RequestType != inter.ServiceRequestMessageType || !r.ddl.Metadata.Service) { 211 err = r.discover(ctx) 212 if err != nil { 213 return nil, fmt.Errorf("discovery failed: %s", err) 214 } 215 } 216 217 discoveredCnt := len(r.opts.Targets) 218 msg, cl, err := r.setupMessage(dctx, action, payload, opts...) 219 if err != nil { 220 return nil, fmt.Errorf("could not configure message: %s", err) 221 } 222 223 if r.opts.DiscoveryEndCB != nil { 224 err = r.opts.DiscoveryEndCB(discoveredCnt, len(r.opts.Targets)) 225 if err != nil { 226 return nil, err 227 } 228 } 229 230 r.opts.totalStats.Start() 231 defer r.opts.totalStats.End() 232 233 r.opts.totalStats.SetAction(action) 234 r.opts.totalStats.SetAgent(r.agent) 235 236 switch r.opts.RequestType { 237 case inter.ServiceRequestMessageType: 238 err = r.doServiceRequest(dctx, msg, cl) 239 default: 240 err = r.doBatchedRequest(ctx, msg, cl) 241 } 242 243 return &RequestOptions{totalStats: r.opts.totalStats}, err 244 } 245 246 func (r *RPC) doBatchedRequest(ctx context.Context, msg inter.Message, cl ChoriaClient) error { 247 // the client is always batched, when batched mode is not request the size of 248 // the batch matches the size of the total targets and during setupMessage() 249 // an appropriate connection will be made 250 251 ctr := 0 252 253 return InGroups(r.opts.Targets, r.opts.BatchSize, func(nodes []string) error { 254 stats := NewStats() 255 stats.SetDiscoveredNodes(nodes) 256 msg.SetDiscoveredHosts(nodes) 257 258 stats.Start() 259 defer func(s *Stats) { 260 s.End() 261 r.opts.totalStats.Merge(s) 262 }(stats) 263 264 if ctr > 0 { 265 err := InterruptableSleep(ctx, r.opts.BatchSleep) 266 if err != nil { 267 return err 268 } 269 } 270 271 r.log.Debugf("Performing batched request %d for %d/%d nodes", ctr, len(nodes), len(r.opts.Targets)) 272 273 err := r.request(ctx, msg, cl, stats) 274 if err != nil { 275 return err 276 } 277 278 ctr++ 279 280 return nil 281 }) 282 } 283 284 func (r *RPC) doServiceRequest(ctx context.Context, msg inter.Message, cl ChoriaClient) error { 285 stats := NewStats() 286 287 var responded []string 288 handler := r.opts.Handler 289 r.opts.Handler = func(r protocol.Reply, rpc *RPCReply) { 290 responded = append(responded, r.SenderID()) 291 if handler != nil { 292 handler(r, rpc) 293 } 294 } 295 296 err := r.request(ctx, msg, cl, stats) 297 if len(responded) > 0 { 298 stats.SetDiscoveredNodes(responded) 299 r.opts.totalStats.SetDiscoveredNodes(responded) 300 stats.RecordReceived(responded[0]) 301 } 302 303 stats.End() 304 305 r.opts.totalStats.Merge(stats) 306 307 return err 308 } 309 310 func (r *RPC) discover(ctx context.Context) error { 311 if len(r.opts.Filter.Compound) > 0 { 312 r.dm = "choria" 313 } 314 if r.opts.DiscoveryStartCB != nil { 315 r.opts.DiscoveryStartCB() 316 } 317 318 r.opts.totalStats.StartDiscover() 319 defer r.opts.totalStats.EndDiscover() 320 321 if r.opts.Filter == nil { 322 r.opts.Filter = protocol.NewFilter() 323 } 324 325 r.opts.Filter.AddAgentFilter(r.agent) 326 327 var n []string 328 var err error 329 330 // TODO: other discovery options? honestly the magical discovery here should just go 331 switch r.dm { 332 case "choria": 333 pdb := puppetdb.New(r.fw) 334 n, err = pdb.Discover(ctx, puppetdb.Filter(r.opts.Filter), puppetdb.Timeout(r.opts.DiscoveryTimeout), puppetdb.Collective(r.opts.Collective)) 335 336 default: 337 b := broadcast.New(r.fw) 338 n, err = b.Discover(ctx, broadcast.Filter(r.opts.Filter), broadcast.Timeout(r.opts.DiscoveryTimeout), broadcast.Name(r.opts.ConnectionName), broadcast.Collective(r.opts.Collective)) 339 } 340 if err != nil { 341 return err 342 } 343 344 if len(n) == 0 { 345 return fmt.Errorf("no targets were discovered") 346 } 347 348 r.opts.Targets = n 349 350 return nil 351 } 352 353 func (r *RPC) setupMessage(ctx context.Context, action string, payload any, opts ...RequestOption) (msg inter.Message, cl ChoriaClient, err error) { 354 pj, err := json.Marshal(payload) 355 if err != nil { 356 return nil, nil, fmt.Errorf("could not encode payload: %s", err) 357 } 358 359 rpcreq := &RPCRequest{ 360 Agent: r.agent, 361 Action: action, 362 Data: pj, 363 } 364 365 rpcp, err := json.Marshal(rpcreq) 366 if err != nil { 367 return nil, nil, fmt.Errorf("could not encode request: %s", err) 368 } 369 370 msgType := inter.RequestMessageType 371 if r.ddl.Metadata.Service { 372 msgType = inter.ServiceRequestMessageType 373 r.opts.Workers = 1 374 } 375 376 msg, err = r.fw.NewMessage(rpcp, r.agent, r.cfg.MainCollective, msgType, nil) 377 if err != nil { 378 return nil, nil, err 379 } 380 381 err = r.opts.ConfigureMessage(msg) 382 if err != nil { 383 return nil, nil, fmt.Errorf("could not configure Message: %s", err) 384 } 385 386 cl = r.cl 387 388 if r.cl == nil { 389 if r.opts.BatchSize == len(r.opts.Targets) || !r.opts.ProcessReplies { 390 cl, err = r.unbatchedClient() 391 if err != nil { 392 return nil, nil, err 393 } 394 } else { 395 cl, err = r.batchedClient(ctx, msg.RequestID()) 396 if err != nil { 397 return nil, nil, err 398 } 399 } 400 } 401 402 return msg, cl, err 403 } 404 405 func (r *RPC) unbatchedClient() (cl ChoriaClient, err error) { 406 cl, err = cclient.New( 407 r.fw, 408 cclient.Receivers(r.opts.Workers), 409 cclient.Timeout(r.opts.Timeout), 410 cclient.OnPublishStart(r.opts.totalStats.StartPublish), 411 cclient.OnPublishFinish(r.opts.totalStats.EndPublish), 412 cclient.Name(r.opts.ConnectionName), 413 ) 414 if err != nil { 415 return nil, fmt.Errorf("could not setup client: %s", err) 416 } 417 418 return cl, nil 419 } 420 421 func (r *RPC) batchedClient(ctx context.Context, msgid string) (cl ChoriaClient, err error) { 422 conn, err := r.connectBatchedConnection(ctx, fmt.Sprintf("%s_%s_batched", r.opts.ConnectionName, msgid)) 423 if err != nil { 424 return nil, fmt.Errorf("could not connect batched network connection: %s", err) 425 } 426 427 cl, err = cclient.New( 428 r.fw, 429 cclient.Receivers(r.opts.Workers), 430 cclient.Timeout(r.opts.Timeout), 431 cclient.OnPublishStart(r.opts.totalStats.StartPublish), 432 cclient.OnPublishFinish(r.opts.totalStats.EndPublish), 433 cclient.Connection(conn), 434 cclient.Name(r.opts.ConnectionName), 435 ) 436 if err != nil { 437 return nil, fmt.Errorf("could not set up batched client: %s", err) 438 } 439 440 return cl, nil 441 } 442 443 // Reset removes the cached options, any further Do() calls need to specify full options 444 func (r *RPC) Reset() { 445 r.mu.Lock() 446 defer r.mu.Unlock() 447 448 r.opts = nil 449 r.cl = nil 450 } 451 452 func (r *RPC) request(ctx context.Context, msg inter.Message, cl ChoriaClient, stats *Stats) error { 453 rctx, cancel := context.WithCancel(ctx) 454 defer cancel() 455 456 err := cl.Request(rctx, msg, r.handlerFactory(rctx, cancel, stats)) 457 if err != nil { 458 return err 459 } 460 461 return nil 462 } 463 464 func (r *RPC) handlerFactory(_ context.Context, cancel context.CancelFunc, stats *Stats) cclient.Handler { 465 if !r.opts.ProcessReplies { 466 return nil 467 } 468 469 var prog *vm.Program 470 471 handler := func(ctx context.Context, rawmsg inter.ConnectorMessage) { 472 reply, err := r.fw.NewReplyFromTransportJSON(rawmsg.Data(), false) 473 if err != nil { 474 stats.FailedRequestInc() 475 r.log.Errorf("Could not process a reply: %s", err) 476 return 477 } 478 479 // defer because we do not do any discovery so recording the response here would mark it as unknown 480 if r.opts.RequestType != inter.ServiceRequestMessageType { 481 stats.RecordReceived(reply.SenderID()) 482 } 483 484 rpcreply, err := ParseReply(reply) 485 switch { 486 case err != nil: 487 stats.FailedRequestInc() 488 r.log.Errorf("Could not process reply from %s: %s", reply.SenderID(), err) 489 return 490 case rpcreply.Statuscode == mcorpc.OK: 491 stats.PassedRequestInc() 492 default: 493 stats.FailedRequestInc() 494 } 495 496 if r.opts.Handler != nil { 497 shouldShow := true 498 if r.opts.ReplyExprFilter != "" { 499 shouldShow, prog, err = rpcreply.MatchExpr(r.opts.ReplyExprFilter, prog) 500 if err != nil { 501 r.log.Errorf("Expr filter parsing failed in reply from %s: %s", reply.SenderID(), err) 502 } 503 } 504 505 if shouldShow { 506 r.opts.Handler(reply, rpcreply) 507 } else { 508 r.opts.Handler(reply, nil) 509 } 510 } 511 512 if stats.All() { 513 cancel() 514 return 515 } 516 } 517 518 return handler 519 } 520 521 func (r *RPC) connectBatchedConnection(ctx context.Context, name string) (Connector, error) { 522 connector, err := r.fw.NewConnector(ctx, r.fw.MiddlewareServers, name, r.log) 523 if err != nil { 524 return nil, err 525 } 526 527 closer := func() { 528 <-ctx.Done() 529 530 r.log.Debugf("Closing batched connection %s", name) 531 connector.Close() 532 connector.Close() 533 } 534 535 go closer() 536 537 return connector, nil 538 }