github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/tunnel/stream_test.go (about) 1 package tunnel 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "net" 9 "sync" 10 "testing" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 17 "github.com/datawire/dlib/dlog" 18 "github.com/telepresenceio/telepresence/rpc/v2/manager" 19 "github.com/telepresenceio/telepresence/v2/pkg/ipproto" 20 "github.com/telepresenceio/telepresence/v2/pkg/iputil" 21 "github.com/telepresenceio/telepresence/v2/pkg/log" 22 ) 23 24 type uni struct { 25 done <-chan struct{} 26 ch chan *manager.TunnelMessage 27 } 28 29 type bidi struct { 30 cToS *uni 31 sToC *uni 32 } 33 34 func newUni(bufSize int, done <-chan struct{}) *uni { 35 return &uni{ch: make(chan *manager.TunnelMessage, bufSize), done: done} 36 } 37 38 func newBidi(bufSize int, done <-chan struct{}) *bidi { 39 return &bidi{cToS: newUni(bufSize, done), sToC: newUni(bufSize, done)} 40 } 41 42 func (t *uni) recv() (*manager.TunnelMessage, error) { 43 select { 44 case <-t.done: 45 return nil, context.Canceled 46 case m := <-t.ch: 47 if m == nil { 48 return nil, net.ErrClosed 49 } 50 // Simulate a network latency of one microsecond per byte 51 time.Sleep(time.Duration(len(m.Payload)) * time.Microsecond) 52 return m, nil 53 } 54 } 55 56 func (t *uni) send(msg *manager.TunnelMessage) error { 57 select { 58 case <-t.done: 59 return context.Canceled 60 case t.ch <- msg: 61 return nil 62 } 63 } 64 65 func (t *uni) close() error { 66 close(t.ch) 67 return nil 68 } 69 70 func (t *bidi) clientSide() GRPCClientStream { 71 return &clientSide{t} 72 } 73 74 func (t *bidi) serverSide() GRPCStream { 75 return &serverSide{t} 76 } 77 78 type clientSide struct { 79 *bidi 80 } 81 82 func (c *clientSide) Recv() (*manager.TunnelMessage, error) { 83 return c.sToC.recv() 84 } 85 86 func (c *clientSide) Send(msg *manager.TunnelMessage) error { 87 return c.cToS.send(msg) 88 } 89 90 func (c *clientSide) CloseSend() error { 91 return c.cToS.close() 92 } 93 94 type serverSide struct { 95 *bidi 96 } 97 98 func (c *serverSide) Recv() (*manager.TunnelMessage, error) { 99 return c.cToS.recv() 100 } 101 102 func (c *serverSide) Send(msg *manager.TunnelMessage) error { 103 return c.sToC.send(msg) 104 } 105 106 func testContext(t *testing.T, timeout time.Duration) (context.Context, context.CancelFunc) { 107 return context.WithTimeout(dlog.WithLogger(context.Background(), log.NewTestLogger(t, dlog.LogLevelDebug)), timeout) 108 } 109 110 func TestStream_Connect(t *testing.T) { 111 ctx, cancel := testContext(t, time.Second) 112 defer cancel() 113 114 tunnel := newBidi(10, ctx.Done()) 115 id := NewConnID(ipproto.TCP, iputil.Parse("127.0.0.1"), iputil.Parse("192.168.0.1"), 1001, 8080) 116 si := uuid.New().String() 117 118 wg := sync.WaitGroup{} 119 wg.Add(2) 120 go func() { 121 defer wg.Done() 122 client, err := NewClientStream(ctx, tunnel.clientSide(), id, si, 0, 0) 123 require.NoError(t, err) 124 assert.Equal(t, Version, client.PeerVersion()) 125 assert.NoError(t, client.CloseSend(ctx)) 126 }() 127 128 go func() { 129 defer wg.Done() 130 server, err := NewServerStream(ctx, tunnel.serverSide()) 131 require.NoError(t, err) 132 assert.Equal(t, id, server.ID()) 133 assert.Equal(t, Version, server.PeerVersion()) 134 assert.Equal(t, si, server.SessionID()) 135 }() 136 wg.Wait() 137 } 138 139 func produce(ctx context.Context, s Stream, msg Message, errs chan<- error) { 140 wrCh := make(chan Message) 141 wg := sync.WaitGroup{} 142 wg.Add(1) 143 WriteLoop(ctx, s, wrCh, &wg, nil) 144 go func() { 145 for i := 0; i < 100; i++ { 146 wrCh <- msg 147 } 148 close(wrCh) 149 wg.Wait() 150 }() 151 152 rdCh, errCh := ReadLoop(ctx, s, nil) 153 select { 154 case <-ctx.Done(): 155 errs <- ctx.Err() 156 case err, ok := <-errCh: 157 if ok { 158 errs <- err 159 } 160 case m, ok := <-rdCh: 161 if ok { 162 errs <- fmt.Errorf("unexpected message: %s", m) 163 } 164 } 165 } 166 167 func consume(ctx context.Context, s Stream, expectedPayload []byte, errs chan<- error) { 168 count := 0 169 wrCh := make(chan Message) 170 wg := sync.WaitGroup{} 171 wg.Add(1) 172 WriteLoop(ctx, s, wrCh, &wg, nil) 173 defer close(wrCh) 174 rdCh, errCh := ReadLoop(ctx, s, nil) 175 for { 176 select { 177 case <-ctx.Done(): 178 errs <- ctx.Err() 179 case err, ok := <-errCh: 180 if ok { 181 errs <- err 182 } 183 case m, ok := <-rdCh: 184 if !ok { 185 return 186 } 187 if m.Code() != Normal { 188 errs <- fmt.Errorf("unexpected message code %s", m.Code()) 189 return 190 } 191 if !bytes.Equal(expectedPayload, m.Payload()) { 192 errs <- errors.New("unexpected message content") 193 return 194 } 195 count++ 196 } 197 } 198 } 199 200 func requireNoErrs(t *testing.T, errs chan error) chan error { 201 t.Helper() 202 close(errs) 203 for err := range errs { 204 assert.NoError(t, err) 205 } 206 if t.Failed() { 207 t.FailNow() 208 } 209 return make(chan error, 10) 210 } 211 212 func TestStream_Xfer(t *testing.T) { 213 ctx, cancel := testContext(t, 30*time.Second) 214 defer cancel() 215 216 id := NewConnID(ipproto.TCP, iputil.Parse("127.0.0.1"), iputil.Parse("192.168.0.1"), 1001, 8080) 217 si := uuid.New().String() 218 b := make([]byte, 0x1000) 219 for i := range b { 220 b[i] = byte(i & 0xff) 221 } 222 large := NewMessage(Normal, b) 223 errs := make(chan error, 10) 224 225 // Send data from client to server 226 t.Run("client to server", func(t *testing.T) { 227 tunnel := newBidi(10, ctx.Done()) 228 wg := sync.WaitGroup{} 229 wg.Add(2) 230 go func() { 231 defer wg.Done() 232 if client, err := NewClientStream(ctx, tunnel.clientSide(), id, si, 0, 0); err != nil { 233 errs <- err 234 } else { 235 produce(ctx, client, large, errs) 236 } 237 }() 238 go func() { 239 defer wg.Done() 240 if server, err := NewServerStream(ctx, tunnel.serverSide()); err != nil { 241 errs <- err 242 } else { 243 consume(ctx, server, b, errs) 244 } 245 }() 246 wg.Wait() 247 errs = requireNoErrs(t, errs) 248 }) 249 250 t.Run("server to client", func(t *testing.T) { 251 tunnel := newBidi(10, ctx.Done()) 252 wg := sync.WaitGroup{} 253 wg.Add(2) 254 go func() { 255 defer wg.Done() 256 if server, err := NewServerStream(ctx, tunnel.serverSide()); err != nil { 257 errs <- err 258 } else { 259 produce(ctx, server, large, errs) 260 } 261 }() 262 go func() { 263 defer wg.Done() 264 if client, err := NewClientStream(ctx, tunnel.clientSide(), id, si, 0, 0); err != nil { 265 errs <- err 266 } else { 267 consume(ctx, client, b, errs) 268 } 269 }() 270 wg.Wait() 271 errs = requireNoErrs(t, errs) 272 }) 273 274 t.Run("client to client over BidiPipe", func(t *testing.T) { 275 ta := newBidi(10, ctx.Done()) 276 tb := newBidi(10, ctx.Done()) 277 278 var counter int32 279 aCh := make(chan Stream) 280 bCh := make(chan Stream) 281 wg := sync.WaitGroup{} 282 wg.Add(5) 283 go func() { 284 defer wg.Done() 285 if s, err := NewServerStream(ctx, ta.serverSide()); err != nil { 286 errs <- err 287 close(aCh) 288 } else { 289 aCh <- s 290 } 291 }() 292 go func() { 293 defer wg.Done() 294 if s, err := NewServerStream(ctx, tb.serverSide()); err != nil { 295 errs <- err 296 close(bCh) 297 } else { 298 bCh <- s 299 } 300 }() 301 go func() { 302 defer wg.Done() 303 if server, err := NewClientStream(ctx, ta.clientSide(), id, si, 0, 0); err != nil { 304 errs <- err 305 } else { 306 produce(ctx, server, large, errs) 307 } 308 }() 309 go func() { 310 defer wg.Done() 311 if client, err := NewClientStream(ctx, tb.clientSide(), id, si, 0, 0); err != nil { 312 errs <- err 313 } else { 314 consume(ctx, client, b, errs) 315 } 316 }() 317 go func() { 318 defer wg.Done() 319 var a, b Stream 320 for a == nil || b == nil { 321 select { 322 case <-ctx.Done(): 323 errs <- ctx.Err() 324 return 325 case a = <-aCh: 326 case b = <-bCh: 327 } 328 } 329 fwd := NewBidiPipe(a, b, "pipe", &counter, nil) 330 fwd.Start(ctx) 331 select { 332 case <-ctx.Done(): 333 errs <- ctx.Err() 334 case <-fwd.Done(): 335 } 336 }() 337 wg.Wait() 338 errs = requireNoErrs(t, errs) 339 }) 340 }