github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/router/roundrobin_router.go (about)

     1  package router
     2  
     3  import (
     4  	"sync/atomic"
     5  
     6  	"github.com/asynkron/protoactor-go/actor"
     7  )
     8  
     9  type roundRobinGroupRouter struct {
    10  	GroupRouter
    11  }
    12  
    13  type roundRobinPoolRouter struct {
    14  	PoolRouter
    15  }
    16  
    17  type roundRobinState struct {
    18  	index   int32
    19  	routees *actor.PIDSet
    20  	sender  actor.SenderContext
    21  }
    22  
    23  func (state *roundRobinState) SetSender(sender actor.SenderContext) {
    24  	state.sender = sender
    25  }
    26  
    27  func (state *roundRobinState) SetRoutees(routees *actor.PIDSet) {
    28  	state.routees = routees
    29  }
    30  
    31  func (state *roundRobinState) GetRoutees() *actor.PIDSet {
    32  	return state.routees
    33  }
    34  
    35  func (state *roundRobinState) RouteMessage(message interface{}) {
    36  	pid := roundRobinRoutee(&state.index, state.routees)
    37  	state.sender.Send(pid, message)
    38  }
    39  
    40  func NewRoundRobinPool(size int, opts ...actor.PropsOption) *actor.Props {
    41  	return (&actor.Props{}).
    42  		Configure(actor.WithSpawnFunc(spawner(&roundRobinPoolRouter{PoolRouter{PoolSize: size}}))).
    43  		Configure(opts...)
    44  }
    45  
    46  func NewRoundRobinGroup(routees ...*actor.PID) *actor.Props {
    47  	return (&actor.Props{}).Configure(actor.WithSpawnFunc(spawner(&roundRobinGroupRouter{GroupRouter{Routees: actor.NewPIDSet(routees...)}})))
    48  }
    49  
    50  func (config *roundRobinPoolRouter) CreateRouterState() State {
    51  	return &roundRobinState{}
    52  }
    53  
    54  func (config *roundRobinGroupRouter) CreateRouterState() State {
    55  	return &roundRobinState{}
    56  }
    57  
    58  func roundRobinRoutee(index *int32, routees *actor.PIDSet) *actor.PID {
    59  	i := int(atomic.AddInt32(index, 1))
    60  	if i < 0 {
    61  		*index = 0
    62  		i = 0
    63  	}
    64  	mod := routees.Len()
    65  	routee := routees.Get(i % mod)
    66  	return routee
    67  }