github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/port_forwarder_test.go (about) 1 package liveshare 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "os" 11 "testing" 12 "time" 13 14 livesharetest "github.com/ungtb10d/cli/v2/pkg/liveshare/test" 15 "github.com/sourcegraph/jsonrpc2" 16 ) 17 18 func TestNewPortForwarder(t *testing.T) { 19 testServer, session, err := makeMockSession() 20 if err != nil { 21 t.Errorf("create mock client: %v", err) 22 } 23 defer testServer.Close() 24 pf := NewPortForwarder(session, "ssh", 80, false) 25 if pf == nil { 26 t.Error("port forwarder is nil") 27 } 28 } 29 30 type portUpdateNotification struct { 31 PortNotification 32 conn *jsonrpc2.Conn 33 } 34 35 func TestPortForwarderStart(t *testing.T) { 36 if os.Getenv("GITHUB_ACTIONS") == "true" { 37 t.Skip("fails intermittently in CI: https://github.com/ungtb10d/cli/issues/5338") 38 } 39 40 streamName, streamCondition := "stream-name", "stream-condition" 41 const port = 8000 42 sendNotification := make(chan portUpdateNotification) 43 serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 44 // Send the PortNotification that will be awaited on in session.StartSharing 45 sendNotification <- portUpdateNotification{ 46 PortNotification: PortNotification{ 47 Port: port, 48 ChangeKind: PortChangeKindStart, 49 }, 50 conn: conn, 51 } 52 return Port{StreamName: streamName, StreamCondition: streamCondition}, nil 53 } 54 getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { 55 return "stream-id", nil 56 } 57 58 stream := bytes.NewBufferString("stream-data") 59 testServer, session, err := makeMockSession( 60 livesharetest.WithService("serverSharing.startSharing", serverSharing), 61 livesharetest.WithService("streamManager.getStream", getStream), 62 livesharetest.WithStream("stream-id", stream), 63 ) 64 if err != nil { 65 t.Errorf("create mock session: %v", err) 66 } 67 defer testServer.Close() 68 69 listen, err := net.Listen("tcp", "127.0.0.1:8000") 70 if err != nil { 71 t.Fatal(err) 72 } 73 defer listen.Close() 74 75 ctx, cancel := context.WithCancel(context.Background()) 76 defer cancel() 77 78 go func() { 79 notif := <-sendNotification 80 _, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif) 81 }() 82 83 done := make(chan error, 2) 84 go func() { 85 done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen) 86 }() 87 88 go func() { 89 var conn net.Conn 90 91 // We retry DialTimeout in a loop to deal with a race in PortForwarder startup. 92 for tries := 0; conn == nil && tries < 2; tries++ { 93 conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) 94 if conn == nil { 95 time.Sleep(1 * time.Second) 96 } 97 } 98 if conn == nil { 99 done <- errors.New("failed to connect to forwarded port") 100 return 101 } 102 b := make([]byte, len("stream-data")) 103 if _, err := conn.Read(b); err != nil && err != io.EOF { 104 done <- fmt.Errorf("reading stream: %w", err) 105 return 106 } 107 if string(b) != "stream-data" { 108 done <- fmt.Errorf("stream data is not expected value, got: %s", string(b)) 109 return 110 } 111 if _, err := conn.Write([]byte("new-data")); err != nil { 112 done <- fmt.Errorf("writing to stream: %w", err) 113 return 114 } 115 done <- nil 116 }() 117 118 select { 119 case err := <-testServer.Err(): 120 t.Errorf("error from server: %v", err) 121 case err := <-done: 122 if err != nil { 123 t.Errorf("error from client: %v", err) 124 } 125 } 126 } 127 128 func TestPortForwarderTrafficMonitor(t *testing.T) { 129 buf := bytes.NewBufferString("some-input") 130 session := &Session{keepAliveReason: make(chan string, 1)} 131 trafficType := "io" 132 133 tm := newTrafficMonitor(buf, session, trafficType) 134 l := len(buf.Bytes()) 135 136 bb := make([]byte, l) 137 n, err := tm.Read(bb) 138 if err != nil { 139 t.Errorf("failed to read from traffic monitor: %v", err) 140 } 141 if n != l { 142 t.Errorf("expected to read %d bytes, got %d", l, n) 143 } 144 145 keepAliveReason := <-session.keepAliveReason 146 if keepAliveReason != trafficType { 147 t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason) 148 } 149 }