github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/client/options.go (about)

     1  // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package client
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"math/rand"
    11  	"regexp"
    12  	"strconv"
    13  	"time"
    14  
    15  	"github.com/choria-io/go-choria/inter"
    16  	"github.com/choria-io/go-choria/protocol"
    17  	"github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent"
    18  )
    19  
    20  // RequestOptions are options for a RPC request
    21  type RequestOptions struct {
    22  	BatchSize        int
    23  	BatchSleep       time.Duration
    24  	Collective       string
    25  	ConnectionName   string
    26  	DiscoveryTimeout time.Duration
    27  	Filter           *protocol.Filter
    28  	Handler          Handler
    29  	ProcessReplies   bool
    30  	ProtocolVersion  protocol.ProtocolVersion
    31  	Replies          chan inter.ConnectorMessage
    32  	ReplyTo          string
    33  	RequestID        string
    34  	RequestType      string
    35  	Targets          []string
    36  	Timeout          time.Duration
    37  	Workers          int
    38  	LimitSeed        int64
    39  	LimitMethod      string
    40  	LimitSize        string
    41  	ReplyExprFilter  string
    42  	DiscoveryStartCB DiscoveryStartFunc
    43  	DiscoveryEndCB   DiscoveryEndFunc
    44  
    45  	// merged of all batches
    46  	totalStats *Stats
    47  
    48  	fw inter.Framework
    49  }
    50  
    51  // DiscoveryStartFunc gets called before discovery starts
    52  type DiscoveryStartFunc func()
    53  
    54  // DiscoveryEndFunc gets called after discovery ends and include the discovered node count
    55  // and what count of nodes will be targeted after limits were applied should this return
    56  // error the RPC call will terminate
    57  type DiscoveryEndFunc func(discovered int, limited int) error
    58  
    59  // RequestOption is a function capable of setting an option
    60  type RequestOption func(*RequestOptions)
    61  
    62  // NewRequestOptions creates a initialized request options
    63  func NewRequestOptions(fw inter.Framework, ddl *agent.DDL) (*RequestOptions, error) {
    64  	cfg := fw.Configuration()
    65  
    66  	return &RequestOptions{
    67  		fw:              fw,
    68  		ProtocolVersion: fw.RequestProtocol(),
    69  		RequestType:     inter.DirectRequestMessageType,
    70  		Collective:      cfg.MainCollective,
    71  		ProcessReplies:  true,
    72  		Workers:         3,
    73  		ConnectionName:  fw.CallerID(),
    74  		totalStats:      NewStats(),
    75  		LimitMethod:     cfg.RPCLimitMethod,
    76  		LimitSeed:       time.Now().UnixNano(),
    77  		Filter:          protocol.NewFilter(),
    78  
    79  		// add discovery timeout to the agent timeout as that's basically an indication of
    80  		// network overhead, discovery being the smallest possible RPC request it's an indication
    81  		// of what peoples network behavior is like assuming discovery works
    82  		Timeout:          (time.Duration(cfg.DiscoveryTimeout) * time.Second) + ddl.Timeout(),
    83  		DiscoveryTimeout: time.Duration(cfg.DiscoveryTimeout) * time.Second,
    84  	}, nil
    85  }
    86  
    87  // ConfigureMessage configures a pre-made message object based on the settings contained
    88  func (o *RequestOptions) ConfigureMessage(msg inter.Message) (err error) {
    89  	o.totalStats.RequestID = msg.RequestID()
    90  	o.RequestID = msg.RequestID()
    91  
    92  	switch o.RequestType {
    93  	case inter.RequestMessageType, inter.DirectRequestMessageType:
    94  		if o.RequestType == inter.RequestMessageType && o.BatchSize > 0 {
    95  			return fmt.Errorf("batched mode requires %s mode", inter.DirectRequestMessageType)
    96  		}
    97  
    98  		if o.BatchSize == 0 {
    99  			o.BatchSize = len(o.Targets)
   100  		}
   101  
   102  		msg.SetFilter(o.Filter)
   103  
   104  		if len(o.Targets) > 0 {
   105  			limited, err := o.limitTargets(o.Targets)
   106  			if err != nil {
   107  				return fmt.Errorf("could not limit targets: %s", err)
   108  			}
   109  
   110  			o.Targets = limited
   111  			msg.SetDiscoveredHosts(limited)
   112  		} else {
   113  			limited, err := o.limitTargets(msg.DiscoveredHosts())
   114  			if err != nil {
   115  				return fmt.Errorf("could not limit targets: %s", err)
   116  			}
   117  
   118  			o.Targets = limited
   119  		}
   120  
   121  		o.totalStats.SetDiscoveredNodes(o.Targets)
   122  
   123  	case inter.ServiceRequestMessageType:
   124  		if len(o.Targets) > 0 {
   125  			return fmt.Errorf("service requests does not support custom targets")
   126  		}
   127  
   128  		if !o.Filter.Empty() {
   129  			return fmt.Errorf("service requests does not support filters")
   130  		}
   131  
   132  		msg.SetFilter(protocol.NewFilter())
   133  		msg.SetDiscoveredHosts([]string{})
   134  	}
   135  
   136  	err = msg.SetType(o.RequestType)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	msg.SetProtocolVersion(o.ProtocolVersion)
   142  
   143  	stdtarget := msg.ReplyTarget()
   144  	if o.ReplyTo == "" {
   145  		o.ReplyTo = stdtarget
   146  	}
   147  
   148  	// the reply target is such that we'd probably not receive replies
   149  	// so disable processing replies
   150  	if stdtarget != o.ReplyTo {
   151  		o.ProcessReplies = false
   152  		o.Workers = 1
   153  	}
   154  
   155  	err = msg.SetReplyTo(o.ReplyTo)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	err = msg.SetCollective(o.Collective)
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	// calculate a TTL for messages when we have batches and when using cached transports,
   166  	// we need to avoid 2FA interactions for the full duration of the message:
   167  	//
   168  	// (TTL + DiscoveryTimeout + Timeout) * batches
   169  	//
   170  	// We have to allow TTL per batch since the last batch will get it much
   171  	if msg.IsCachedTransport() && o.BatchSize != len(o.Targets) {
   172  		batches := int(math.Ceil(float64(len(o.Targets)) / float64(o.BatchSize)))
   173  
   174  		msg.SetTTL(batches * (msg.TTL() + int(o.DiscoveryTimeout.Seconds()) + int(o.Timeout.Seconds())))
   175  		if msg.TTL() > int((5 * time.Hour).Seconds()) {
   176  			return fmt.Errorf("cached transport TTL is unreasonably long")
   177  		}
   178  	}
   179  
   180  	return nil
   181  }
   182  
   183  // Stats retrieves the stats for the completed request
   184  func (o *RequestOptions) Stats() *Stats {
   185  	return o.totalStats
   186  }
   187  
   188  // ReplyExprFilter filters reply by filter f, replies that match f will
   189  // not be recorded and will not be passed to any handlers - they will
   190  // count to received replies though as usual.
   191  //
   192  // When this filter matches a reply and a handler is set the handler will
   193  // be called using a nil 'rpcreply' allowing the handler to process progress
   194  // bars and more
   195  func ReplyExprFilter(f string) RequestOption {
   196  	return func(o *RequestOptions) {
   197  		o.ReplyExprFilter = f
   198  	}
   199  }
   200  
   201  // DiscoveryStartCB sets the function to be called before discovery starts
   202  func DiscoveryStartCB(h DiscoveryStartFunc) RequestOption {
   203  	return func(o *RequestOptions) {
   204  		o.DiscoveryStartCB = h
   205  	}
   206  }
   207  
   208  // DiscoveryEndCB sets the function to be called after discovery and node limiting
   209  func DiscoveryEndCB(h DiscoveryEndFunc) RequestOption {
   210  	return func(o *RequestOptions) {
   211  		o.DiscoveryEndCB = h
   212  	}
   213  }
   214  
   215  // ConnectionName sets the prefix used for various connection names
   216  //
   217  // Setting this when making many clients will minimize prometheus
   218  // metrics being created - 2 or 3 per client which with random generated
   219  // names will snowball over time
   220  func ConnectionName(n string) RequestOption {
   221  	return func(o *RequestOptions) {
   222  		o.ConnectionName = n
   223  	}
   224  }
   225  
   226  // Targets configures targets for a RPC request
   227  func Targets(t []string) RequestOption {
   228  	return func(o *RequestOptions) {
   229  		o.Targets = t
   230  	}
   231  }
   232  
   233  // Protocol sets the protocol version to use
   234  func Protocol(v protocol.ProtocolVersion) RequestOption {
   235  	return func(o *RequestOptions) {
   236  		o.ProtocolVersion = v
   237  	}
   238  }
   239  
   240  // DirectRequest force the request to be a direct request
   241  func DirectRequest() RequestOption {
   242  	return func(o *RequestOptions) {
   243  		o.RequestType = inter.DirectRequestMessageType
   244  	}
   245  }
   246  
   247  // BroadcastRequest for the request to be a broadcast mode
   248  //
   249  // **NOTE:** You need to ensure you have filters etc done
   250  func BroadcastRequest() RequestOption {
   251  	return func(o *RequestOptions) {
   252  		o.RequestType = inter.RequestMessageType
   253  	}
   254  }
   255  
   256  // ServiceRequest for the request to be directed at a specific service agent
   257  //
   258  // **Note**: does not support filters or targets
   259  func ServiceRequest() RequestOption {
   260  	return func(o *RequestOptions) {
   261  		o.RequestType = inter.ServiceRequestMessageType
   262  	}
   263  }
   264  
   265  // Workers configures the amount of workers used to process responses
   266  // this is ignored during batched mode as that is always done with a
   267  // single worker
   268  func Workers(w int) RequestOption {
   269  	return func(o *RequestOptions) {
   270  		o.Workers = w
   271  	}
   272  }
   273  
   274  // Collective sets the collective to target a message at
   275  func Collective(c string) RequestOption {
   276  	return func(o *RequestOptions) {
   277  		o.Collective = c
   278  	}
   279  }
   280  
   281  // ReplyTo sets a custom reply to, else the connector will determine it
   282  func ReplyTo(r string) RequestOption {
   283  	return func(o *RequestOptions) {
   284  		o.ReplyTo = r
   285  		o.ProcessReplies = false
   286  	}
   287  }
   288  
   289  // InBatches performs requests in batches
   290  func InBatches(size int, sleep int) RequestOption {
   291  	return func(o *RequestOptions) {
   292  		o.BatchSize = size
   293  		o.BatchSleep = time.Second * time.Duration(sleep)
   294  		o.Workers = 1
   295  	}
   296  }
   297  
   298  // Replies creates a custom channel for replies and will avoid processing them
   299  func Replies(r chan inter.ConnectorMessage) RequestOption {
   300  	return func(o *RequestOptions) {
   301  		o.Replies = r
   302  		o.ProcessReplies = false
   303  	}
   304  }
   305  
   306  // Timeout configures the request timeout
   307  func Timeout(t time.Duration) RequestOption {
   308  	return func(o *RequestOptions) {
   309  		o.Timeout = t
   310  	}
   311  }
   312  
   313  // DiscoveryTimeout configures the request discovery timeout, defaults to configured discovery timeout
   314  func DiscoveryTimeout(t time.Duration) RequestOption {
   315  	return func(o *RequestOptions) {
   316  		o.DiscoveryTimeout = t
   317  	}
   318  }
   319  
   320  // Filter sets the filter, if its set discovery will be done prior to performing requests
   321  func Filter(f *protocol.Filter) RequestOption {
   322  	return func(o *RequestOptions) {
   323  		o.Filter = f
   324  	}
   325  }
   326  
   327  // ReplyHandler configures a callback to be called for each message received
   328  func ReplyHandler(f Handler) RequestOption {
   329  	return func(o *RequestOptions) {
   330  		o.Handler = f
   331  	}
   332  }
   333  
   334  // LimitMethod configures the method to use when limiting targets - "random" or "first"
   335  func LimitMethod(m string) RequestOption {
   336  	return func(o *RequestOptions) {
   337  		o.LimitMethod = m
   338  	}
   339  }
   340  
   341  // LimitSize sets limits on the targets, either a number of a percentage like "10%"
   342  func LimitSize(s string) RequestOption {
   343  	return func(o *RequestOptions) {
   344  		o.LimitSize = s
   345  	}
   346  }
   347  
   348  // LimitSeed sets the random seed used to select targets when limiting and limit method is "random"
   349  func LimitSeed(s int64) RequestOption {
   350  	return func(o *RequestOptions) {
   351  		o.LimitSeed = s
   352  	}
   353  }
   354  
   355  func (o *RequestOptions) shuffleLimitedTargets(targets []string) []string {
   356  	if o.LimitMethod != "random" {
   357  		return targets
   358  	}
   359  
   360  	var shuffler *rand.Rand
   361  
   362  	if o.LimitSeed > -1 {
   363  		shuffler = rand.New(rand.NewSource(o.LimitSeed))
   364  	} else {
   365  		shuffler = rand.New(rand.NewSource(time.Now().UnixNano()))
   366  	}
   367  
   368  	shuffler.Shuffle(len(targets), func(i, j int) { targets[i], targets[j] = targets[j], targets[i] })
   369  
   370  	return targets
   371  }
   372  
   373  func (o *RequestOptions) limitTargets(targets []string) (limited []string, err error) {
   374  	if !(o.LimitMethod == "random" || o.LimitMethod == "first") {
   375  		return targets, fmt.Errorf("limit method '%s' is not valid, only 'random' or 'first' supported", o.LimitMethod)
   376  	}
   377  
   378  	if o.LimitSize == "" {
   379  		limited = make([]string, len(targets))
   380  		copy(limited, targets)
   381  
   382  		return limited, nil
   383  	}
   384  
   385  	pctRe := regexp.MustCompile(`^(\d+)%$`)
   386  	digitRe := regexp.MustCompile(`^(\d+)$`)
   387  
   388  	count := 0
   389  
   390  	if pctRe.MatchString(o.LimitSize) {
   391  		// already know its a number and it has a matching substring
   392  		pct, _ := strconv.Atoi(pctRe.FindStringSubmatch(o.LimitSize)[1])
   393  		count = int(float64(len(targets)) * (float64(pct) / 100))
   394  	} else if digitRe.MatchString(o.LimitSize) {
   395  		// already know its a number
   396  		count, _ = strconv.Atoi(o.LimitSize)
   397  	} else {
   398  		return limited, fmt.Errorf("could not parse limit as either number or percent")
   399  	}
   400  
   401  	if count <= 0 {
   402  		return limited, fmt.Errorf("no targets left after applying target limits of '%s'", o.LimitSize)
   403  	}
   404  
   405  	limited = make([]string, count)
   406  	copy(limited, o.shuffleLimitedTargets(targets))
   407  
   408  	return limited, err
   409  }