github.com/anycable/anycable-go@v1.5.1/rpc/rpc_test.go (about) 1 package rpc 2 3 import ( 4 "context" 5 "errors" 6 "log/slog" 7 "testing" 8 9 "github.com/anycable/anycable-go/common" 10 "github.com/anycable/anycable-go/metrics" 11 "github.com/anycable/anycable-go/mocks" 12 pb "github.com/anycable/anycable-go/protos" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/mock" 15 "github.com/stretchr/testify/require" 16 "google.golang.org/grpc/metadata" 17 ) 18 19 type MockState struct { 20 ready bool 21 closed bool 22 } 23 24 func (st MockState) Ready() error { 25 if st.ready { 26 return nil 27 } 28 29 return errors.New("not ready") 30 } 31 32 func (st MockState) Close() { 33 } 34 35 func (st MockState) SupportsActiveConns() bool { 36 return false 37 } 38 39 func (st MockState) ActiveConns() int { 40 return 0 41 } 42 43 func NewTestController() *Controller { 44 config := NewConfig() 45 metrics := metrics.NewMetrics(nil, 0, slog.Default()) 46 controller, _ := NewController(metrics, &config, slog.Default()) 47 barrier, _ := NewFixedSizeBarrier(5) 48 controller.barrier = barrier 49 controller.clientState = MockState{true, false} 50 return controller 51 } 52 53 func TestAuthenticate(t *testing.T) { 54 controller := NewTestController() 55 client := mocks.RPCClient{} 56 controller.client = &client 57 58 t.Run("Success", func(t *testing.T) { 59 url := "/cable-test" 60 headers := map[string]string{"cookie": "token=secret;"} 61 62 client.On("Connect", mock.Anything, 63 &pb.ConnectionRequest{ 64 Env: &pb.Env{Url: url, Headers: headers}, 65 }).Return( 66 &pb.ConnectionResponse{ 67 Identifiers: "user=john", 68 Transmissions: []string{"welcome"}, 69 Status: pb.Status_SUCCESS, 70 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "test-session"}}, 71 }, nil) 72 73 res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers}) 74 assert.Nil(t, err) 75 assert.Equal(t, []string{"welcome"}, res.Transmissions) 76 assert.Equal(t, "user=john", res.Identifier) 77 assert.Equal(t, map[string]string{"_s_": "test-session"}, res.CState) 78 assert.Empty(t, res.Broadcasts) 79 }) 80 81 t.Run("Failure", func(t *testing.T) { 82 url := "/cable-test" 83 headers := map[string]string{"cookie": "token=invalid;"} 84 85 client.On("Connect", mock.Anything, 86 &pb.ConnectionRequest{ 87 Env: &pb.Env{Url: url, Headers: headers}, 88 }).Return( 89 &pb.ConnectionResponse{ 90 Transmissions: []string{"unauthorized"}, 91 Status: pb.Status_FAILURE, 92 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "test-session"}}, 93 ErrorMsg: "Authentication failed", 94 }, nil) 95 96 res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers}) 97 assert.Nil(t, err) 98 assert.Equal(t, []string{"unauthorized"}, res.Transmissions) 99 assert.Equal(t, "", res.Identifier) 100 assert.Equal(t, map[string]string{"_s_": "test-session"}, res.CState) 101 assert.Empty(t, res.Broadcasts) 102 }) 103 104 t.Run("Error", func(t *testing.T) { 105 url := "/cable-test" 106 headers := map[string]string{"cookie": "token=exceptional;"} 107 108 client.On("Connect", mock.Anything, 109 &pb.ConnectionRequest{ 110 Env: &pb.Env{Url: url, Headers: headers}, 111 }).Return( 112 &pb.ConnectionResponse{ 113 Status: pb.Status_ERROR, 114 ErrorMsg: "Exception", 115 }, nil) 116 117 res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers}) 118 assert.NotNil(t, err) 119 assert.Error(t, err, "Exception") 120 assert.Nil(t, res.Transmissions) 121 assert.Equal(t, "", res.Identifier) 122 assert.Nil(t, res.CState) 123 assert.Empty(t, res.Broadcasts) 124 }) 125 } 126 127 func TestPerform(t *testing.T) { 128 controller := NewTestController() 129 client := mocks.RPCClient{} 130 controller.client = &client 131 132 t.Run("Success", func(t *testing.T) { 133 url := "/cable-test" 134 headers := map[string]string{"cookie": "token=secret;"} 135 cstate := map[string]string{"_s_": "id=42"} 136 137 client.On("Command", mock.Anything, 138 &pb.CommandMessage{ 139 Command: "message", 140 ConnectionIdentifiers: "ids", 141 Identifier: "test_channel", 142 Data: "hello", 143 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 144 }).Return( 145 &pb.CommandResponse{ 146 Status: pb.Status_SUCCESS, 147 Streams: []string{"chat_42"}, 148 StopStreams: true, 149 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}}, 150 Transmissions: []string{"message_sent"}, 151 }, nil) 152 153 res, err := controller.Perform( 154 "42", 155 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 156 "ids", "test_channel", "hello", 157 ) 158 159 assert.Nil(t, err) 160 assert.Equal(t, []string{"message_sent"}, res.Transmissions) 161 assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState) 162 assert.True(t, res.StopAllStreams) 163 assert.Equal(t, []string{"chat_42"}, res.Streams) 164 assert.Nil(t, res.StoppedStreams) 165 assert.Empty(t, res.Broadcasts) 166 }) 167 168 t.Run("Failure", func(t *testing.T) { 169 url := "/cable-test" 170 headers := map[string]string{"cookie": "token=invalid;"} 171 cstate := map[string]string{"_s_": "id=42"} 172 173 client.On("Command", mock.Anything, 174 &pb.CommandMessage{ 175 Command: "message", 176 ConnectionIdentifiers: "ids", 177 Identifier: "test_channel", 178 Data: "fail", 179 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 180 }).Return( 181 &pb.CommandResponse{ 182 Status: pb.Status_FAILURE, 183 Streams: []string{"chat_42"}, 184 StopStreams: true, 185 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}}, 186 Transmissions: []string{"message_sent"}, 187 ErrorMsg: "Forbidden", 188 }, nil) 189 190 res, err := controller.Perform( 191 "42", 192 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 193 "ids", "test_channel", "fail", 194 ) 195 196 assert.Nil(t, err) 197 assert.Equal(t, common.FAILURE, res.Status) 198 assert.Equal(t, []string{"message_sent"}, res.Transmissions) 199 assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState) 200 assert.True(t, res.StopAllStreams) 201 assert.Equal(t, []string{"chat_42"}, res.Streams) 202 assert.Nil(t, res.StoppedStreams) 203 assert.Empty(t, res.Broadcasts) 204 }) 205 206 t.Run("Error", func(t *testing.T) { 207 url := "/cable-test" 208 headers := map[string]string{"cookie": "token=invalid;"} 209 cstate := map[string]string{"_s_": "id=42"} 210 211 client.On("Command", mock.Anything, 212 &pb.CommandMessage{ 213 Command: "message", 214 ConnectionIdentifiers: "ids", 215 Identifier: "test_channel", 216 Data: "exception", 217 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 218 }).Return( 219 &pb.CommandResponse{ 220 Status: pb.Status_ERROR, 221 StopStreams: true, 222 ErrorMsg: "Exception", 223 }, nil) 224 225 res, err := controller.Perform( 226 "42", 227 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 228 "ids", "test_channel", "exception", 229 ) 230 231 assert.NotNil(t, err) 232 assert.Equal(t, common.ERROR, res.Status) 233 assert.Error(t, err, "Exception") 234 assert.Nil(t, res.Transmissions) 235 assert.True(t, res.StopAllStreams) 236 assert.Nil(t, res.Streams) 237 assert.Nil(t, res.StoppedStreams) 238 assert.Empty(t, res.Broadcasts) 239 }) 240 241 t.Run("With stopped streams", func(t *testing.T) { 242 url := "/cable-test" 243 headers := map[string]string{"cookie": "token=secret;"} 244 cstate := map[string]string{"_s_": "id=42"} 245 246 client.On("Command", mock.Anything, 247 &pb.CommandMessage{ 248 Command: "message", 249 ConnectionIdentifiers: "ids", 250 Identifier: "test_channel", 251 Data: "stop_stream", 252 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 253 }).Return( 254 &pb.CommandResponse{ 255 Status: pb.Status_SUCCESS, 256 StoppedStreams: []string{"chat_42"}, 257 StopStreams: false, 258 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}}, 259 Transmissions: []string{"message_sent"}, 260 }, nil) 261 262 res, err := controller.Perform( 263 "42", 264 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 265 "ids", "test_channel", "stop_stream", 266 ) 267 268 assert.Nil(t, err) 269 assert.Equal(t, []string{"message_sent"}, res.Transmissions) 270 assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState) 271 assert.False(t, res.StopAllStreams) 272 assert.Equal(t, []string{"chat_42"}, res.StoppedStreams) 273 assert.Nil(t, res.Streams) 274 assert.Empty(t, res.Broadcasts) 275 }) 276 277 t.Run("With channel state", func(t *testing.T) { 278 url := "/cable-test" 279 headers := map[string]string{"cookie": "token=secret;"} 280 istate := map[string]string{"room": "room:1"} 281 282 channels := make(map[string]map[string]string) 283 channels["test_channel"] = istate 284 285 client.On("Command", mock.Anything, 286 &pb.CommandMessage{ 287 Command: "message", 288 ConnectionIdentifiers: "ids", 289 Identifier: "test_channel", 290 Data: "channel_state", 291 Env: &pb.Env{Url: url, Headers: headers, Istate: istate}, 292 }).Return( 293 &pb.CommandResponse{ 294 Status: pb.Status_SUCCESS, 295 StoppedStreams: []string{"chat_42"}, 296 StopStreams: false, 297 Env: &pb.EnvResponse{Istate: map[string]string{"count": "1"}}, 298 Transmissions: []string{"message_sent"}, 299 }, nil) 300 301 res, err := controller.Perform( 302 "42", 303 &common.SessionEnv{URL: url, Headers: &headers, ChannelStates: &channels}, 304 "ids", "test_channel", "channel_state", 305 ) 306 307 assert.Nil(t, err) 308 assert.Equal(t, []string{"message_sent"}, res.Transmissions) 309 assert.Equal(t, map[string]string{"count": "1"}, res.IState) 310 assert.False(t, res.StopAllStreams) 311 assert.Equal(t, []string{"chat_42"}, res.StoppedStreams) 312 assert.Nil(t, res.Streams) 313 assert.Empty(t, res.Broadcasts) 314 }) 315 } 316 317 func TestSubscribe(t *testing.T) { 318 controller := NewTestController() 319 client := mocks.RPCClient{} 320 controller.client = &client 321 322 t.Run("Success", func(t *testing.T) { 323 url := "/cable-test" 324 headers := map[string]string{"cookie": "token=secret;"} 325 cstate := map[string]string{"_s_": "id=42"} 326 327 client.On("Command", mock.Anything, 328 &pb.CommandMessage{ 329 Command: "subscribe", 330 ConnectionIdentifiers: "ids", 331 Identifier: "test_channel", 332 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 333 }).Return( 334 &pb.CommandResponse{ 335 Status: pb.Status_SUCCESS, 336 Streams: []string{"chat_42"}, 337 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "sentCount=1"}}, 338 Transmissions: []string{"confirmed"}, 339 }, nil) 340 341 res, err := controller.Subscribe( 342 "42", 343 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 344 "ids", "test_channel", 345 ) 346 347 assert.Nil(t, err) 348 assert.Equal(t, []string{"confirmed"}, res.Transmissions) 349 assert.Equal(t, map[string]string{"_s_": "sentCount=1"}, res.CState) 350 assert.False(t, res.StopAllStreams) 351 assert.Equal(t, []string{"chat_42"}, res.Streams) 352 assert.Nil(t, res.StoppedStreams) 353 assert.Empty(t, res.Broadcasts) 354 }) 355 356 t.Run("Failure", func(t *testing.T) { 357 url := "/cable-test" 358 headers := map[string]string{"cookie": "token=secret;"} 359 cstate := map[string]string{"_s_": "id=42"} 360 361 client.On("Command", mock.Anything, 362 &pb.CommandMessage{ 363 Command: "subscribe", 364 ConnectionIdentifiers: "ids", 365 Identifier: "fail_channel", 366 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 367 }).Return( 368 &pb.CommandResponse{ 369 Status: pb.Status_FAILURE, 370 ErrorMsg: "Unauthorized", 371 Disconnect: true, 372 Transmissions: []string{"rejected"}, 373 }, nil) 374 375 res, err := controller.Subscribe( 376 "42", 377 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 378 "ids", "fail_channel", 379 ) 380 381 assert.Nil(t, err) 382 assert.Equal(t, common.FAILURE, res.Status) 383 assert.Equal(t, []string{"rejected"}, res.Transmissions) 384 assert.True(t, res.Disconnect) 385 assert.Nil(t, res.StoppedStreams) 386 assert.Empty(t, res.Broadcasts) 387 }) 388 389 t.Run("Error", func(t *testing.T) { 390 url := "/cable-test" 391 headers := map[string]string{"cookie": "token=secret;"} 392 cstate := map[string]string{"_s_": "id=42"} 393 394 client.On("Command", mock.Anything, 395 &pb.CommandMessage{ 396 Command: "subscribe", 397 ConnectionIdentifiers: "ids", 398 Identifier: "error_channel", 399 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate}, 400 }).Return( 401 &pb.CommandResponse{ 402 Status: pb.Status_ERROR, 403 ErrorMsg: "Exception", 404 }, nil) 405 406 res, err := controller.Subscribe( 407 "42", 408 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate}, 409 "ids", "error_channel", 410 ) 411 412 assert.NotNil(t, err) 413 assert.Equal(t, common.ERROR, res.Status) 414 }) 415 } 416 417 func TestDisconnect(t *testing.T) { 418 controller := NewTestController() 419 client := mocks.RPCClient{} 420 controller.client = &client 421 422 t.Run("Success", func(t *testing.T) { 423 url := "/cable-test" 424 headers := map[string]string{"cookie": "token=secret;"} 425 cstate := map[string]string{"_s_": "id=42"} 426 istate := map[string]string{"test_channel": "{\"room\":\"room:1\"}"} 427 428 channels := make(map[string]map[string]string) 429 channels["test_channel"] = map[string]string{"room": "room:1"} 430 431 client.On("Disconnect", mock.Anything, 432 &pb.DisconnectRequest{ 433 Identifiers: "ids", 434 Subscriptions: []string{"chat_42"}, 435 Env: &pb.Env{Url: url, Headers: headers, Cstate: cstate, Istate: istate}, 436 }).Return( 437 &pb.DisconnectResponse{ 438 Status: pb.Status_SUCCESS, 439 }, nil) 440 441 err := controller.Disconnect( 442 "42", 443 &common.SessionEnv{URL: url, Headers: &headers, ConnectionState: &cstate, ChannelStates: &channels}, 444 "ids", 445 []string{"chat_42"}, 446 ) 447 assert.Nil(t, err) 448 }) 449 } 450 451 func TestCustomDialFun(t *testing.T) { 452 config := NewConfig() 453 454 service := mocks.RPCServer{} 455 stateHandler := MockState{true, false} 456 457 config.DialFun = NewInprocessServiceDialer(&service, stateHandler) 458 459 controller, err := NewController(metrics.NewMetrics(nil, 0, slog.Default()), &config, slog.Default()) 460 require.NoError(t, err) 461 require.NoError(t, controller.Start()) 462 463 t.Run("Connect", func(t *testing.T) { 464 url := "/cable-test" 465 headers := map[string]string{"cookie": "token=secret;"} 466 467 service.On("Connect", mock.Anything, 468 &pb.ConnectionRequest{ 469 Env: &pb.Env{Url: url, Headers: headers}, 470 }).Return( 471 &pb.ConnectionResponse{ 472 Identifiers: "user=john", 473 Transmissions: []string{"welcome"}, 474 Status: pb.Status_SUCCESS, 475 Env: &pb.EnvResponse{Cstate: map[string]string{"_s_": "test-session"}}, 476 }, nil) 477 478 res, err := controller.Authenticate("42", &common.SessionEnv{URL: url, Headers: &headers}) 479 require.Nil(t, err) 480 assert.Equal(t, []string{"welcome"}, res.Transmissions) 481 assert.Equal(t, "user=john", res.Identifier) 482 assert.Equal(t, map[string]string{"_s_": "test-session"}, res.CState) 483 assert.Empty(t, res.Broadcasts) 484 485 call := service.Calls[0] 486 requestCtx, ok := call.Arguments[0].(context.Context) 487 488 require.True(t, ok) 489 490 md, ok := metadata.FromIncomingContext(requestCtx) 491 require.True(t, ok) 492 493 assert.Equal(t, []string{"42"}, md.Get("sid")) 494 }) 495 }