github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/api4/websocket_test.go (about) 1 // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2 // See LICENSE.txt for license information. 3 4 package api4 5 6 import ( 7 "fmt" 8 "net/http" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/gorilla/websocket" 14 "github.com/stretchr/testify/require" 15 16 "github.com/masterhung0112/hk_server/v5/model" 17 ) 18 19 func TestWebSocketTrailingSlash(t *testing.T) { 20 th := Setup(t) 21 defer th.TearDown() 22 23 url := fmt.Sprintf("ws://localhost:%v", th.App.Srv().ListenAddr.Port) 24 _, _, err := websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket/", nil) 25 require.NoError(t, err) 26 } 27 28 func TestWebSocketEvent(t *testing.T) { 29 th := Setup(t).InitBasic() 30 defer th.TearDown() 31 32 WebSocketClient, err := th.CreateWebSocketClient() 33 require.Nil(t, err) 34 defer WebSocketClient.Close() 35 36 WebSocketClient.Listen() 37 38 resp := <-WebSocketClient.ResponseChannel 39 require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge") 40 41 omitUser := make(map[string]bool, 1) 42 omitUser["somerandomid"] = true 43 evt1 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", th.BasicChannel.Id, "", omitUser) 44 evt1.Add("user_id", "somerandomid") 45 th.App.Publish(evt1) 46 47 time.Sleep(300 * time.Millisecond) 48 49 stop := make(chan bool) 50 eventHit := false 51 52 go func() { 53 for { 54 select { 55 case resp := <-WebSocketClient.EventChannel: 56 if resp.EventType() == model.WEBSOCKET_EVENT_TYPING && resp.GetData()["user_id"].(string) == "somerandomid" { 57 eventHit = true 58 } 59 case <-stop: 60 return 61 } 62 } 63 }() 64 65 time.Sleep(400 * time.Millisecond) 66 67 stop <- true 68 69 require.True(t, eventHit, "did not receive typing event") 70 71 evt2 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", "somerandomid", "", nil) 72 th.App.Publish(evt2) 73 time.Sleep(300 * time.Millisecond) 74 75 eventHit = false 76 77 go func() { 78 for { 79 select { 80 case resp := <-WebSocketClient.EventChannel: 81 if resp.EventType() == model.WEBSOCKET_EVENT_TYPING { 82 eventHit = true 83 } 84 case <-stop: 85 return 86 } 87 } 88 }() 89 90 time.Sleep(400 * time.Millisecond) 91 92 stop <- true 93 94 require.False(t, eventHit, "got typing event for bad channel id") 95 } 96 97 func TestCreateDirectChannelWithSocket(t *testing.T) { 98 th := Setup(t).InitBasic() 99 defer th.TearDown() 100 101 Client := th.Client 102 user2 := th.BasicUser2 103 104 users := make([]*model.User, 0) 105 users = append(users, user2) 106 107 for i := 0; i < 10; i++ { 108 users = append(users, th.CreateUser()) 109 } 110 111 WebSocketClient, err := th.CreateWebSocketClient() 112 require.Nil(t, err) 113 defer WebSocketClient.Close() 114 WebSocketClient.Listen() 115 116 resp := <-WebSocketClient.ResponseChannel 117 require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge") 118 119 wsr := <-WebSocketClient.EventChannel 120 require.Equal(t, wsr.EventType(), model.WEBSOCKET_EVENT_HELLO, "missing hello") 121 122 stop := make(chan bool) 123 count := 0 124 125 go func() { 126 for { 127 select { 128 case wsr := <-WebSocketClient.EventChannel: 129 if wsr != nil && wsr.EventType() == model.WEBSOCKET_EVENT_DIRECT_ADDED { 130 count = count + 1 131 } 132 133 case <-stop: 134 return 135 } 136 } 137 }() 138 139 for _, user := range users { 140 time.Sleep(100 * time.Millisecond) 141 _, resp := Client.CreateDirectChannel(th.BasicUser.Id, user.Id) 142 require.Nil(t, resp.Error, "failed to create DM channel") 143 } 144 145 time.Sleep(5000 * time.Millisecond) 146 147 stop <- true 148 149 require.Equal(t, count, len(users), "We didn't get the proper amount of direct_added messages") 150 } 151 152 func TestWebsocketOriginSecurity(t *testing.T) { 153 th := Setup(t) 154 defer th.TearDown() 155 156 url := fmt.Sprintf("ws://localhost:%v", th.App.Srv().ListenAddr.Port) 157 158 // Should fail because origin doesn't match 159 _, _, err := websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{ 160 "Origin": []string{"http://www.evil.com"}, 161 }) 162 163 require.Error(t, err, "Should have errored because Origin does not match host! SECURITY ISSUE!") 164 165 // We are not a browser so we can spoof this just fine 166 _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{ 167 "Origin": []string{fmt.Sprintf("http://localhost:%v", th.App.Srv().ListenAddr.Port)}, 168 }) 169 require.NoError(t, err, err) 170 171 // Should succeed now because open CORS 172 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "*" }) 173 _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{ 174 "Origin": []string{"http://www.evil.com"}, 175 }) 176 require.NoError(t, err, err) 177 178 // Should succeed now because matching CORS 179 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "http://www.evil.com" }) 180 _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{ 181 "Origin": []string{"http://www.evil.com"}, 182 }) 183 require.NoError(t, err, err) 184 185 // Should fail because non-matching CORS 186 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "http://www.good.com" }) 187 _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{ 188 "Origin": []string{"http://www.evil.com"}, 189 }) 190 require.Error(t, err, "Should have errored because Origin contain AllowCorsFrom") 191 192 // Should fail because non-matching CORS 193 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "http://www.good.com" }) 194 _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{ 195 "Origin": []string{"http://www.good.co"}, 196 }) 197 require.Error(t, err, "Should have errored because Origin does not match host! SECURITY ISSUE!") 198 199 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "" }) 200 } 201 202 func TestWebSocketStatuses(t *testing.T) { 203 th := Setup(t).InitBasic() 204 defer th.TearDown() 205 206 Client := th.Client 207 WebSocketClient, err := th.CreateWebSocketClient() 208 require.Nil(t, err, err) 209 defer WebSocketClient.Close() 210 WebSocketClient.Listen() 211 212 resp := <-WebSocketClient.ResponseChannel 213 require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge") 214 215 team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewRandomTeamName() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} 216 rteam, _ := Client.CreateTeam(&team) 217 218 user := model.User{Email: strings.ToLower(model.NewId()) + "success+test@simulator.amazonses.com", Nickname: "Corey Hulen", Password: "passwd1"} 219 ruser := Client.Must(Client.CreateUser(&user)).(*model.User) 220 th.LinkUserToTeam(ruser, rteam) 221 _, nErr := th.App.Srv().Store.User().VerifyEmail(ruser.Id, ruser.Email) 222 require.NoError(t, nErr) 223 224 user2 := model.User{Email: strings.ToLower(model.NewId()) + "success+test@simulator.amazonses.com", Nickname: "Corey Hulen", Password: "passwd1"} 225 ruser2 := Client.Must(Client.CreateUser(&user2)).(*model.User) 226 th.LinkUserToTeam(ruser2, rteam) 227 _, nErr = th.App.Srv().Store.User().VerifyEmail(ruser2.Id, ruser2.Email) 228 require.NoError(t, nErr) 229 230 Client.Login(user.Email, user.Password) 231 232 th.LoginBasic2() 233 234 WebSocketClient2, err2 := th.CreateWebSocketClient() 235 require.Nil(t, err2, err2) 236 237 time.Sleep(1000 * time.Millisecond) 238 239 WebSocketClient.GetStatuses() 240 resp = <-WebSocketClient.ResponseChannel 241 require.Nil(t, resp.Error, resp.Error) 242 243 require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number") 244 245 allowedValues := [4]string{model.STATUS_OFFLINE, model.STATUS_AWAY, model.STATUS_ONLINE, model.STATUS_DND} 246 for _, status := range resp.Data { 247 require.Containsf(t, allowedValues, status, "one of the statuses had an invalid value status=%v", status) 248 } 249 250 status, ok := resp.Data[th.BasicUser2.Id] 251 require.True(t, ok, "should have had user status") 252 253 require.Equal(t, status, model.STATUS_ONLINE, "status should have been online status=%v", status) 254 255 WebSocketClient.GetStatusesByIds([]string{th.BasicUser2.Id}) 256 resp = <-WebSocketClient.ResponseChannel 257 require.Nil(t, resp.Error, resp.Error) 258 259 require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number") 260 261 allowedValues = [4]string{model.STATUS_OFFLINE, model.STATUS_AWAY, model.STATUS_ONLINE} 262 for _, status := range resp.Data { 263 require.Containsf(t, allowedValues, status, "one of the statuses had an invalid value status") 264 } 265 266 status, ok = resp.Data[th.BasicUser2.Id] 267 require.True(t, ok, "should have had user status") 268 269 require.Equal(t, status, model.STATUS_ONLINE, "status should have been online status=%v", status) 270 require.Equal(t, len(resp.Data), 1, "only 1 status should be returned") 271 272 WebSocketClient.GetStatusesByIds([]string{ruser2.Id, "junk"}) 273 resp = <-WebSocketClient.ResponseChannel 274 require.Nil(t, resp.Error, resp.Error) 275 require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number") 276 require.Equal(t, len(resp.Data), 2, "2 statuses should be returned") 277 278 WebSocketClient.GetStatusesByIds([]string{}) 279 if resp2 := <-WebSocketClient.ResponseChannel; resp2.Error == nil { 280 require.Equal(t, resp2.SeqReply, WebSocketClient.Sequence-1, "bad sequence number") 281 require.NotNil(t, resp2.Error, "should have errored - empty user ids") 282 } 283 284 WebSocketClient2.Close() 285 286 th.App.SetStatusAwayIfNeeded(th.BasicUser.Id, false) 287 288 awayTimeout := *th.App.Config().TeamSettings.UserStatusAwayTimeout 289 defer func() { 290 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.UserStatusAwayTimeout = awayTimeout }) 291 }() 292 th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.UserStatusAwayTimeout = 1 }) 293 294 time.Sleep(1500 * time.Millisecond) 295 296 th.App.SetStatusAwayIfNeeded(th.BasicUser.Id, false) 297 th.App.SetStatusOnline(th.BasicUser.Id, false) 298 299 time.Sleep(1500 * time.Millisecond) 300 301 WebSocketClient.GetStatuses() 302 resp = <-WebSocketClient.ResponseChannel 303 require.Nil(t, resp.Error) 304 305 require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number") 306 _, ok = resp.Data[th.BasicUser2.Id] 307 require.False(t, ok, "should not have had user status") 308 309 stop := make(chan bool) 310 onlineHit := false 311 awayHit := false 312 313 go func() { 314 for { 315 select { 316 case resp := <-WebSocketClient.EventChannel: 317 if resp.EventType() == model.WEBSOCKET_EVENT_STATUS_CHANGE && resp.GetData()["user_id"].(string) == th.BasicUser.Id { 318 status := resp.GetData()["status"].(string) 319 if status == model.STATUS_ONLINE { 320 onlineHit = true 321 } else if status == model.STATUS_AWAY { 322 awayHit = true 323 } 324 } 325 case <-stop: 326 return 327 } 328 } 329 }() 330 331 time.Sleep(500 * time.Millisecond) 332 333 stop <- true 334 335 require.True(t, onlineHit, "didn't get online event") 336 require.True(t, awayHit, "didn't get away event") 337 338 time.Sleep(500 * time.Millisecond) 339 340 WebSocketClient.Close() 341 }