github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/router/consistent_hash_router_test.go (about)

     1  package router_test
     2  
     3  import (
     4  	"log/slog"
     5  	"strconv"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/asynkron/protoactor-go/actor"
    12  	"github.com/asynkron/protoactor-go/router"
    13  )
    14  
    15  var system = actor.NewActorSystem()
    16  
    17  type myMessage struct {
    18  	i   int32
    19  	pid *actor.PID
    20  }
    21  
    22  type getRoutees struct {
    23  	pid *actor.PID
    24  }
    25  
    26  func (m *myMessage) Hash() string {
    27  	i := atomic.LoadInt32(&m.i)
    28  	return strconv.Itoa(int(i))
    29  }
    30  
    31  var wait sync.WaitGroup
    32  
    33  type (
    34  	routerActor  struct{}
    35  	tellerActor  struct{}
    36  	managerActor struct {
    37  		set  []*actor.PID
    38  		rpid *actor.PID
    39  	}
    40  )
    41  
    42  func (state *routerActor) Receive(context actor.Context) {
    43  	switch msg := context.Message().(type) {
    44  	case *myMessage:
    45  		//context.Logger().Info("%v got message", slog.Any("self", context.Self()), slog.Int("msg", int(msg.i)))
    46  		atomic.AddInt32(&msg.i, 1)
    47  		wait.Done()
    48  	}
    49  }
    50  
    51  func (state *tellerActor) Receive(context actor.Context) {
    52  	switch msg := context.Message().(type) {
    53  	case *myMessage:
    54  		start := msg.i
    55  		for i := 0; i < 100; i++ {
    56  			context.Send(msg.pid, msg)
    57  			time.Sleep(10 * time.Millisecond)
    58  		}
    59  		if msg.i != start+100 {
    60  			context.Logger().Error("Expected to send 100 messages", slog.Int("start", int(start)), slog.Int("end", int(msg.i)))
    61  		}
    62  	}
    63  }
    64  
    65  func (state *managerActor) Receive(context actor.Context) {
    66  	switch msg := context.Message().(type) {
    67  	case *router.Routees:
    68  		state.set = msg.PIDs
    69  		for i, v := range state.set {
    70  			if i%2 == 0 {
    71  				context.Send(state.rpid, &router.RemoveRoutee{PID: v})
    72  				// log.Println(v)
    73  			} else {
    74  				props := actor.PropsFromProducer(func() actor.Actor { return &routerActor{} })
    75  				pid := context.Spawn(props)
    76  				context.Send(state.rpid, &router.AddRoutee{PID: pid})
    77  				// log.Println(v)
    78  			}
    79  		}
    80  		context.Send(context.Self(), &getRoutees{state.rpid})
    81  	case *getRoutees:
    82  		state.rpid = msg.pid
    83  		context.Request(msg.pid, &router.GetRoutees{})
    84  	}
    85  }
    86  
    87  func TestConcurrency(t *testing.T) {
    88  	if testing.Short() {
    89  		t.SkipNow()
    90  	}
    91  
    92  	wait.Add(100 * 1000)
    93  	rpid := system.Root.Spawn(router.NewConsistentHashPool(100).Configure(actor.WithProducer(func() actor.Actor { return &routerActor{} })))
    94  
    95  	props := actor.PropsFromProducer(func() actor.Actor { return &tellerActor{} })
    96  	for i := 0; i < 1000; i++ {
    97  		pid := system.Root.Spawn(props)
    98  		system.Root.Send(pid, &myMessage{int32(i), rpid})
    99  	}
   100  
   101  	props = actor.PropsFromProducer(func() actor.Actor { return &managerActor{} })
   102  	pid := system.Root.Spawn(props)
   103  	system.Root.Send(pid, &getRoutees{rpid})
   104  
   105  	// Implementing the timeout
   106  	timeout := time.After(5 * time.Second)
   107  	done := make(chan bool)
   108  	go func() {
   109  		wait.Wait()
   110  		done <- true
   111  	}()
   112  
   113  	select {
   114  	case <-timeout:
   115  		t.Fatal("Test timed out")
   116  	case <-done:
   117  		// Test completed within timeout
   118  	}
   119  
   120  }