github.com/anycable/anycable-go@v1.5.1/rpc/rpc_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log/slog"
     7  	"testing"
     8  
     9  	"github.com/anycable/anycable-go/common"
    10  	"github.com/anycable/anycable-go/metrics"
    11  	"github.com/anycable/anycable-go/mocks"
    12  	pb "github.com/anycable/anycable-go/protos"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/mock"
    15  	"github.com/stretchr/testify/require"
    16  	"google.golang.org/grpc/metadata"
    17  )
    18  
    19  type MockState struct {
    20  	ready  bool
    21  	closed bool
    22  }
    23  
    24  func (st MockState) Ready() error {
    25  	if st.ready {
    26  		return nil
    27  	}
    28  
    29  	return errors.New("not ready")
    30  }
    31  
    32  func (st MockState) Close() {
    33  }
    34  
    35  func (st MockState) SupportsActiveConns() bool {
    36  	return false
    37  }
    38  
    39  func (st MockState) ActiveConns() int {
    40  	return 0
    41  }
    42  
    43  func NewTestController() *Controller {
    44  	config := NewConfig()
    45  	metrics := metrics.NewMetrics(nil, 0, slog.Default())
    46  	controller, _ := NewController(metrics, &config, slog.Default())
    47  	barrier, _ := NewFixedSizeBarrier(5)
    48  	controller.barrier = barrier
    49  	controller.clientState = MockState{true, false}
    50  	return controller
    51  }
    52  
    53  func TestAuthenticate(t *testing.T) {
    54  	controller := NewTestController()
    55  	client := mocks.RPCClient{}
    56  	controller.client = &client
    57  
    58  	t.Run("Success", func(t *testing.T) {
    59  		url := "/cable-test"
    60  		headers := map[string]string{"cookie": "token=secret;"}
    61  
    62  		client.On("Connect", mock.Anything,
    63  			&pb.ConnectionRequest{
    64  				Env: &pb.Env{Url: url, Headers: headers},
    65  			}).Return(
    66  			&pb.ConnectionResponse{
    67  				Identifiers:   "user=john",
    68  				Transmissions: []string{"welcome"},
    69  				Status:        pb.Status_SUCCESS,
    70  				Env:           &pb.EnvResponse{Cstate: map[string]string{"_s_": "test-session"}},
    71  			}, nil)
    72  
    73  		res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers})
    74  		assert.Nil(t, err)
    75  		assert.Equal(t, []string{"welcome"}, res.Transmissions)
    76  		assert.Equal(t, "user=john", res.Identifier)
    77  		assert.Equal(t, map[string]string{"_s_": "test-session"}, res.CState)
    78  		assert.Empty(t, res.Broadcasts)
    79  	})
    80  
    81  	t.Run("Failure", func(t *testing.T) {
    82  		url := "/cable-test"
    83  		headers := map[string]string{"cookie": "token=invalid;"}
    84  
    85  		client.On("Connect", mock.Anything,
    86  			&pb.ConnectionRequest{
    87  				Env: &pb.Env{Url: url, Headers: headers},
    88  			}).Return(
    89  			&pb.ConnectionResponse{
    90  				Transmissions: []string{"unauthorized"},
    91  				Status:        pb.Status_FAILURE,
    92  				Env:           &pb.EnvResponse{Cstate: map[string]string{"_s_": "test-session"}},
    93  				ErrorMsg:      "Authentication failed",
    94  			}, nil)
    95  
    96  		res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers})
    97  		assert.Nil(t, err)
    98  		assert.Equal(t, []string{"unauthorized"}, res.Transmissions)
    99  		assert.Equal(t, "", res.Identifier)
   100  		assert.Equal(t, map[string]string{"_s_": "test-session"}, res.CState)
   101  		assert.Empty(t, res.Broadcasts)
   102  	})
   103  
   104  	t.Run("Error", func(t *testing.T) {
   105  		url := "/cable-test"
   106  		headers := map[string]string{"cookie": "token=exceptional;"}
   107  
   108  		client.On("Connect", mock.Anything,
   109  			&pb.ConnectionRequest{
   110  				Env: &pb.Env{Url: url, Headers: headers},
   111  			}).Return(
   112  			&pb.ConnectionResponse{
   113  				Status:   pb.Status_ERROR,
   114  				ErrorMsg: "Exception",
   115  			}, nil)
   116  
   117  		res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers})
   118  		assert.NotNil(t, err)
   119  		assert.Error(t, err, "Exception")
   120  		assert.Nil(t, res.Transmissions)
   121  		assert.Equal(t, "", res.Identifier)
   122  		assert.Nil(t, res.CState)
   123  		assert.Empty(t, res.Broadcasts)
   124  	})
   125  }
   126  
   127  func TestPerform(t *testing.T) {
   128  	controller := NewTestController()
   129  	client := mocks.RPCClient{}
   130  	controller.client = &client
   131  
   132  	t.Run("Success", func(t *testing.T) {
   133  		url := "/cable-test"
   134  		headers := map[string]string{"cookie": "token=secret;"}
   135  		cstate := map[string]string{"_s_": "id=42"}
   136  
   137  		client.On("Command", mock.Anything,
   138  			&pb.CommandMessage{
   139  				Command:               "message",
   140  				ConnectionIdentifiers: "ids",
   141  				Identifier:            "test_channel",
   142  				Data:                  "hello",
   143  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   144  			}).Return(
   145  			&pb.CommandResponse{
   146  				Status:        pb.Status_SUCCESS,
   147  				Streams:       []string{"chat_42"},
   148  				StopStreams:   true,
   149  				Env:           &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}},
   150  				Transmissions: []string{"message_sent"},
   151  			}, nil)
   152  
   153  		res, err := controller.Perform(
   154  			"42",
   155  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   156  			"ids", "test_channel", "hello",
   157  		)
   158  
   159  		assert.Nil(t, err)
   160  		assert.Equal(t, []string{"message_sent"}, res.Transmissions)
   161  		assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState)
   162  		assert.True(t, res.StopAllStreams)
   163  		assert.Equal(t, []string{"chat_42"}, res.Streams)
   164  		assert.Nil(t, res.StoppedStreams)
   165  		assert.Empty(t, res.Broadcasts)
   166  	})
   167  
   168  	t.Run("Failure", func(t *testing.T) {
   169  		url := "/cable-test"
   170  		headers := map[string]string{"cookie": "token=invalid;"}
   171  		cstate := map[string]string{"_s_": "id=42"}
   172  
   173  		client.On("Command", mock.Anything,
   174  			&pb.CommandMessage{
   175  				Command:               "message",
   176  				ConnectionIdentifiers: "ids",
   177  				Identifier:            "test_channel",
   178  				Data:                  "fail",
   179  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   180  			}).Return(
   181  			&pb.CommandResponse{
   182  				Status:        pb.Status_FAILURE,
   183  				Streams:       []string{"chat_42"},
   184  				StopStreams:   true,
   185  				Env:           &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}},
   186  				Transmissions: []string{"message_sent"},
   187  				ErrorMsg:      "Forbidden",
   188  			}, nil)
   189  
   190  		res, err := controller.Perform(
   191  			"42",
   192  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   193  			"ids", "test_channel", "fail",
   194  		)
   195  
   196  		assert.Nil(t, err)
   197  		assert.Equal(t, common.FAILURE, res.Status)
   198  		assert.Equal(t, []string{"message_sent"}, res.Transmissions)
   199  		assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState)
   200  		assert.True(t, res.StopAllStreams)
   201  		assert.Equal(t, []string{"chat_42"}, res.Streams)
   202  		assert.Nil(t, res.StoppedStreams)
   203  		assert.Empty(t, res.Broadcasts)
   204  	})
   205  
   206  	t.Run("Error", func(t *testing.T) {
   207  		url := "/cable-test"
   208  		headers := map[string]string{"cookie": "token=invalid;"}
   209  		cstate := map[string]string{"_s_": "id=42"}
   210  
   211  		client.On("Command", mock.Anything,
   212  			&pb.CommandMessage{
   213  				Command:               "message",
   214  				ConnectionIdentifiers: "ids",
   215  				Identifier:            "test_channel",
   216  				Data:                  "exception",
   217  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   218  			}).Return(
   219  			&pb.CommandResponse{
   220  				Status:      pb.Status_ERROR,
   221  				StopStreams: true,
   222  				ErrorMsg:    "Exception",
   223  			}, nil)
   224  
   225  		res, err := controller.Perform(
   226  			"42",
   227  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   228  			"ids", "test_channel", "exception",
   229  		)
   230  
   231  		assert.NotNil(t, err)
   232  		assert.Equal(t, common.ERROR, res.Status)
   233  		assert.Error(t, err, "Exception")
   234  		assert.Nil(t, res.Transmissions)
   235  		assert.True(t, res.StopAllStreams)
   236  		assert.Nil(t, res.Streams)
   237  		assert.Nil(t, res.StoppedStreams)
   238  		assert.Empty(t, res.Broadcasts)
   239  	})
   240  
   241  	t.Run("With stopped streams", func(t *testing.T) {
   242  		url := "/cable-test"
   243  		headers := map[string]string{"cookie": "token=secret;"}
   244  		cstate := map[string]string{"_s_": "id=42"}
   245  
   246  		client.On("Command", mock.Anything,
   247  			&pb.CommandMessage{
   248  				Command:               "message",
   249  				ConnectionIdentifiers: "ids",
   250  				Identifier:            "test_channel",
   251  				Data:                  "stop_stream",
   252  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   253  			}).Return(
   254  			&pb.CommandResponse{
   255  				Status:         pb.Status_SUCCESS,
   256  				StoppedStreams: []string{"chat_42"},
   257  				StopStreams:    false,
   258  				Env:            &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}},
   259  				Transmissions:  []string{"message_sent"},
   260  			}, nil)
   261  
   262  		res, err := controller.Perform(
   263  			"42",
   264  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   265  			"ids", "test_channel", "stop_stream",
   266  		)
   267  
   268  		assert.Nil(t, err)
   269  		assert.Equal(t, []string{"message_sent"}, res.Transmissions)
   270  		assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState)
   271  		assert.False(t, res.StopAllStreams)
   272  		assert.Equal(t, []string{"chat_42"}, res.StoppedStreams)
   273  		assert.Nil(t, res.Streams)
   274  		assert.Empty(t, res.Broadcasts)
   275  	})
   276  
   277  	t.Run("With channel state", func(t *testing.T) {
   278  		url := "/cable-test"
   279  		headers := map[string]string{"cookie": "token=secret;"}
   280  		istate := map[string]string{"room": "room:1"}
   281  
   282  		channels := make(map[string]map[string]string)
   283  		channels["test_channel"] = istate
   284  
   285  		client.On("Command", mock.Anything,
   286  			&pb.CommandMessage{
   287  				Command:               "message",
   288  				ConnectionIdentifiers: "ids",
   289  				Identifier:            "test_channel",
   290  				Data:                  "channel_state",
   291  				Env:                   &pb.Env{Url: url, Headers: headers, Istate: istate},
   292  			}).Return(
   293  			&pb.CommandResponse{
   294  				Status:         pb.Status_SUCCESS,
   295  				StoppedStreams: []string{"chat_42"},
   296  				StopStreams:    false,
   297  				Env:            &pb.EnvResponse{Istate: map[string]string{"count": "1"}},
   298  				Transmissions:  []string{"message_sent"},
   299  			}, nil)
   300  
   301  		res, err := controller.Perform(
   302  			"42",
   303  			&common.SessionEnv{URL: url, Headers: &headers, ChannelStates: &channels},
   304  			"ids", "test_channel", "channel_state",
   305  		)
   306  
   307  		assert.Nil(t, err)
   308  		assert.Equal(t, []string{"message_sent"}, res.Transmissions)
   309  		assert.Equal(t, map[string]string{"count": "1"}, res.IState)
   310  		assert.False(t, res.StopAllStreams)
   311  		assert.Equal(t, []string{"chat_42"}, res.StoppedStreams)
   312  		assert.Nil(t, res.Streams)
   313  		assert.Empty(t, res.Broadcasts)
   314  	})
   315  }
   316  
   317  func TestSubscribe(t *testing.T) {
   318  	controller := NewTestController()
   319  	client := mocks.RPCClient{}
   320  	controller.client = &client
   321  
   322  	t.Run("Success", func(t *testing.T) {
   323  		url := "/cable-test"
   324  		headers := map[string]string{"cookie": "token=secret;"}
   325  		cstate := map[string]string{"_s_": "id=42"}
   326  
   327  		client.On("Command", mock.Anything,
   328  			&pb.CommandMessage{
   329  				Command:               "subscribe",
   330  				ConnectionIdentifiers: "ids",
   331  				Identifier:            "test_channel",
   332  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   333  			}).Return(
   334  			&pb.CommandResponse{
   335  				Status:        pb.Status_SUCCESS,
   336  				Streams:       []string{"chat_42"},
   337  				Env:           &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}},
   338  				Transmissions: []string{"confirmed"},
   339  			}, nil)
   340  
   341  		res, err := controller.Subscribe(
   342  			"42",
   343  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   344  			"ids", "test_channel",
   345  		)
   346  
   347  		assert.Nil(t, err)
   348  		assert.Equal(t, []string{"confirmed"}, res.Transmissions)
   349  		assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState)
   350  		assert.False(t, res.StopAllStreams)
   351  		assert.Equal(t, []string{"chat_42"}, res.Streams)
   352  		assert.Nil(t, res.StoppedStreams)
   353  		assert.Empty(t, res.Broadcasts)
   354  	})
   355  
   356  	t.Run("Failure", func(t *testing.T) {
   357  		url := "/cable-test"
   358  		headers := map[string]string{"cookie": "token=secret;"}
   359  		cstate := map[string]string{"_s_": "id=42"}
   360  
   361  		client.On("Command", mock.Anything,
   362  			&pb.CommandMessage{
   363  				Command:               "subscribe",
   364  				ConnectionIdentifiers: "ids",
   365  				Identifier:            "fail_channel",
   366  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   367  			}).Return(
   368  			&pb.CommandResponse{
   369  				Status:        pb.Status_FAILURE,
   370  				ErrorMsg:      "Unauthorized",
   371  				Disconnect:    true,
   372  				Transmissions: []string{"rejected"},
   373  			}, nil)
   374  
   375  		res, err := controller.Subscribe(
   376  			"42",
   377  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   378  			"ids", "fail_channel",
   379  		)
   380  
   381  		assert.Nil(t, err)
   382  		assert.Equal(t, common.FAILURE, res.Status)
   383  		assert.Equal(t, []string{"rejected"}, res.Transmissions)
   384  		assert.True(t, res.Disconnect)
   385  		assert.Nil(t, res.StoppedStreams)
   386  		assert.Empty(t, res.Broadcasts)
   387  	})
   388  
   389  	t.Run("Error", func(t *testing.T) {
   390  		url := "/cable-test"
   391  		headers := map[string]string{"cookie": "token=secret;"}
   392  		cstate := map[string]string{"_s_": "id=42"}
   393  
   394  		client.On("Command", mock.Anything,
   395  			&pb.CommandMessage{
   396  				Command:               "subscribe",
   397  				ConnectionIdentifiers: "ids",
   398  				Identifier:            "error_channel",
   399  				Env:                   &pb.Env{Url: url, Headers: headers, Cstate: cstate},
   400  			}).Return(
   401  			&pb.CommandResponse{
   402  				Status:   pb.Status_ERROR,
   403  				ErrorMsg: "Exception",
   404  			}, nil)
   405  
   406  		res, err := controller.Subscribe(
   407  			"42",
   408  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate},
   409  			"ids", "error_channel",
   410  		)
   411  
   412  		assert.NotNil(t, err)
   413  		assert.Equal(t, common.ERROR, res.Status)
   414  	})
   415  }
   416  
   417  func TestDisconnect(t *testing.T) {
   418  	controller := NewTestController()
   419  	client := mocks.RPCClient{}
   420  	controller.client = &client
   421  
   422  	t.Run("Success", func(t *testing.T) {
   423  		url := "/cable-test"
   424  		headers := map[string]string{"cookie": "token=secret;"}
   425  		cstate := map[string]string{"_s_": "id=42"}
   426  		istate := map[string]string{"test_channel": "{\"room\":\"room:1\"}"}
   427  
   428  		channels := make(map[string]map[string]string)
   429  		channels["test_channel"] = map[string]string{"room": "room:1"}
   430  
   431  		client.On("Disconnect", mock.Anything,
   432  			&pb.DisconnectRequest{
   433  				Identifiers:   "ids",
   434  				Subscriptions: []string{"chat_42"},
   435  				Env:           &pb.Env{Url: url, Headers: headers, Cstate: cstate, Istate: istate},
   436  			}).Return(
   437  			&pb.DisconnectResponse{
   438  				Status: pb.Status_SUCCESS,
   439  			}, nil)
   440  
   441  		err := controller.Disconnect(
   442  			"42",
   443  			&common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate, ChannelStates: &channels},
   444  			"ids",
   445  			[]string{"chat_42"},
   446  		)
   447  		assert.Nil(t, err)
   448  	})
   449  }
   450  
   451  func TestCustomDialFun(t *testing.T) {
   452  	config := NewConfig()
   453  
   454  	service := mocks.RPCServer{}
   455  	stateHandler := MockState{true, false}
   456  
   457  	config.DialFun = NewInprocessServiceDialer(&service, stateHandler)
   458  
   459  	controller, err := NewController(metrics.NewMetrics(nil, 0, slog.Default()), &config, slog.Default())
   460  	require.NoError(t, err)
   461  	require.NoError(t, controller.Start())
   462  
   463  	t.Run("Connect", func(t *testing.T) {
   464  		url := "/cable-test"
   465  		headers := map[string]string{"cookie": "token=secret;"}
   466  
   467  		service.On("Connect", mock.Anything,
   468  			&pb.ConnectionRequest{
   469  				Env: &pb.Env{Url: url, Headers: headers},
   470  			}).Return(
   471  			&pb.ConnectionResponse{
   472  				Identifiers:   "user=john",
   473  				Transmissions: []string{"welcome"},
   474  				Status:        pb.Status_SUCCESS,
   475  				Env:           &pb.EnvResponse{Cstate: map[string]string{"_s_": "test-session"}},
   476  			}, nil)
   477  
   478  		res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers})
   479  		require.Nil(t, err)
   480  		assert.Equal(t, []string{"welcome"}, res.Transmissions)
   481  		assert.Equal(t, "user=john", res.Identifier)
   482  		assert.Equal(t, map[string]string{"_s_": "test-session"}, res.CState)
   483  		assert.Empty(t, res.Broadcasts)
   484  
   485  		call := service.Calls[0]
   486  		requestCtx, ok := call.Arguments[0].(context.Context)
   487  
   488  		require.True(t, ok)
   489  
   490  		md, ok := metadata.FromIncomingContext(requestCtx)
   491  		require.True(t, ok)
   492  
   493  		assert.Equal(t, []string{"42"}, md.Get("sid"))
   494  	})
   495  }