github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/session_test.go (about) 1 package liveshare 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "strings" 11 "sync" 12 "testing" 13 "time" 14 15 livesharetest "github.com/ungtb10d/cli/v2/pkg/liveshare/test" 16 "github.com/sourcegraph/jsonrpc2" 17 ) 18 19 const mockClientName = "liveshare-client" 20 21 func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { 22 joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 23 return joinWorkspaceResult{1}, nil 24 } 25 const sessionToken = "session-token" 26 opts = append( 27 opts, 28 livesharetest.WithPassword(sessionToken), 29 livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), 30 ) 31 testServer, err := livesharetest.NewServer(opts...) 32 if err != nil { 33 return nil, nil, fmt.Errorf("error creating server: %w", err) 34 } 35 36 session, err := Connect(context.Background(), Options{ 37 ClientName: mockClientName, 38 SessionID: "session-id", 39 SessionToken: sessionToken, 40 RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), 41 RelaySAS: "relay-sas", 42 HostPublicKeys: []string{livesharetest.SSHPublicKey}, 43 TLSConfig: &tls.Config{InsecureSkipVerify: true}, 44 Logger: newMockLogger(), 45 }) 46 if err != nil { 47 return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) 48 } 49 return testServer, session, nil 50 } 51 52 func TestServerStartSharing(t *testing.T) { 53 serverPort, serverProtocol := 2222, "sshd" 54 sendNotification := make(chan portUpdateNotification) 55 startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 56 var args []interface{} 57 if err := json.Unmarshal(*req.Params, &args); err != nil { 58 return nil, fmt.Errorf("error unmarshaling request: %w", err) 59 } 60 if len(args) < 3 { 61 return nil, errors.New("not enough arguments to start sharing") 62 } 63 port, ok := args[0].(float64) 64 if !ok { 65 return nil, errors.New("port argument is not an int") 66 } 67 if port != float64(serverPort) { 68 return nil, errors.New("port does not match serverPort") 69 } 70 if protocol, ok := args[1].(string); !ok { 71 return nil, errors.New("protocol argument is not a string") 72 } else if protocol != serverProtocol { 73 return nil, errors.New("protocol does not match serverProtocol") 74 } 75 if browseURL, ok := args[2].(string); !ok { 76 return nil, errors.New("browse url is not a string") 77 } else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) { 78 return nil, errors.New("browseURL does not match expected") 79 } 80 sendNotification <- portUpdateNotification{ 81 PortNotification: PortNotification{ 82 Port: int(port), 83 ChangeKind: PortChangeKindStart, 84 }, 85 conn: conn, 86 } 87 return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil 88 } 89 testServer, session, err := makeMockSession( 90 livesharetest.WithService("serverSharing.startSharing", startSharing), 91 ) 92 defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close() 93 94 if err != nil { 95 t.Errorf("error creating mock session: %v", err) 96 } 97 ctx := context.Background() 98 99 go func() { 100 notif := <-sendNotification 101 _, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif) 102 }() 103 104 done := make(chan error) 105 go func() { 106 streamID, err := session.StartSharing(ctx, serverProtocol, serverPort) 107 if err != nil { 108 done <- fmt.Errorf("error sharing server: %w", err) 109 } 110 if streamID.name == "" || streamID.condition == "" { 111 done <- errors.New("stream name or condition is blank") 112 } 113 done <- nil 114 }() 115 116 select { 117 case err := <-testServer.Err(): 118 t.Errorf("error from server: %v", err) 119 case err := <-done: 120 if err != nil { 121 t.Errorf("error from client: %v", err) 122 } 123 } 124 } 125 126 func TestServerGetSharedServers(t *testing.T) { 127 sharedServer := Port{ 128 SourcePort: 2222, 129 StreamName: "stream-name", 130 StreamCondition: "stream-condition", 131 } 132 getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 133 return []*Port{&sharedServer}, nil 134 } 135 testServer, session, err := makeMockSession( 136 livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), 137 ) 138 if err != nil { 139 t.Errorf("error creating mock session: %v", err) 140 } 141 defer testServer.Close() 142 ctx := context.Background() 143 done := make(chan error) 144 go func() { 145 ports, err := session.GetSharedServers(ctx) 146 if err != nil { 147 done <- fmt.Errorf("error getting shared servers: %w", err) 148 } 149 if len(ports) < 1 { 150 done <- errors.New("not enough ports returned") 151 } 152 if ports[0].SourcePort != sharedServer.SourcePort { 153 done <- errors.New("source port does not match") 154 } 155 if ports[0].StreamName != sharedServer.StreamName { 156 done <- errors.New("stream name does not match") 157 } 158 if ports[0].StreamCondition != sharedServer.StreamCondition { 159 done <- errors.New("stream condiion does not match") 160 } 161 done <- nil 162 }() 163 164 select { 165 case err := <-testServer.Err(): 166 t.Errorf("error from server: %v", err) 167 case err := <-done: 168 if err != nil { 169 t.Errorf("error from client: %v", err) 170 } 171 } 172 } 173 174 func TestServerUpdateSharedServerPrivacy(t *testing.T) { 175 updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { 176 var req []interface{} 177 if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { 178 return nil, fmt.Errorf("unmarshal req: %w", err) 179 } 180 if len(req) < 2 { 181 return nil, errors.New("request arguments is less than 2") 182 } 183 if port, ok := req[0].(float64); ok { 184 if port != 80.0 { 185 return nil, errors.New("port param is not expected value") 186 } 187 } else { 188 return nil, errors.New("port param is not a float64") 189 } 190 if privacy, ok := req[1].(string); ok { 191 if privacy != "public" { 192 return nil, fmt.Errorf("expected privacy param to be public but got %q", privacy) 193 } 194 } else { 195 return nil, fmt.Errorf("expected privacy param to be a bool but go %T", req[1]) 196 } 197 return nil, nil 198 } 199 testServer, session, err := makeMockSession( 200 livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility), 201 ) 202 if err != nil { 203 t.Errorf("creating mock session: %v", err) 204 } 205 defer testServer.Close() 206 ctx := context.Background() 207 done := make(chan error) 208 go func() { 209 done <- session.UpdateSharedServerPrivacy(ctx, 80, "public") 210 }() 211 select { 212 case err := <-testServer.Err(): 213 t.Errorf("error from server: %v", err) 214 case err := <-done: 215 if err != nil { 216 t.Errorf("error from client: %v", err) 217 } 218 } 219 } 220 221 func TestInvalidHostKey(t *testing.T) { 222 joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 223 return joinWorkspaceResult{1}, nil 224 } 225 const sessionToken = "session-token" 226 opts := []livesharetest.ServerOption{ 227 livesharetest.WithPassword(sessionToken), 228 livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), 229 } 230 testServer, err := livesharetest.NewServer(opts...) 231 if err != nil { 232 t.Errorf("error creating server: %v", err) 233 } 234 _, err = Connect(context.Background(), Options{ 235 SessionID: "session-id", 236 SessionToken: sessionToken, 237 RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), 238 RelaySAS: "relay-sas", 239 HostPublicKeys: []string{}, 240 TLSConfig: &tls.Config{InsecureSkipVerify: true}, 241 }) 242 if err == nil { 243 t.Error("expected invalid host key error, got: nil") 244 } 245 } 246 247 func TestKeepAliveNonBlocking(t *testing.T) { 248 session := &Session{keepAliveReason: make(chan string, 1)} 249 for i := 0; i < 2; i++ { 250 session.KeepAlive("io") 251 } 252 253 // if KeepAlive blocks, we'll never reach this and timeout the test 254 // timing out 255 } 256 257 func TestNotifyHostOfActivity(t *testing.T) { 258 notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { 259 var req []interface{} 260 if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { 261 return nil, fmt.Errorf("unmarshal req: %w", err) 262 } 263 if len(req) < 2 { 264 return nil, errors.New("request arguments is less than 2") 265 } 266 267 if clientName, ok := req[0].(string); ok { 268 if clientName != mockClientName { 269 return nil, fmt.Errorf( 270 "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, 271 ) 272 } 273 } else { 274 return nil, errors.New("clientName param is not a string") 275 } 276 277 if acs, ok := req[1].([]interface{}); ok { 278 if fmt.Sprintf("%s", acs) != "[input]" { 279 return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) 280 } 281 } else { 282 return nil, errors.New("activities param is not a slice") 283 } 284 285 return nil, nil 286 } 287 svc := livesharetest.WithService( 288 "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, 289 ) 290 testServer, session, err := makeMockSession(svc) 291 if err != nil { 292 t.Fatalf("creating mock session: %v", err) 293 } 294 defer testServer.Close() 295 ctx := context.Background() 296 done := make(chan error) 297 go func() { 298 done <- session.notifyHostOfActivity(ctx, "input") 299 }() 300 select { 301 case err := <-testServer.Err(): 302 t.Errorf("error from server: %v", err) 303 case err := <-done: 304 if err != nil { 305 t.Errorf("error from client: %v", err) 306 } 307 } 308 } 309 310 func TestSessionHeartbeat(t *testing.T) { 311 var ( 312 requestsMu sync.Mutex 313 requests int 314 wg sync.WaitGroup 315 ) 316 wg.Add(1) 317 notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { 318 defer wg.Done() 319 requestsMu.Lock() 320 requests++ 321 requestsMu.Unlock() 322 323 var req []interface{} 324 if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { 325 return nil, fmt.Errorf("unmarshal req: %w", err) 326 } 327 if len(req) < 2 { 328 return nil, errors.New("request arguments is less than 2") 329 } 330 331 if clientName, ok := req[0].(string); ok { 332 if clientName != mockClientName { 333 return nil, fmt.Errorf( 334 "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, 335 ) 336 } 337 } else { 338 return nil, errors.New("clientName param is not a string") 339 } 340 341 if acs, ok := req[1].([]interface{}); ok { 342 if fmt.Sprintf("%s", acs) != "[input]" { 343 return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) 344 } 345 } else { 346 return nil, errors.New("activities param is not a slice") 347 } 348 349 return nil, nil 350 } 351 svc := livesharetest.WithService( 352 "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, 353 ) 354 testServer, session, err := makeMockSession(svc) 355 if err != nil { 356 t.Fatalf("creating mock session: %v", err) 357 } 358 defer testServer.Close() 359 360 ctx, cancel := context.WithCancel(context.Background()) 361 defer cancel() 362 363 done := make(chan struct{}) 364 365 logger := newMockLogger() 366 session.logger = logger 367 368 go session.heartbeat(ctx, 50*time.Millisecond) 369 go func() { 370 session.KeepAlive("input") 371 wg.Wait() 372 wg.Add(1) 373 session.KeepAlive("input") 374 wg.Wait() 375 done <- struct{}{} 376 }() 377 378 select { 379 case err := <-testServer.Err(): 380 t.Errorf("error from server: %v", err) 381 case <-done: 382 activityCount := strings.Count(logger.String(), "input") 383 // by design KeepAlive can drop requests, and therefore there is zero guarantee 384 // that we actually get two requests if the network happened to be slow (rarely) 385 // during testing. 386 if activityCount != 1 && activityCount != 2 { 387 t.Errorf("unexpected number of activities, expected: 1-2, got: %d", activityCount) 388 } 389 390 requestsMu.Lock() 391 rc := requests 392 requestsMu.Unlock() 393 // though this could be also dropped, the sync.WaitGroup above guarantees 394 // that it gets called a second time. 395 if rc != 2 { 396 t.Errorf("unexpected number of requests, expected: 2, got: %d", requests) 397 } 398 return 399 } 400 } 401 402 func TestRebuild(t *testing.T) { 403 tests := []struct { 404 fullRebuild bool 405 rpcService string 406 }{ 407 { 408 fullRebuild: false, 409 rpcService: "IEnvironmentConfigurationService.incrementalRebuildContainer", 410 }, 411 { 412 fullRebuild: true, 413 rpcService: "IEnvironmentConfigurationService.rebuildContainer", 414 }, 415 } 416 417 for _, tt := range tests { 418 t.Logf("RPC service: %s", tt.rpcService) 419 t.Logf("full rebuild: %t", tt.fullRebuild) 420 421 requestCount := 0 422 rebuildContainer := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 423 requestCount++ 424 return true, nil 425 } 426 testServer, session, err := makeMockSession( 427 livesharetest.WithService(tt.rpcService, rebuildContainer), 428 ) 429 if err != nil { 430 t.Errorf("creating mock session: %v", err) 431 } 432 defer testServer.Close() 433 434 err = session.RebuildContainer(context.Background(), tt.fullRebuild) 435 if err != nil { 436 t.Errorf("rebuilding codespace via mock session: %v", err) 437 } 438 439 if requestCount == 0 { 440 t.Errorf("no requests were made") 441 } 442 } 443 } 444 445 type mockLogger struct { 446 sync.Mutex 447 buf *bytes.Buffer 448 } 449 450 func newMockLogger() *mockLogger { 451 return &mockLogger{buf: new(bytes.Buffer)} 452 } 453 454 func (m *mockLogger) Printf(format string, v ...interface{}) { 455 m.Lock() 456 defer m.Unlock() 457 m.buf.WriteString(fmt.Sprintf(format, v...)) 458 } 459 460 func (m *mockLogger) Println(v ...interface{}) { 461 m.Lock() 462 defer m.Unlock() 463 m.buf.WriteString(fmt.Sprintln(v...)) 464 } 465 466 func (m *mockLogger) String() string { 467 m.Lock() 468 defer m.Unlock() 469 return m.buf.String() 470 }