github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/remote/endpoint_watcher.go (about)

     1  package remote
     2  
     3  import (
     4  	"github.com/asynkron/protoactor-go/actor"
     5  	"log/slog"
     6  )
     7  
     8  func newEndpointWatcher(remote *Remote, address string) actor.Producer {
     9  	return func() actor.Actor {
    10  		watcher := &endpointWatcher{
    11  			behavior: actor.NewBehavior(),
    12  			address:  address,
    13  			remote:   remote,
    14  		}
    15  		watcher.behavior.Become(watcher.connected)
    16  		return watcher
    17  	}
    18  }
    19  
    20  type endpointWatcher struct {
    21  	behavior actor.Behavior
    22  	address  string
    23  	watched  map[string]*actor.PIDSet // key is the watching PID string, value is the watched PID
    24  	remote   *Remote
    25  }
    26  
    27  func (state *endpointWatcher) initialize() {
    28  	state.remote.Logger().Info("Started EndpointWatcher", slog.String("address", state.address))
    29  	state.watched = make(map[string]*actor.PIDSet)
    30  }
    31  
    32  func (state *endpointWatcher) Receive(ctx actor.Context) {
    33  	state.behavior.Receive(ctx)
    34  }
    35  
    36  func (state *endpointWatcher) connected(ctx actor.Context) {
    37  	switch msg := ctx.Message().(type) {
    38  	case *actor.Started:
    39  		state.initialize()
    40  
    41  	case *remoteTerminate:
    42  		// delete the watch entries
    43  		if pidSet, ok := state.watched[msg.Watcher.Id]; ok {
    44  			pidSet.Remove(msg.Watchee)
    45  			if pidSet.Len() == 0 {
    46  				delete(state.watched, msg.Watcher.Id)
    47  			}
    48  		}
    49  
    50  		terminated := &actor.Terminated{
    51  			Who: msg.Watchee,
    52  			Why: actor.TerminatedReason_Stopped,
    53  		}
    54  		ref, ok := state.remote.actorSystem.ProcessRegistry.GetLocal(msg.Watcher.Id)
    55  		if ok {
    56  			ref.SendSystemMessage(msg.Watcher, terminated)
    57  		}
    58  	case *EndpointConnectedEvent:
    59  		// Already connected, pass
    60  	case *EndpointTerminatedEvent:
    61  		state.remote.Logger().Info("EndpointWatcher handling terminated",
    62  			slog.String("address", state.address), slog.Int("watched", len(state.watched)))
    63  
    64  		for id, pidSet := range state.watched {
    65  			// try to find the watcher ExtensionID in the local actor registry
    66  			ref, ok := state.remote.actorSystem.ProcessRegistry.GetLocal(id)
    67  			if ok {
    68  				pidSet.ForEach(func(i int, pid *actor.PID) {
    69  					// create a terminated event for the Watched actor
    70  					terminated := &actor.Terminated{
    71  						Who: pid,
    72  						Why: actor.TerminatedReason_AddressTerminated,
    73  					}
    74  
    75  					watcher := state.remote.actorSystem.NewLocalPID(id)
    76  					// send the address Terminated event to the Watcher
    77  					ref.SendSystemMessage(watcher, terminated)
    78  				})
    79  			}
    80  		}
    81  
    82  		// Clear watcher's map
    83  		state.watched = make(map[string]*actor.PIDSet)
    84  		state.behavior.Become(state.terminated)
    85  		ctx.Stop(ctx.Self())
    86  
    87  	case *remoteWatch:
    88  		// add watchee to watcher's map
    89  		if pidSet, ok := state.watched[msg.Watcher.Id]; ok {
    90  			pidSet.Add(msg.Watchee)
    91  		} else {
    92  			state.watched[msg.Watcher.Id] = actor.NewPIDSet(msg.Watchee)
    93  		}
    94  
    95  		// recreate the Watch command
    96  		w := &actor.Watch{
    97  			Watcher: msg.Watcher,
    98  		}
    99  
   100  		// pass it off to the remote PID
   101  		state.remote.SendMessage(msg.Watchee, nil, w, nil, -1)
   102  
   103  	case *remoteUnwatch:
   104  		// delete the watch entries
   105  		if pidSet, ok := state.watched[msg.Watcher.Id]; ok {
   106  			pidSet.Remove(msg.Watchee)
   107  			if pidSet.Len() == 0 {
   108  				delete(state.watched, msg.Watcher.Id)
   109  			}
   110  		}
   111  
   112  		// recreate the Unwatch command
   113  		uw := &actor.Unwatch{
   114  			Watcher: msg.Watcher,
   115  		}
   116  
   117  		// pass it off to the remote PID
   118  		state.remote.SendMessage(msg.Watchee, nil, uw, nil, -1)
   119  	case actor.SystemMessage, actor.AutoReceiveMessage:
   120  		// ignore
   121  	default:
   122  		state.remote.Logger().Error("EndpointWatcher received unknown message", slog.String("address", state.address), slog.Any("message", msg))
   123  	}
   124  }
   125  
   126  func (state *endpointWatcher) terminated(ctx actor.Context) {
   127  	switch msg := ctx.Message().(type) {
   128  	case *remoteWatch:
   129  		// try to find the watcher ExtensionID in the local actor registry
   130  		ref, ok := state.remote.actorSystem.ProcessRegistry.GetLocal(msg.Watcher.Id)
   131  
   132  		if ok {
   133  			// create a terminated event for the Watched actor
   134  			terminated := &actor.Terminated{
   135  				Who: msg.Watchee,
   136  				Why: actor.TerminatedReason_AddressTerminated,
   137  			}
   138  			// send the address Terminated event to the Watcher
   139  			ref.SendSystemMessage(msg.Watcher, terminated)
   140  		}
   141  	case *EndpointConnectedEvent:
   142  		state.remote.Logger().Info("EndpointWatcher handling restart", slog.String("address", state.address))
   143  		state.behavior.Become(state.connected)
   144  	case *remoteTerminate, *EndpointTerminatedEvent, *remoteUnwatch:
   145  		// pass
   146  		state.remote.Logger().Error("EndpointWatcher receive message for already terminated endpoint", slog.String("address", state.address), slog.Any("message", msg))
   147  	case actor.SystemMessage, actor.AutoReceiveMessage:
   148  		// ignore
   149  	default:
   150  		state.remote.Logger().Error("EndpointWatcher received unknown message", slog.String("address", state.address), slog.Any("message", msg))
   151  	}
   152  }