github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/session_test.go (about)

     1  package liveshare
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	livesharetest "github.com/ungtb10d/cli/v2/pkg/liveshare/test"
    16  	"github.com/sourcegraph/jsonrpc2"
    17  )
    18  
    19  const mockClientName = "liveshare-client"
    20  
    21  func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
    22  	joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
    23  		return joinWorkspaceResult{1}, nil
    24  	}
    25  	const sessionToken = "session-token"
    26  	opts = append(
    27  		opts,
    28  		livesharetest.WithPassword(sessionToken),
    29  		livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
    30  	)
    31  	testServer, err := livesharetest.NewServer(opts...)
    32  	if err != nil {
    33  		return nil, nil, fmt.Errorf("error creating server: %w", err)
    34  	}
    35  
    36  	session, err := Connect(context.Background(), Options{
    37  		ClientName:     mockClientName,
    38  		SessionID:      "session-id",
    39  		SessionToken:   sessionToken,
    40  		RelayEndpoint:  "sb" + strings.TrimPrefix(testServer.URL(), "https"),
    41  		RelaySAS:       "relay-sas",
    42  		HostPublicKeys: []string{livesharetest.SSHPublicKey},
    43  		TLSConfig:      &tls.Config{InsecureSkipVerify: true},
    44  		Logger:         newMockLogger(),
    45  	})
    46  	if err != nil {
    47  		return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
    48  	}
    49  	return testServer, session, nil
    50  }
    51  
    52  func TestServerStartSharing(t *testing.T) {
    53  	serverPort, serverProtocol := 2222, "sshd"
    54  	sendNotification := make(chan portUpdateNotification)
    55  	startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
    56  		var args []interface{}
    57  		if err := json.Unmarshal(*req.Params, &args); err != nil {
    58  			return nil, fmt.Errorf("error unmarshaling request: %w", err)
    59  		}
    60  		if len(args) < 3 {
    61  			return nil, errors.New("not enough arguments to start sharing")
    62  		}
    63  		port, ok := args[0].(float64)
    64  		if !ok {
    65  			return nil, errors.New("port argument is not an int")
    66  		}
    67  		if port != float64(serverPort) {
    68  			return nil, errors.New("port does not match serverPort")
    69  		}
    70  		if protocol, ok := args[1].(string); !ok {
    71  			return nil, errors.New("protocol argument is not a string")
    72  		} else if protocol != serverProtocol {
    73  			return nil, errors.New("protocol does not match serverProtocol")
    74  		}
    75  		if browseURL, ok := args[2].(string); !ok {
    76  			return nil, errors.New("browse url is not a string")
    77  		} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
    78  			return nil, errors.New("browseURL does not match expected")
    79  		}
    80  		sendNotification <- portUpdateNotification{
    81  			PortNotification: PortNotification{
    82  				Port:       int(port),
    83  				ChangeKind: PortChangeKindStart,
    84  			},
    85  			conn: conn,
    86  		}
    87  		return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
    88  	}
    89  	testServer, session, err := makeMockSession(
    90  		livesharetest.WithService("serverSharing.startSharing", startSharing),
    91  	)
    92  	defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
    93  
    94  	if err != nil {
    95  		t.Errorf("error creating mock session: %v", err)
    96  	}
    97  	ctx := context.Background()
    98  
    99  	go func() {
   100  		notif := <-sendNotification
   101  		_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif)
   102  	}()
   103  
   104  	done := make(chan error)
   105  	go func() {
   106  		streamID, err := session.StartSharing(ctx, serverProtocol, serverPort)
   107  		if err != nil {
   108  			done <- fmt.Errorf("error sharing server: %w", err)
   109  		}
   110  		if streamID.name == "" || streamID.condition == "" {
   111  			done <- errors.New("stream name or condition is blank")
   112  		}
   113  		done <- nil
   114  	}()
   115  
   116  	select {
   117  	case err := <-testServer.Err():
   118  		t.Errorf("error from server: %v", err)
   119  	case err := <-done:
   120  		if err != nil {
   121  			t.Errorf("error from client: %v", err)
   122  		}
   123  	}
   124  }
   125  
   126  func TestServerGetSharedServers(t *testing.T) {
   127  	sharedServer := Port{
   128  		SourcePort:      2222,
   129  		StreamName:      "stream-name",
   130  		StreamCondition: "stream-condition",
   131  	}
   132  	getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
   133  		return []*Port{&sharedServer}, nil
   134  	}
   135  	testServer, session, err := makeMockSession(
   136  		livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
   137  	)
   138  	if err != nil {
   139  		t.Errorf("error creating mock session: %v", err)
   140  	}
   141  	defer testServer.Close()
   142  	ctx := context.Background()
   143  	done := make(chan error)
   144  	go func() {
   145  		ports, err := session.GetSharedServers(ctx)
   146  		if err != nil {
   147  			done <- fmt.Errorf("error getting shared servers: %w", err)
   148  		}
   149  		if len(ports) < 1 {
   150  			done <- errors.New("not enough ports returned")
   151  		}
   152  		if ports[0].SourcePort != sharedServer.SourcePort {
   153  			done <- errors.New("source port does not match")
   154  		}
   155  		if ports[0].StreamName != sharedServer.StreamName {
   156  			done <- errors.New("stream name does not match")
   157  		}
   158  		if ports[0].StreamCondition != sharedServer.StreamCondition {
   159  			done <- errors.New("stream condiion does not match")
   160  		}
   161  		done <- nil
   162  	}()
   163  
   164  	select {
   165  	case err := <-testServer.Err():
   166  		t.Errorf("error from server: %v", err)
   167  	case err := <-done:
   168  		if err != nil {
   169  			t.Errorf("error from client: %v", err)
   170  		}
   171  	}
   172  }
   173  
   174  func TestServerUpdateSharedServerPrivacy(t *testing.T) {
   175  	updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
   176  		var req []interface{}
   177  		if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
   178  			return nil, fmt.Errorf("unmarshal req: %w", err)
   179  		}
   180  		if len(req) < 2 {
   181  			return nil, errors.New("request arguments is less than 2")
   182  		}
   183  		if port, ok := req[0].(float64); ok {
   184  			if port != 80.0 {
   185  				return nil, errors.New("port param is not expected value")
   186  			}
   187  		} else {
   188  			return nil, errors.New("port param is not a float64")
   189  		}
   190  		if privacy, ok := req[1].(string); ok {
   191  			if privacy != "public" {
   192  				return nil, fmt.Errorf("expected privacy param to be public but got %q", privacy)
   193  			}
   194  		} else {
   195  			return nil, fmt.Errorf("expected privacy param to be a bool but go %T", req[1])
   196  		}
   197  		return nil, nil
   198  	}
   199  	testServer, session, err := makeMockSession(
   200  		livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
   201  	)
   202  	if err != nil {
   203  		t.Errorf("creating mock session: %v", err)
   204  	}
   205  	defer testServer.Close()
   206  	ctx := context.Background()
   207  	done := make(chan error)
   208  	go func() {
   209  		done <- session.UpdateSharedServerPrivacy(ctx, 80, "public")
   210  	}()
   211  	select {
   212  	case err := <-testServer.Err():
   213  		t.Errorf("error from server: %v", err)
   214  	case err := <-done:
   215  		if err != nil {
   216  			t.Errorf("error from client: %v", err)
   217  		}
   218  	}
   219  }
   220  
   221  func TestInvalidHostKey(t *testing.T) {
   222  	joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
   223  		return joinWorkspaceResult{1}, nil
   224  	}
   225  	const sessionToken = "session-token"
   226  	opts := []livesharetest.ServerOption{
   227  		livesharetest.WithPassword(sessionToken),
   228  		livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
   229  	}
   230  	testServer, err := livesharetest.NewServer(opts...)
   231  	if err != nil {
   232  		t.Errorf("error creating server: %v", err)
   233  	}
   234  	_, err = Connect(context.Background(), Options{
   235  		SessionID:      "session-id",
   236  		SessionToken:   sessionToken,
   237  		RelayEndpoint:  "sb" + strings.TrimPrefix(testServer.URL(), "https"),
   238  		RelaySAS:       "relay-sas",
   239  		HostPublicKeys: []string{},
   240  		TLSConfig:      &tls.Config{InsecureSkipVerify: true},
   241  	})
   242  	if err == nil {
   243  		t.Error("expected invalid host key error, got: nil")
   244  	}
   245  }
   246  
   247  func TestKeepAliveNonBlocking(t *testing.T) {
   248  	session := &Session{keepAliveReason: make(chan string, 1)}
   249  	for i := 0; i < 2; i++ {
   250  		session.KeepAlive("io")
   251  	}
   252  
   253  	// if KeepAlive blocks, we'll never reach this and timeout the test
   254  	// timing out
   255  }
   256  
   257  func TestNotifyHostOfActivity(t *testing.T) {
   258  	notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
   259  		var req []interface{}
   260  		if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
   261  			return nil, fmt.Errorf("unmarshal req: %w", err)
   262  		}
   263  		if len(req) < 2 {
   264  			return nil, errors.New("request arguments is less than 2")
   265  		}
   266  
   267  		if clientName, ok := req[0].(string); ok {
   268  			if clientName != mockClientName {
   269  				return nil, fmt.Errorf(
   270  					"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
   271  				)
   272  			}
   273  		} else {
   274  			return nil, errors.New("clientName param is not a string")
   275  		}
   276  
   277  		if acs, ok := req[1].([]interface{}); ok {
   278  			if fmt.Sprintf("%s", acs) != "[input]" {
   279  				return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
   280  			}
   281  		} else {
   282  			return nil, errors.New("activities param is not a slice")
   283  		}
   284  
   285  		return nil, nil
   286  	}
   287  	svc := livesharetest.WithService(
   288  		"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
   289  	)
   290  	testServer, session, err := makeMockSession(svc)
   291  	if err != nil {
   292  		t.Fatalf("creating mock session: %v", err)
   293  	}
   294  	defer testServer.Close()
   295  	ctx := context.Background()
   296  	done := make(chan error)
   297  	go func() {
   298  		done <- session.notifyHostOfActivity(ctx, "input")
   299  	}()
   300  	select {
   301  	case err := <-testServer.Err():
   302  		t.Errorf("error from server: %v", err)
   303  	case err := <-done:
   304  		if err != nil {
   305  			t.Errorf("error from client: %v", err)
   306  		}
   307  	}
   308  }
   309  
   310  func TestSessionHeartbeat(t *testing.T) {
   311  	var (
   312  		requestsMu sync.Mutex
   313  		requests   int
   314  		wg         sync.WaitGroup
   315  	)
   316  	wg.Add(1)
   317  	notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
   318  		defer wg.Done()
   319  		requestsMu.Lock()
   320  		requests++
   321  		requestsMu.Unlock()
   322  
   323  		var req []interface{}
   324  		if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
   325  			return nil, fmt.Errorf("unmarshal req: %w", err)
   326  		}
   327  		if len(req) < 2 {
   328  			return nil, errors.New("request arguments is less than 2")
   329  		}
   330  
   331  		if clientName, ok := req[0].(string); ok {
   332  			if clientName != mockClientName {
   333  				return nil, fmt.Errorf(
   334  					"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
   335  				)
   336  			}
   337  		} else {
   338  			return nil, errors.New("clientName param is not a string")
   339  		}
   340  
   341  		if acs, ok := req[1].([]interface{}); ok {
   342  			if fmt.Sprintf("%s", acs) != "[input]" {
   343  				return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
   344  			}
   345  		} else {
   346  			return nil, errors.New("activities param is not a slice")
   347  		}
   348  
   349  		return nil, nil
   350  	}
   351  	svc := livesharetest.WithService(
   352  		"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
   353  	)
   354  	testServer, session, err := makeMockSession(svc)
   355  	if err != nil {
   356  		t.Fatalf("creating mock session: %v", err)
   357  	}
   358  	defer testServer.Close()
   359  
   360  	ctx, cancel := context.WithCancel(context.Background())
   361  	defer cancel()
   362  
   363  	done := make(chan struct{})
   364  
   365  	logger := newMockLogger()
   366  	session.logger = logger
   367  
   368  	go session.heartbeat(ctx, 50*time.Millisecond)
   369  	go func() {
   370  		session.KeepAlive("input")
   371  		wg.Wait()
   372  		wg.Add(1)
   373  		session.KeepAlive("input")
   374  		wg.Wait()
   375  		done <- struct{}{}
   376  	}()
   377  
   378  	select {
   379  	case err := <-testServer.Err():
   380  		t.Errorf("error from server: %v", err)
   381  	case <-done:
   382  		activityCount := strings.Count(logger.String(), "input")
   383  		// by design KeepAlive can drop requests, and therefore there is zero guarantee
   384  		// that we actually get two requests if the network happened to be slow (rarely)
   385  		// during testing.
   386  		if activityCount != 1 && activityCount != 2 {
   387  			t.Errorf("unexpected number of activities, expected: 1-2, got: %d", activityCount)
   388  		}
   389  
   390  		requestsMu.Lock()
   391  		rc := requests
   392  		requestsMu.Unlock()
   393  		// though this could be also dropped, the sync.WaitGroup above guarantees
   394  		// that it gets called a second time.
   395  		if rc != 2 {
   396  			t.Errorf("unexpected number of requests, expected: 2, got: %d", requests)
   397  		}
   398  		return
   399  	}
   400  }
   401  
   402  func TestRebuild(t *testing.T) {
   403  	tests := []struct {
   404  		fullRebuild bool
   405  		rpcService  string
   406  	}{
   407  		{
   408  			fullRebuild: false,
   409  			rpcService:  "IEnvironmentConfigurationService.incrementalRebuildContainer",
   410  		},
   411  		{
   412  			fullRebuild: true,
   413  			rpcService:  "IEnvironmentConfigurationService.rebuildContainer",
   414  		},
   415  	}
   416  
   417  	for _, tt := range tests {
   418  		t.Logf("RPC service: %s", tt.rpcService)
   419  		t.Logf("full rebuild: %t", tt.fullRebuild)
   420  
   421  		requestCount := 0
   422  		rebuildContainer := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
   423  			requestCount++
   424  			return true, nil
   425  		}
   426  		testServer, session, err := makeMockSession(
   427  			livesharetest.WithService(tt.rpcService, rebuildContainer),
   428  		)
   429  		if err != nil {
   430  			t.Errorf("creating mock session: %v", err)
   431  		}
   432  		defer testServer.Close()
   433  
   434  		err = session.RebuildContainer(context.Background(), tt.fullRebuild)
   435  		if err != nil {
   436  			t.Errorf("rebuilding codespace via mock session: %v", err)
   437  		}
   438  
   439  		if requestCount == 0 {
   440  			t.Errorf("no requests were made")
   441  		}
   442  	}
   443  }
   444  
   445  type mockLogger struct {
   446  	sync.Mutex
   447  	buf *bytes.Buffer
   448  }
   449  
   450  func newMockLogger() *mockLogger {
   451  	return &mockLogger{buf: new(bytes.Buffer)}
   452  }
   453  
   454  func (m *mockLogger) Printf(format string, v ...interface{}) {
   455  	m.Lock()
   456  	defer m.Unlock()
   457  	m.buf.WriteString(fmt.Sprintf(format, v...))
   458  }
   459  
   460  func (m *mockLogger) Println(v ...interface{}) {
   461  	m.Lock()
   462  	defer m.Unlock()
   463  	m.buf.WriteString(fmt.Sprintln(v...))
   464  }
   465  
   466  func (m *mockLogger) String() string {
   467  	m.Lock()
   468  	defer m.Unlock()
   469  	return m.buf.String()
   470  }