github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/client/client.go (about)

     1  // Copyright (c) 2018-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package client
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"sync"
    12  
    13  	"github.com/choria-io/go-choria/inter"
    14  	"github.com/expr-lang/expr/vm"
    15  
    16  	"github.com/choria-io/go-choria/config"
    17  	"github.com/choria-io/go-choria/providers/discovery/broadcast"
    18  	"github.com/choria-io/go-choria/providers/discovery/puppetdb"
    19  
    20  	cclient "github.com/choria-io/go-choria/client/client"
    21  	"github.com/choria-io/go-choria/protocol"
    22  	"github.com/choria-io/go-choria/providers/agent/mcorpc"
    23  	addl "github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent"
    24  
    25  	"github.com/sirupsen/logrus"
    26  )
    27  
    28  // RPC is a MCollective compatible RPC client
    29  type RPC struct {
    30  	fw   inter.Framework
    31  	opts *RequestOptions
    32  	log  *logrus.Entry
    33  	cfg  *config.Config
    34  	dm   string
    35  
    36  	agent string
    37  
    38  	mu *sync.Mutex
    39  
    40  	ddl *addl.DDL
    41  
    42  	// used for testing only
    43  	cl ChoriaClient
    44  }
    45  
    46  // RPCRequest is a basic RPC request
    47  type RPCRequest struct {
    48  	Agent  string          `json:"agent"`
    49  	Action string          `json:"action"`
    50  	Data   json.RawMessage `json:"data"`
    51  }
    52  
    53  // RequestResult is the result of a request
    54  type RequestResult interface {
    55  	Stats() *Stats
    56  }
    57  
    58  // Handler is a function that should handle each reply synchronously
    59  type Handler func(protocol.Reply, *RPCReply)
    60  
    61  // ChoriaClient implements the connection to the Choria network
    62  type ChoriaClient interface {
    63  	Request(ctx context.Context, msg inter.Message, handler cclient.Handler) (err error)
    64  }
    65  
    66  // Connector is a connection to the choria network
    67  type Connector interface {
    68  	QueueSubscribe(ctx context.Context, name string, subject string, group string, output chan inter.ConnectorMessage) error
    69  	Publish(msg inter.Message) error
    70  }
    71  
    72  // Option configures the RPC client
    73  type Option func(r *RPC)
    74  
    75  // DDL supplies a DDL when creating the client thus avoiding a disk search
    76  func DDL(d *addl.DDL) Option {
    77  	return func(r *RPC) {
    78  		r.ddl = d
    79  	}
    80  }
    81  
    82  // DiscoveryMethod sets a specific discovery method
    83  func DiscoveryMethod(dm string) Option {
    84  	return func(r *RPC) {
    85  		r.dm = dm
    86  	}
    87  }
    88  
    89  // New creates a new RPC request
    90  //
    91  // A DDL is required when one is not given using the DDL() option as argument
    92  // attempts will be made to find it on the file system should this fail an error will be returned
    93  func New(fw inter.Framework, agent string, opts ...Option) (rpc *RPC, err error) {
    94  	rpc = &RPC{
    95  		fw:    fw,
    96  		cfg:   fw.Configuration(),
    97  		mu:    &sync.Mutex{},
    98  		log:   fw.Logger("rpc"),
    99  		agent: agent,
   100  		dm:    fw.Configuration().DefaultDiscoveryMethod,
   101  	}
   102  
   103  	for _, opt := range opts {
   104  		opt(rpc)
   105  	}
   106  
   107  	return rpc, nil
   108  }
   109  
   110  func (r *RPC) setOptions(opts ...RequestOption) (err error) {
   111  	r.opts, err = NewRequestOptions(r.fw, r.ddl)
   112  	if err != nil {
   113  		return err
   114  	}
   115  
   116  	for _, opt := range opts {
   117  		opt(r.opts)
   118  	}
   119  
   120  	if r.ddl.Metadata.Service {
   121  		r.opts.Workers = 1
   122  		r.opts.RequestType = inter.ServiceRequestMessageType
   123  	}
   124  
   125  	return nil
   126  }
   127  
   128  func (r *RPC) ResolveDDL(ctx context.Context) error {
   129  	r.mu.Lock()
   130  	defer r.mu.Unlock()
   131  
   132  	if r.ddl != nil {
   133  		return nil
   134  	}
   135  
   136  	ddl, err := r.resolveDDL(ctx)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	r.ddl = ddl
   142  
   143  	return nil
   144  }
   145  
   146  func (r *RPC) resolveDDL(ctx context.Context) (*addl.DDL, error) {
   147  	if r.ddl != nil {
   148  		return r.ddl, nil
   149  	}
   150  
   151  	resolvers, err := r.fw.DDLResolvers()
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	for _, resolver := range resolvers {
   157  		r.log.Debugf("Attempting to resolve agent DDL using %s", resolver)
   158  
   159  		data, err := resolver.DDLBytes(ctx, "agent", r.agent, r.fw)
   160  		if err == nil {
   161  			ddl, err := addl.NewFromBytes(data)
   162  			if err != nil {
   163  				return nil, err
   164  			}
   165  			r.ddl = ddl
   166  
   167  			return ddl, nil
   168  		} else {
   169  			r.log.Debugf("DDL Resolution failed: %s", err)
   170  		}
   171  	}
   172  
   173  	return nil, fmt.Errorf("could not resolve %s DDL in any resolver: %s", r.agent, err)
   174  }
   175  
   176  // DDL returns the active DDL for this client
   177  func (r *RPC) DDL() *addl.DDL {
   178  	r.mu.Lock()
   179  	defer r.mu.Unlock()
   180  
   181  	return r.ddl
   182  }
   183  
   184  // Do perform a RPC request and optionally processes replies
   185  //
   186  // If a filter is supplied using the Filter() option and Targets() are not then discovery will be done for you
   187  // using the broadcast method, should no nodes be discovered an error will be returned
   188  func (r *RPC) Do(ctx context.Context, action string, payload any, opts ...RequestOption) (RequestResult, error) {
   189  	r.mu.Lock()
   190  	defer r.mu.Unlock()
   191  
   192  	_, err := r.resolveDDL(ctx)
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  
   197  	if r.ddl.Metadata.Name != r.agent {
   198  		return nil, fmt.Errorf("the DDL does not describe the %s agent", r.agent)
   199  	}
   200  
   201  	// we want to force the passing of options on every request
   202  	err = r.setOptions(opts...)
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	dctx, cancel := context.WithCancel(ctx)
   208  	defer cancel()
   209  
   210  	if len(r.opts.Targets) == 0 && (r.opts.RequestType != inter.ServiceRequestMessageType || !r.ddl.Metadata.Service) {
   211  		err = r.discover(ctx)
   212  		if err != nil {
   213  			return nil, fmt.Errorf("discovery failed: %s", err)
   214  		}
   215  	}
   216  
   217  	discoveredCnt := len(r.opts.Targets)
   218  	msg, cl, err := r.setupMessage(dctx, action, payload, opts...)
   219  	if err != nil {
   220  		return nil, fmt.Errorf("could not configure message: %s", err)
   221  	}
   222  
   223  	if r.opts.DiscoveryEndCB != nil {
   224  		err = r.opts.DiscoveryEndCB(discoveredCnt, len(r.opts.Targets))
   225  		if err != nil {
   226  			return nil, err
   227  		}
   228  	}
   229  
   230  	r.opts.totalStats.Start()
   231  	defer r.opts.totalStats.End()
   232  
   233  	r.opts.totalStats.SetAction(action)
   234  	r.opts.totalStats.SetAgent(r.agent)
   235  
   236  	switch r.opts.RequestType {
   237  	case inter.ServiceRequestMessageType:
   238  		err = r.doServiceRequest(dctx, msg, cl)
   239  	default:
   240  		err = r.doBatchedRequest(ctx, msg, cl)
   241  	}
   242  
   243  	return &RequestOptions{totalStats: r.opts.totalStats}, err
   244  }
   245  
   246  func (r *RPC) doBatchedRequest(ctx context.Context, msg inter.Message, cl ChoriaClient) error {
   247  	// the client is always batched, when batched mode is not request the size of
   248  	// the batch matches the size of the total targets and during setupMessage()
   249  	// an appropriate connection will be made
   250  
   251  	ctr := 0
   252  
   253  	return InGroups(r.opts.Targets, r.opts.BatchSize, func(nodes []string) error {
   254  		stats := NewStats()
   255  		stats.SetDiscoveredNodes(nodes)
   256  		msg.SetDiscoveredHosts(nodes)
   257  
   258  		stats.Start()
   259  		defer func(s *Stats) {
   260  			s.End()
   261  			r.opts.totalStats.Merge(s)
   262  		}(stats)
   263  
   264  		if ctr > 0 {
   265  			err := InterruptableSleep(ctx, r.opts.BatchSleep)
   266  			if err != nil {
   267  				return err
   268  			}
   269  		}
   270  
   271  		r.log.Debugf("Performing batched request %d for %d/%d nodes", ctr, len(nodes), len(r.opts.Targets))
   272  
   273  		err := r.request(ctx, msg, cl, stats)
   274  		if err != nil {
   275  			return err
   276  		}
   277  
   278  		ctr++
   279  
   280  		return nil
   281  	})
   282  }
   283  
   284  func (r *RPC) doServiceRequest(ctx context.Context, msg inter.Message, cl ChoriaClient) error {
   285  	stats := NewStats()
   286  
   287  	var responded []string
   288  	handler := r.opts.Handler
   289  	r.opts.Handler = func(r protocol.Reply, rpc *RPCReply) {
   290  		responded = append(responded, r.SenderID())
   291  		if handler != nil {
   292  			handler(r, rpc)
   293  		}
   294  	}
   295  
   296  	err := r.request(ctx, msg, cl, stats)
   297  	if len(responded) > 0 {
   298  		stats.SetDiscoveredNodes(responded)
   299  		r.opts.totalStats.SetDiscoveredNodes(responded)
   300  		stats.RecordReceived(responded[0])
   301  	}
   302  
   303  	stats.End()
   304  
   305  	r.opts.totalStats.Merge(stats)
   306  
   307  	return err
   308  }
   309  
   310  func (r *RPC) discover(ctx context.Context) error {
   311  	if len(r.opts.Filter.Compound) > 0 {
   312  		r.dm = "choria"
   313  	}
   314  	if r.opts.DiscoveryStartCB != nil {
   315  		r.opts.DiscoveryStartCB()
   316  	}
   317  
   318  	r.opts.totalStats.StartDiscover()
   319  	defer r.opts.totalStats.EndDiscover()
   320  
   321  	if r.opts.Filter == nil {
   322  		r.opts.Filter = protocol.NewFilter()
   323  	}
   324  
   325  	r.opts.Filter.AddAgentFilter(r.agent)
   326  
   327  	var n []string
   328  	var err error
   329  
   330  	// TODO: other discovery options? honestly the magical discovery here should just go
   331  	switch r.dm {
   332  	case "choria":
   333  		pdb := puppetdb.New(r.fw)
   334  		n, err = pdb.Discover(ctx, puppetdb.Filter(r.opts.Filter), puppetdb.Timeout(r.opts.DiscoveryTimeout), puppetdb.Collective(r.opts.Collective))
   335  
   336  	default:
   337  		b := broadcast.New(r.fw)
   338  		n, err = b.Discover(ctx, broadcast.Filter(r.opts.Filter), broadcast.Timeout(r.opts.DiscoveryTimeout), broadcast.Name(r.opts.ConnectionName), broadcast.Collective(r.opts.Collective))
   339  	}
   340  	if err != nil {
   341  		return err
   342  	}
   343  
   344  	if len(n) == 0 {
   345  		return fmt.Errorf("no targets were discovered")
   346  	}
   347  
   348  	r.opts.Targets = n
   349  
   350  	return nil
   351  }
   352  
   353  func (r *RPC) setupMessage(ctx context.Context, action string, payload any, opts ...RequestOption) (msg inter.Message, cl ChoriaClient, err error) {
   354  	pj, err := json.Marshal(payload)
   355  	if err != nil {
   356  		return nil, nil, fmt.Errorf("could not encode payload: %s", err)
   357  	}
   358  
   359  	rpcreq := &RPCRequest{
   360  		Agent:  r.agent,
   361  		Action: action,
   362  		Data:   pj,
   363  	}
   364  
   365  	rpcp, err := json.Marshal(rpcreq)
   366  	if err != nil {
   367  		return nil, nil, fmt.Errorf("could not encode request: %s", err)
   368  	}
   369  
   370  	msgType := inter.RequestMessageType
   371  	if r.ddl.Metadata.Service {
   372  		msgType = inter.ServiceRequestMessageType
   373  		r.opts.Workers = 1
   374  	}
   375  
   376  	msg, err = r.fw.NewMessage(rpcp, r.agent, r.cfg.MainCollective, msgType, nil)
   377  	if err != nil {
   378  		return nil, nil, err
   379  	}
   380  
   381  	err = r.opts.ConfigureMessage(msg)
   382  	if err != nil {
   383  		return nil, nil, fmt.Errorf("could not configure Message: %s", err)
   384  	}
   385  
   386  	cl = r.cl
   387  
   388  	if r.cl == nil {
   389  		if r.opts.BatchSize == len(r.opts.Targets) || !r.opts.ProcessReplies {
   390  			cl, err = r.unbatchedClient()
   391  			if err != nil {
   392  				return nil, nil, err
   393  			}
   394  		} else {
   395  			cl, err = r.batchedClient(ctx, msg.RequestID())
   396  			if err != nil {
   397  				return nil, nil, err
   398  			}
   399  		}
   400  	}
   401  
   402  	return msg, cl, err
   403  }
   404  
   405  func (r *RPC) unbatchedClient() (cl ChoriaClient, err error) {
   406  	cl, err = cclient.New(
   407  		r.fw,
   408  		cclient.Receivers(r.opts.Workers),
   409  		cclient.Timeout(r.opts.Timeout),
   410  		cclient.OnPublishStart(r.opts.totalStats.StartPublish),
   411  		cclient.OnPublishFinish(r.opts.totalStats.EndPublish),
   412  		cclient.Name(r.opts.ConnectionName),
   413  	)
   414  	if err != nil {
   415  		return nil, fmt.Errorf("could not setup client: %s", err)
   416  	}
   417  
   418  	return cl, nil
   419  }
   420  
   421  func (r *RPC) batchedClient(ctx context.Context, msgid string) (cl ChoriaClient, err error) {
   422  	conn, err := r.connectBatchedConnection(ctx, fmt.Sprintf("%s_%s_batched", r.opts.ConnectionName, msgid))
   423  	if err != nil {
   424  		return nil, fmt.Errorf("could not connect batched network connection: %s", err)
   425  	}
   426  
   427  	cl, err = cclient.New(
   428  		r.fw,
   429  		cclient.Receivers(r.opts.Workers),
   430  		cclient.Timeout(r.opts.Timeout),
   431  		cclient.OnPublishStart(r.opts.totalStats.StartPublish),
   432  		cclient.OnPublishFinish(r.opts.totalStats.EndPublish),
   433  		cclient.Connection(conn),
   434  		cclient.Name(r.opts.ConnectionName),
   435  	)
   436  	if err != nil {
   437  		return nil, fmt.Errorf("could not set up batched client: %s", err)
   438  	}
   439  
   440  	return cl, nil
   441  }
   442  
   443  // Reset removes the cached options, any further Do() calls need to specify full options
   444  func (r *RPC) Reset() {
   445  	r.mu.Lock()
   446  	defer r.mu.Unlock()
   447  
   448  	r.opts = nil
   449  	r.cl = nil
   450  }
   451  
   452  func (r *RPC) request(ctx context.Context, msg inter.Message, cl ChoriaClient, stats *Stats) error {
   453  	rctx, cancel := context.WithCancel(ctx)
   454  	defer cancel()
   455  
   456  	err := cl.Request(rctx, msg, r.handlerFactory(rctx, cancel, stats))
   457  	if err != nil {
   458  		return err
   459  	}
   460  
   461  	return nil
   462  }
   463  
   464  func (r *RPC) handlerFactory(_ context.Context, cancel context.CancelFunc, stats *Stats) cclient.Handler {
   465  	if !r.opts.ProcessReplies {
   466  		return nil
   467  	}
   468  
   469  	var prog *vm.Program
   470  
   471  	handler := func(ctx context.Context, rawmsg inter.ConnectorMessage) {
   472  		reply, err := r.fw.NewReplyFromTransportJSON(rawmsg.Data(), false)
   473  		if err != nil {
   474  			stats.FailedRequestInc()
   475  			r.log.Errorf("Could not process a reply: %s", err)
   476  			return
   477  		}
   478  
   479  		// defer because we do not do any discovery so recording the response here would mark it as unknown
   480  		if r.opts.RequestType != inter.ServiceRequestMessageType {
   481  			stats.RecordReceived(reply.SenderID())
   482  		}
   483  
   484  		rpcreply, err := ParseReply(reply)
   485  		switch {
   486  		case err != nil:
   487  			stats.FailedRequestInc()
   488  			r.log.Errorf("Could not process reply from %s: %s", reply.SenderID(), err)
   489  			return
   490  		case rpcreply.Statuscode == mcorpc.OK:
   491  			stats.PassedRequestInc()
   492  		default:
   493  			stats.FailedRequestInc()
   494  		}
   495  
   496  		if r.opts.Handler != nil {
   497  			shouldShow := true
   498  			if r.opts.ReplyExprFilter != "" {
   499  				shouldShow, prog, err = rpcreply.MatchExpr(r.opts.ReplyExprFilter, prog)
   500  				if err != nil {
   501  					r.log.Errorf("Expr filter parsing failed in reply from %s: %s", reply.SenderID(), err)
   502  				}
   503  			}
   504  
   505  			if shouldShow {
   506  				r.opts.Handler(reply, rpcreply)
   507  			} else {
   508  				r.opts.Handler(reply, nil)
   509  			}
   510  		}
   511  
   512  		if stats.All() {
   513  			cancel()
   514  			return
   515  		}
   516  	}
   517  
   518  	return handler
   519  }
   520  
   521  func (r *RPC) connectBatchedConnection(ctx context.Context, name string) (Connector, error) {
   522  	connector, err := r.fw.NewConnector(ctx, r.fw.MiddlewareServers, name, r.log)
   523  	if err != nil {
   524  		return nil, err
   525  	}
   526  
   527  	closer := func() {
   528  		<-ctx.Done()
   529  
   530  		r.log.Debugf("Closing batched connection %s", name)
   531  		connector.Close()
   532  		connector.Close()
   533  	}
   534  
   535  	go closer()
   536  
   537  	return connector, nil
   538  }