github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/port_forwarder_test.go (about)

     1  package liveshare
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"testing"
    12  	"time"
    13  
    14  	livesharetest "github.com/ungtb10d/cli/v2/pkg/liveshare/test"
    15  	"github.com/sourcegraph/jsonrpc2"
    16  )
    17  
    18  func TestNewPortForwarder(t *testing.T) {
    19  	testServer, session, err := makeMockSession()
    20  	if err != nil {
    21  		t.Errorf("create mock client: %v", err)
    22  	}
    23  	defer testServer.Close()
    24  	pf := NewPortForwarder(session, "ssh", 80, false)
    25  	if pf == nil {
    26  		t.Error("port forwarder is nil")
    27  	}
    28  }
    29  
    30  type portUpdateNotification struct {
    31  	PortNotification
    32  	conn *jsonrpc2.Conn
    33  }
    34  
    35  func TestPortForwarderStart(t *testing.T) {
    36  	if os.Getenv("GITHUB_ACTIONS") == "true" {
    37  		t.Skip("fails intermittently in CI: https://github.com/ungtb10d/cli/issues/5338")
    38  	}
    39  
    40  	streamName, streamCondition := "stream-name", "stream-condition"
    41  	const port = 8000
    42  	sendNotification := make(chan portUpdateNotification)
    43  	serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
    44  		// Send the PortNotification that will be awaited on in session.StartSharing
    45  		sendNotification <- portUpdateNotification{
    46  			PortNotification: PortNotification{
    47  				Port:       port,
    48  				ChangeKind: PortChangeKindStart,
    49  			},
    50  			conn: conn,
    51  		}
    52  		return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
    53  	}
    54  	getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
    55  		return "stream-id", nil
    56  	}
    57  
    58  	stream := bytes.NewBufferString("stream-data")
    59  	testServer, session, err := makeMockSession(
    60  		livesharetest.WithService("serverSharing.startSharing", serverSharing),
    61  		livesharetest.WithService("streamManager.getStream", getStream),
    62  		livesharetest.WithStream("stream-id", stream),
    63  	)
    64  	if err != nil {
    65  		t.Errorf("create mock session: %v", err)
    66  	}
    67  	defer testServer.Close()
    68  
    69  	listen, err := net.Listen("tcp", "127.0.0.1:8000")
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  	defer listen.Close()
    74  
    75  	ctx, cancel := context.WithCancel(context.Background())
    76  	defer cancel()
    77  
    78  	go func() {
    79  		notif := <-sendNotification
    80  		_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif)
    81  	}()
    82  
    83  	done := make(chan error, 2)
    84  	go func() {
    85  		done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen)
    86  	}()
    87  
    88  	go func() {
    89  		var conn net.Conn
    90  
    91  		// We retry DialTimeout in a loop to deal with a race in PortForwarder startup.
    92  		for tries := 0; conn == nil && tries < 2; tries++ {
    93  			conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
    94  			if conn == nil {
    95  				time.Sleep(1 * time.Second)
    96  			}
    97  		}
    98  		if conn == nil {
    99  			done <- errors.New("failed to connect to forwarded port")
   100  			return
   101  		}
   102  		b := make([]byte, len("stream-data"))
   103  		if _, err := conn.Read(b); err != nil && err != io.EOF {
   104  			done <- fmt.Errorf("reading stream: %w", err)
   105  			return
   106  		}
   107  		if string(b) != "stream-data" {
   108  			done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
   109  			return
   110  		}
   111  		if _, err := conn.Write([]byte("new-data")); err != nil {
   112  			done <- fmt.Errorf("writing to stream: %w", err)
   113  			return
   114  		}
   115  		done <- nil
   116  	}()
   117  
   118  	select {
   119  	case err := <-testServer.Err():
   120  		t.Errorf("error from server: %v", err)
   121  	case err := <-done:
   122  		if err != nil {
   123  			t.Errorf("error from client: %v", err)
   124  		}
   125  	}
   126  }
   127  
   128  func TestPortForwarderTrafficMonitor(t *testing.T) {
   129  	buf := bytes.NewBufferString("some-input")
   130  	session := &Session{keepAliveReason: make(chan string, 1)}
   131  	trafficType := "io"
   132  
   133  	tm := newTrafficMonitor(buf, session, trafficType)
   134  	l := len(buf.Bytes())
   135  
   136  	bb := make([]byte, l)
   137  	n, err := tm.Read(bb)
   138  	if err != nil {
   139  		t.Errorf("failed to read from traffic monitor: %v", err)
   140  	}
   141  	if n != l {
   142  		t.Errorf("expected to read %d bytes, got %d", l, n)
   143  	}
   144  
   145  	keepAliveReason := <-session.keepAliveReason
   146  	if keepAliveReason != trafficType {
   147  		t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
   148  	}
   149  }