github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/remote/endpoint_reader.go (about)

     1  package remote
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"log/slog"
     7  
     8  	"google.golang.org/protobuf/proto"
     9  
    10  	"github.com/asynkron/protoactor-go/actor"
    11  	"golang.org/x/net/context"
    12  )
    13  
    14  type endpointReader struct {
    15  	suspended bool
    16  	remote    *Remote
    17  }
    18  
    19  func (s *endpointReader) mustEmbedUnimplementedRemotingServer() {
    20  	// TODO implement me
    21  	panic("implement me")
    22  }
    23  
    24  func (s *endpointReader) ListProcesses(ctx context.Context, request *ListProcessesRequest) (*ListProcessesResponse, error) {
    25  	panic("implement me")
    26  }
    27  
    28  func (s *endpointReader) GetProcessDiagnostics(ctx context.Context, request *GetProcessDiagnosticsRequest) (*GetProcessDiagnosticsResponse, error) {
    29  	panic("implement me")
    30  }
    31  
    32  func newEndpointReader(r *Remote) *endpointReader {
    33  	return &endpointReader{
    34  		remote: r,
    35  	}
    36  }
    37  
    38  func (s *endpointReader) Receive(stream Remoting_ReceiveServer) error {
    39  	disconnectChan := make(chan bool, 1)
    40  	s.remote.edpManager.endpointReaderConnections.Store(stream, disconnectChan)
    41  	defer func() {
    42  		close(disconnectChan)
    43  	}()
    44  
    45  	go func() {
    46  		// endpointManager sends true
    47  		// endpointReader sends false
    48  		if <-disconnectChan {
    49  			s.remote.Logger().Debug("EndpointReader is telling to remote that it's leaving")
    50  			err := stream.Send(&RemoteMessage{
    51  				MessageType: &RemoteMessage_DisconnectRequest{
    52  					DisconnectRequest: &DisconnectRequest{},
    53  				},
    54  			})
    55  			if err != nil {
    56  				s.remote.Logger().Error("EndpointReader failed to send disconnection message", slog.Any("error", err))
    57  			}
    58  		} else {
    59  			s.remote.edpManager.endpointReaderConnections.Delete(stream)
    60  			s.remote.Logger().Debug("EndpointReader removed active endpoint from endpointManager")
    61  		}
    62  	}()
    63  
    64  	for {
    65  		msg, err := stream.Recv()
    66  		switch {
    67  		case errors.Is(err, io.EOF):
    68  			s.remote.Logger().Info("EndpointReader stream closed")
    69  			disconnectChan <- false
    70  			return nil
    71  		case err != nil:
    72  			s.remote.Logger().Info("EndpointReader failed to read", slog.Any("error", err))
    73  			return err
    74  		case s.suspended:
    75  			continue
    76  		}
    77  
    78  		switch t := msg.MessageType.(type) {
    79  		case *RemoteMessage_ConnectRequest:
    80  			s.remote.Logger().Debug("EndpointReader received connect request", slog.Any("message", t.ConnectRequest))
    81  			c := t.ConnectRequest
    82  			_, err := s.OnConnectRequest(stream, c)
    83  			if err != nil {
    84  				s.remote.Logger().Error("EndpointReader failed to handle connect request", slog.Any("error", err))
    85  				return err
    86  			}
    87  		case *RemoteMessage_MessageBatch:
    88  			m := t.MessageBatch
    89  			err := s.onMessageBatch(m)
    90  			if err != nil {
    91  				return err
    92  			}
    93  		default:
    94  			{
    95  				s.remote.Logger().Warn("EndpointReader received unknown message type")
    96  			}
    97  		}
    98  	}
    99  }
   100  
   101  func (s *endpointReader) OnConnectRequest(stream Remoting_ReceiveServer, c *ConnectRequest) (bool, error) {
   102  	switch tt := c.ConnectionType.(type) {
   103  	case *ConnectRequest_ServerConnection:
   104  		{
   105  			sc := tt.ServerConnection
   106  			s.onServerConnection(stream, sc)
   107  		}
   108  	case *ConnectRequest_ClientConnection:
   109  		{
   110  			// TODO implement me
   111  			s.remote.Logger().Error("ClientConnection not implemented")
   112  		}
   113  	default:
   114  		s.remote.Logger().Error("EndpointReader received unknown connection type")
   115  		return true, nil
   116  	}
   117  	return false, nil
   118  }
   119  
   120  func (s *endpointReader) onMessageBatch(m *MessageBatch) error {
   121  	var (
   122  		sender *actor.PID
   123  		target *actor.PID
   124  	)
   125  
   126  	for _, envelope := range m.Envelopes {
   127  		data := envelope.MessageData
   128  
   129  		sender = deserializeSender(sender, envelope.Sender, envelope.SenderRequestId, m.Senders)
   130  		target = deserializeTarget(target, envelope.Target, envelope.TargetRequestId, m.Targets)
   131  		if target == nil {
   132  			s.remote.Logger().Error("EndpointReader received message with unknown target", slog.Int("target", int(envelope.Target)), slog.Int("targetRequestId", int(envelope.TargetRequestId)))
   133  			return errors.New("unknown target")
   134  		}
   135  
   136  		message, err := Deserialize(data, m.TypeNames[envelope.TypeId], envelope.SerializerId)
   137  		if err != nil {
   138  			s.remote.Logger().Error("EndpointReader failed to deserialize", slog.Any("error", err))
   139  			return err
   140  		}
   141  
   142  		// translate from on-the-wire representation to in-process representation
   143  		// this only applies to root level messages, and never on nested child messages
   144  		if v, ok := message.(RootSerialized); ok {
   145  			message, err = v.Deserialize()
   146  			if err != nil {
   147  				s.remote.Logger().Error("EndpointReader failed to deserialize", slog.Any("error", err))
   148  				return err
   149  			}
   150  		}
   151  
   152  		switch msg := message.(type) {
   153  		case *actor.Terminated:
   154  			rt := &remoteTerminate{
   155  				Watchee: msg.Who,
   156  				Watcher: target,
   157  			}
   158  			s.remote.edpManager.remoteTerminate(rt)
   159  		case actor.SystemMessage:
   160  			ref, _ := s.remote.actorSystem.ProcessRegistry.GetLocal(target.Id)
   161  			ref.SendSystemMessage(target, msg)
   162  		default:
   163  			var header map[string]string
   164  
   165  			// fast path
   166  			if sender == nil && envelope.MessageHeader == nil {
   167  				s.remote.actorSystem.Root.Send(target, message)
   168  				continue
   169  			}
   170  
   171  			// slow path
   172  			if envelope.MessageHeader != nil {
   173  				header = envelope.MessageHeader.HeaderData
   174  			}
   175  			localEnvelope := &actor.MessageEnvelope{
   176  				Header:  header,
   177  				Message: message,
   178  				Sender:  sender,
   179  			}
   180  			s.remote.actorSystem.Root.Send(target, localEnvelope)
   181  		}
   182  	}
   183  	return nil
   184  }
   185  
   186  func deserializeSender(pid *actor.PID, index int32, requestId uint32, arr []*actor.PID) *actor.PID {
   187  	if index == 0 {
   188  		pid = nil
   189  	} else {
   190  		pid = arr[index-1]
   191  
   192  		// if request id is used. make sure to clone the PID first, so we don't corrupt the lookup
   193  		if requestId > 0 {
   194  			pid, _ = proto.Clone(pid).(*actor.PID)
   195  			pid.RequestId = requestId
   196  		}
   197  	}
   198  	return pid
   199  }
   200  
   201  func deserializeTarget(pid *actor.PID, index int32, requestId uint32, arr []*actor.PID) *actor.PID {
   202  	pid = arr[index]
   203  
   204  	// if request id is used. make sure to clone the PID first, so we don't corrupt the lookup
   205  	if requestId > 0 {
   206  		pid, _ = proto.Clone(pid).(*actor.PID)
   207  		pid.RequestId = requestId
   208  	}
   209  
   210  	return pid
   211  }
   212  
   213  func (s *endpointReader) onServerConnection(stream Remoting_ReceiveServer, sc *ServerConnection) {
   214  	if s.remote.BlockList().IsBlocked(sc.SystemId) {
   215  		s.remote.Logger().Debug("EndpointReader is blocked")
   216  
   217  		err := stream.Send(
   218  			&RemoteMessage{
   219  				MessageType: &RemoteMessage_ConnectResponse{
   220  					ConnectResponse: &ConnectResponse{
   221  						Blocked:  true,
   222  						MemberId: s.remote.actorSystem.ID,
   223  					},
   224  				},
   225  			})
   226  		if err != nil {
   227  			s.remote.Logger().Error("EndpointReader failed to send ConnectResponse message", slog.Any("error", err))
   228  		}
   229  
   230  		address := sc.Address
   231  		systemID := sc.SystemId
   232  
   233  		// TODO
   234  		_ = address
   235  		_ = systemID
   236  	} else {
   237  		err := stream.Send(
   238  			&RemoteMessage{
   239  				MessageType: &RemoteMessage_ConnectResponse{
   240  					ConnectResponse: &ConnectResponse{
   241  						Blocked:  false,
   242  						MemberId: s.remote.actorSystem.ID,
   243  					},
   244  				},
   245  			})
   246  		if err != nil {
   247  			s.remote.Logger().Error("EndpointReader failed to send ConnectResponse message", slog.Any("error", err))
   248  		}
   249  	}
   250  }
   251  
   252  func (s *endpointReader) suspend(toSuspend bool) {
   253  	s.suspended = toSuspend
   254  	if toSuspend {
   255  		s.remote.Logger().Debug("Suspended EndpointReader")
   256  	}
   257  }