github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/client/options.go (about) 1 // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package client 6 7 import ( 8 "fmt" 9 "math" 10 "math/rand" 11 "regexp" 12 "strconv" 13 "time" 14 15 "github.com/choria-io/go-choria/inter" 16 "github.com/choria-io/go-choria/protocol" 17 "github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent" 18 ) 19 20 // RequestOptions are options for a RPC request 21 type RequestOptions struct { 22 BatchSize int 23 BatchSleep time.Duration 24 Collective string 25 ConnectionName string 26 DiscoveryTimeout time.Duration 27 Filter *protocol.Filter 28 Handler Handler 29 ProcessReplies bool 30 ProtocolVersion protocol.ProtocolVersion 31 Replies chan inter.ConnectorMessage 32 ReplyTo string 33 RequestID string 34 RequestType string 35 Targets []string 36 Timeout time.Duration 37 Workers int 38 LimitSeed int64 39 LimitMethod string 40 LimitSize string 41 ReplyExprFilter string 42 DiscoveryStartCB DiscoveryStartFunc 43 DiscoveryEndCB DiscoveryEndFunc 44 45 // merged of all batches 46 totalStats *Stats 47 48 fw inter.Framework 49 } 50 51 // DiscoveryStartFunc gets called before discovery starts 52 type DiscoveryStartFunc func() 53 54 // DiscoveryEndFunc gets called after discovery ends and include the discovered node count 55 // and what count of nodes will be targeted after limits were applied should this return 56 // error the RPC call will terminate 57 type DiscoveryEndFunc func(discovered int, limited int) error 58 59 // RequestOption is a function capable of setting an option 60 type RequestOption func(*RequestOptions) 61 62 // NewRequestOptions creates a initialized request options 63 func NewRequestOptions(fw inter.Framework, ddl *agent.DDL) (*RequestOptions, error) { 64 cfg := fw.Configuration() 65 66 return &RequestOptions{ 67 fw: fw, 68 ProtocolVersion: fw.RequestProtocol(), 69 RequestType: inter.DirectRequestMessageType, 70 Collective: cfg.MainCollective, 71 ProcessReplies: true, 72 Workers: 3, 73 ConnectionName: fw.CallerID(), 74 totalStats: NewStats(), 75 LimitMethod: cfg.RPCLimitMethod, 76 LimitSeed: time.Now().UnixNano(), 77 Filter: protocol.NewFilter(), 78 79 // add discovery timeout to the agent timeout as that's basically an indication of 80 // network overhead, discovery being the smallest possible RPC request it's an indication 81 // of what peoples network behavior is like assuming discovery works 82 Timeout: (time.Duration(cfg.DiscoveryTimeout) * time.Second) + ddl.Timeout(), 83 DiscoveryTimeout: time.Duration(cfg.DiscoveryTimeout) * time.Second, 84 }, nil 85 } 86 87 // ConfigureMessage configures a pre-made message object based on the settings contained 88 func (o *RequestOptions) ConfigureMessage(msg inter.Message) (err error) { 89 o.totalStats.RequestID = msg.RequestID() 90 o.RequestID = msg.RequestID() 91 92 switch o.RequestType { 93 case inter.RequestMessageType, inter.DirectRequestMessageType: 94 if o.RequestType == inter.RequestMessageType && o.BatchSize > 0 { 95 return fmt.Errorf("batched mode requires %s mode", inter.DirectRequestMessageType) 96 } 97 98 if o.BatchSize == 0 { 99 o.BatchSize = len(o.Targets) 100 } 101 102 msg.SetFilter(o.Filter) 103 104 if len(o.Targets) > 0 { 105 limited, err := o.limitTargets(o.Targets) 106 if err != nil { 107 return fmt.Errorf("could not limit targets: %s", err) 108 } 109 110 o.Targets = limited 111 msg.SetDiscoveredHosts(limited) 112 } else { 113 limited, err := o.limitTargets(msg.DiscoveredHosts()) 114 if err != nil { 115 return fmt.Errorf("could not limit targets: %s", err) 116 } 117 118 o.Targets = limited 119 } 120 121 o.totalStats.SetDiscoveredNodes(o.Targets) 122 123 case inter.ServiceRequestMessageType: 124 if len(o.Targets) > 0 { 125 return fmt.Errorf("service requests does not support custom targets") 126 } 127 128 if !o.Filter.Empty() { 129 return fmt.Errorf("service requests does not support filters") 130 } 131 132 msg.SetFilter(protocol.NewFilter()) 133 msg.SetDiscoveredHosts([]string{}) 134 } 135 136 err = msg.SetType(o.RequestType) 137 if err != nil { 138 return err 139 } 140 141 msg.SetProtocolVersion(o.ProtocolVersion) 142 143 stdtarget := msg.ReplyTarget() 144 if o.ReplyTo == "" { 145 o.ReplyTo = stdtarget 146 } 147 148 // the reply target is such that we'd probably not receive replies 149 // so disable processing replies 150 if stdtarget != o.ReplyTo { 151 o.ProcessReplies = false 152 o.Workers = 1 153 } 154 155 err = msg.SetReplyTo(o.ReplyTo) 156 if err != nil { 157 return err 158 } 159 160 err = msg.SetCollective(o.Collective) 161 if err != nil { 162 return err 163 } 164 165 // calculate a TTL for messages when we have batches and when using cached transports, 166 // we need to avoid 2FA interactions for the full duration of the message: 167 // 168 // (TTL + DiscoveryTimeout + Timeout) * batches 169 // 170 // We have to allow TTL per batch since the last batch will get it much 171 if msg.IsCachedTransport() && o.BatchSize != len(o.Targets) { 172 batches := int(math.Ceil(float64(len(o.Targets)) / float64(o.BatchSize))) 173 174 msg.SetTTL(batches * (msg.TTL() + int(o.DiscoveryTimeout.Seconds()) + int(o.Timeout.Seconds()))) 175 if msg.TTL() > int((5 * time.Hour).Seconds()) { 176 return fmt.Errorf("cached transport TTL is unreasonably long") 177 } 178 } 179 180 return nil 181 } 182 183 // Stats retrieves the stats for the completed request 184 func (o *RequestOptions) Stats() *Stats { 185 return o.totalStats 186 } 187 188 // ReplyExprFilter filters reply by filter f, replies that match f will 189 // not be recorded and will not be passed to any handlers - they will 190 // count to received replies though as usual. 191 // 192 // When this filter matches a reply and a handler is set the handler will 193 // be called using a nil 'rpcreply' allowing the handler to process progress 194 // bars and more 195 func ReplyExprFilter(f string) RequestOption { 196 return func(o *RequestOptions) { 197 o.ReplyExprFilter = f 198 } 199 } 200 201 // DiscoveryStartCB sets the function to be called before discovery starts 202 func DiscoveryStartCB(h DiscoveryStartFunc) RequestOption { 203 return func(o *RequestOptions) { 204 o.DiscoveryStartCB = h 205 } 206 } 207 208 // DiscoveryEndCB sets the function to be called after discovery and node limiting 209 func DiscoveryEndCB(h DiscoveryEndFunc) RequestOption { 210 return func(o *RequestOptions) { 211 o.DiscoveryEndCB = h 212 } 213 } 214 215 // ConnectionName sets the prefix used for various connection names 216 // 217 // Setting this when making many clients will minimize prometheus 218 // metrics being created - 2 or 3 per client which with random generated 219 // names will snowball over time 220 func ConnectionName(n string) RequestOption { 221 return func(o *RequestOptions) { 222 o.ConnectionName = n 223 } 224 } 225 226 // Targets configures targets for a RPC request 227 func Targets(t []string) RequestOption { 228 return func(o *RequestOptions) { 229 o.Targets = t 230 } 231 } 232 233 // Protocol sets the protocol version to use 234 func Protocol(v protocol.ProtocolVersion) RequestOption { 235 return func(o *RequestOptions) { 236 o.ProtocolVersion = v 237 } 238 } 239 240 // DirectRequest force the request to be a direct request 241 func DirectRequest() RequestOption { 242 return func(o *RequestOptions) { 243 o.RequestType = inter.DirectRequestMessageType 244 } 245 } 246 247 // BroadcastRequest for the request to be a broadcast mode 248 // 249 // **NOTE:** You need to ensure you have filters etc done 250 func BroadcastRequest() RequestOption { 251 return func(o *RequestOptions) { 252 o.RequestType = inter.RequestMessageType 253 } 254 } 255 256 // ServiceRequest for the request to be directed at a specific service agent 257 // 258 // **Note**: does not support filters or targets 259 func ServiceRequest() RequestOption { 260 return func(o *RequestOptions) { 261 o.RequestType = inter.ServiceRequestMessageType 262 } 263 } 264 265 // Workers configures the amount of workers used to process responses 266 // this is ignored during batched mode as that is always done with a 267 // single worker 268 func Workers(w int) RequestOption { 269 return func(o *RequestOptions) { 270 o.Workers = w 271 } 272 } 273 274 // Collective sets the collective to target a message at 275 func Collective(c string) RequestOption { 276 return func(o *RequestOptions) { 277 o.Collective = c 278 } 279 } 280 281 // ReplyTo sets a custom reply to, else the connector will determine it 282 func ReplyTo(r string) RequestOption { 283 return func(o *RequestOptions) { 284 o.ReplyTo = r 285 o.ProcessReplies = false 286 } 287 } 288 289 // InBatches performs requests in batches 290 func InBatches(size int, sleep int) RequestOption { 291 return func(o *RequestOptions) { 292 o.BatchSize = size 293 o.BatchSleep = time.Second * time.Duration(sleep) 294 o.Workers = 1 295 } 296 } 297 298 // Replies creates a custom channel for replies and will avoid processing them 299 func Replies(r chan inter.ConnectorMessage) RequestOption { 300 return func(o *RequestOptions) { 301 o.Replies = r 302 o.ProcessReplies = false 303 } 304 } 305 306 // Timeout configures the request timeout 307 func Timeout(t time.Duration) RequestOption { 308 return func(o *RequestOptions) { 309 o.Timeout = t 310 } 311 } 312 313 // DiscoveryTimeout configures the request discovery timeout, defaults to configured discovery timeout 314 func DiscoveryTimeout(t time.Duration) RequestOption { 315 return func(o *RequestOptions) { 316 o.DiscoveryTimeout = t 317 } 318 } 319 320 // Filter sets the filter, if its set discovery will be done prior to performing requests 321 func Filter(f *protocol.Filter) RequestOption { 322 return func(o *RequestOptions) { 323 o.Filter = f 324 } 325 } 326 327 // ReplyHandler configures a callback to be called for each message received 328 func ReplyHandler(f Handler) RequestOption { 329 return func(o *RequestOptions) { 330 o.Handler = f 331 } 332 } 333 334 // LimitMethod configures the method to use when limiting targets - "random" or "first" 335 func LimitMethod(m string) RequestOption { 336 return func(o *RequestOptions) { 337 o.LimitMethod = m 338 } 339 } 340 341 // LimitSize sets limits on the targets, either a number of a percentage like "10%" 342 func LimitSize(s string) RequestOption { 343 return func(o *RequestOptions) { 344 o.LimitSize = s 345 } 346 } 347 348 // LimitSeed sets the random seed used to select targets when limiting and limit method is "random" 349 func LimitSeed(s int64) RequestOption { 350 return func(o *RequestOptions) { 351 o.LimitSeed = s 352 } 353 } 354 355 func (o *RequestOptions) shuffleLimitedTargets(targets []string) []string { 356 if o.LimitMethod != "random" { 357 return targets 358 } 359 360 var shuffler *rand.Rand 361 362 if o.LimitSeed > -1 { 363 shuffler = rand.New(rand.NewSource(o.LimitSeed)) 364 } else { 365 shuffler = rand.New(rand.NewSource(time.Now().UnixNano())) 366 } 367 368 shuffler.Shuffle(len(targets), func(i, j int) { targets[i], targets[j] = targets[j], targets[i] }) 369 370 return targets 371 } 372 373 func (o *RequestOptions) limitTargets(targets []string) (limited []string, err error) { 374 if !(o.LimitMethod == "random" || o.LimitMethod == "first") { 375 return targets, fmt.Errorf("limit method '%s' is not valid, only 'random' or 'first' supported", o.LimitMethod) 376 } 377 378 if o.LimitSize == "" { 379 limited = make([]string, len(targets)) 380 copy(limited, targets) 381 382 return limited, nil 383 } 384 385 pctRe := regexp.MustCompile(`^(\d+)%$`) 386 digitRe := regexp.MustCompile(`^(\d+)$`) 387 388 count := 0 389 390 if pctRe.MatchString(o.LimitSize) { 391 // already know its a number and it has a matching substring 392 pct, _ := strconv.Atoi(pctRe.FindStringSubmatch(o.LimitSize)[1]) 393 count = int(float64(len(targets)) * (float64(pct) / 100)) 394 } else if digitRe.MatchString(o.LimitSize) { 395 // already know its a number 396 count, _ = strconv.Atoi(o.LimitSize) 397 } else { 398 return limited, fmt.Errorf("could not parse limit as either number or percent") 399 } 400 401 if count <= 0 { 402 return limited, fmt.Errorf("no targets left after applying target limits of '%s'", o.LimitSize) 403 } 404 405 limited = make([]string, count) 406 copy(limited, o.shuffleLimitedTargets(targets)) 407 408 return limited, err 409 }