github.com/anycable/anycable-go@v1.5.1/node/session_test.go (about) 1 package node 2 3 import ( 4 "sync" 5 "testing" 6 7 "github.com/anycable/anycable-go/common" 8 "github.com/anycable/anycable-go/ws" 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 ) 12 13 func TestSendRaceConditions(t *testing.T) { 14 node := NewMockNode() 15 var wg sync.WaitGroup 16 17 for i := 1; i <= 10; i++ { 18 session := NewMockSession("123", node) 19 20 go func() { 21 for { 22 session.conn.Read() // nolint:errcheck 23 } 24 }() 25 26 wg.Add(2) 27 go func() { 28 go func() { 29 session.sendFrame(&ws.SentFrame{FrameType: ws.TextFrame, Payload: []byte("hi!")}) 30 wg.Done() 31 }() 32 33 go func() { 34 session.sendFrame(&ws.SentFrame{FrameType: ws.TextFrame, Payload: []byte("bye")}) 35 wg.Done() 36 }() 37 }() 38 39 wg.Add(2) 40 go func() { 41 go func() { 42 session.sendFrame(&ws.SentFrame{FrameType: ws.TextFrame, Payload: []byte("bye")}) 43 wg.Done() 44 }() 45 46 go func() { 47 session.sendFrame(&ws.SentFrame{FrameType: ws.TextFrame, Payload: []byte("why")}) 48 wg.Done() 49 }() 50 }() 51 } 52 53 wg.Wait() 54 } 55 56 func TestSessionSend(t *testing.T) { 57 node := NewMockNode() 58 session := NewMockSession("123", node) 59 60 go func() { 61 for i := 1; i <= 10; i++ { 62 session.sendFrame(&ws.SentFrame{FrameType: ws.TextFrame, Payload: []byte("bye")}) 63 } 64 }() 65 66 for i := 1; i <= 10; i++ { 67 _, err := session.conn.Read() 68 assert.Nil(t, err) 69 } 70 } 71 72 func TestSessionDisconnect(t *testing.T) { 73 node := NewMockNode() 74 session := NewMockSession("123", node) 75 session.closed = false 76 session.Connected = true 77 78 go func() { 79 session.sendFrame(&ws.SentFrame{FrameType: ws.TextFrame, Payload: []byte("bye")}) 80 session.Disconnect("test", 1042) 81 }() 82 83 // Message frame 84 _, err := session.conn.Read() 85 assert.Nil(t, err) 86 87 // Close frame 88 _, err = session.conn.Read() 89 assert.Nil(t, err) 90 } 91 92 func TestMergeEnv(t *testing.T) { 93 node := NewMockNode() 94 session := NewMockSession("123", node) 95 96 istate := map[string]map[string]string{ 97 "test_channel": { 98 "foo": "bar", 99 "a": "z", 100 }, 101 } 102 cstate := map[string]string{"_s_": "id=42"} 103 origEnv := common.SessionEnv{ChannelStates: &istate, ConnectionState: &cstate} 104 105 session.SetEnv(&origEnv) 106 107 istate2 := map[string]map[string]string{ 108 "test_channel": { 109 "foo": "baz", 110 }, 111 "another_channel": { 112 "wasting": "time", 113 }, 114 } 115 116 env := common.SessionEnv{ChannelStates: &istate2} 117 118 cstate2 := map[string]string{"red": "end of silence"} 119 120 env2 := common.SessionEnv{ConnectionState: &cstate2} 121 122 var wg sync.WaitGroup 123 124 wg.Add(2) 125 126 go func() { 127 session.MergeEnv(&env) 128 wg.Done() 129 }() 130 131 go func() { 132 session.MergeEnv(&env2) 133 wg.Done() 134 }() 135 136 wg.Wait() 137 138 assert.Equal(t, &origEnv, session.GetEnv()) 139 140 assert.Equal(t, "id=42", origEnv.GetConnectionStateField("_s_")) 141 assert.Equal(t, "end of silence", origEnv.GetConnectionStateField("red")) 142 143 assert.Equal(t, "baz", origEnv.GetChannelStateField("test_channel", "foo")) 144 assert.Equal(t, "z", origEnv.GetChannelStateField("test_channel", "a")) 145 assert.Equal(t, "time", origEnv.GetChannelStateField("another_channel", "wasting")) 146 } 147 148 func TestCacheEntry(t *testing.T) { 149 session := Session{} 150 151 session.subscriptions = NewSubscriptionState() 152 session.subscriptions.AddChannel("chat_1") 153 session.subscriptions.AddChannel("presence_1") 154 155 session.subscriptions.AddChannelStream("chat_1", "a") 156 session.subscriptions.AddChannelStream("chat_1", "b") 157 session.subscriptions.AddChannelStream("presence_1", "z") 158 159 session.env = common.NewSessionEnv("/cable", nil) 160 session.SetIdentifiers("plastilin") 161 session.env.MergeConnectionState(&map[string]string{"tenant": "x", "locale": "it"}) 162 session.env.MergeChannelState("chat_1", &map[string]string{"presence": "on"}) 163 164 session.MarkDisconnectable(true) 165 166 cached, err := session.ToCacheEntry() 167 require.NoError(t, err) 168 169 new_session := Session{} 170 new_session.subscriptions = NewSubscriptionState() 171 new_session.env = common.NewSessionEnv("/cable", nil) 172 173 err = new_session.RestoreFromCache(cached) 174 require.NoError(t, err) 175 176 assert.Equal(t, "plastilin", new_session.GetIdentifiers()) 177 178 assert.Contains(t, new_session.subscriptions.Channels(), "chat_1") 179 assert.Contains(t, new_session.subscriptions.Channels(), "presence_1") 180 assert.Contains(t, new_session.subscriptions.StreamsFor("chat_1"), "a") 181 assert.Contains(t, new_session.subscriptions.StreamsFor("chat_1"), "b") 182 assert.Contains(t, new_session.subscriptions.StreamsFor("presence_1"), "z") 183 184 assert.Equal(t, "x", new_session.env.GetConnectionStateField("tenant")) 185 assert.Equal(t, "it", new_session.env.GetConnectionStateField("locale")) 186 assert.Equal(t, "on", new_session.env.GetChannelStateField("chat_1", "presence")) 187 188 assert.True(t, new_session.IsDisconnectable()) 189 } 190 191 func TestCacheEntryEmptySession(t *testing.T) { 192 session := Session{} 193 session.subscriptions = NewSubscriptionState() 194 session.env = common.NewSessionEnv("/cable", nil) 195 196 cached, err := session.ToCacheEntry() 197 require.NoError(t, err) 198 199 new_session := Session{} 200 new_session.subscriptions = NewSubscriptionState() 201 new_session.env = common.NewSessionEnv("/cable", nil) 202 203 err = new_session.RestoreFromCache(cached) 204 require.NoError(t, err) 205 } 206 207 func TestMarkDisconnectable(t *testing.T) { 208 session := Session{} 209 210 session.MarkDisconnectable(false) 211 212 assert.False(t, session.IsDisconnectable()) 213 214 session.MarkDisconnectable(true) 215 216 assert.True(t, session.IsDisconnectable()) 217 218 session.MarkDisconnectable(false) 219 220 assert.True(t, session.IsDisconnectable()) 221 }