github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/tunnel/stream_test.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/google/uuid"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/datawire/dlib/dlog"
    18  	"github.com/telepresenceio/telepresence/rpc/v2/manager"
    19  	"github.com/telepresenceio/telepresence/v2/pkg/ipproto"
    20  	"github.com/telepresenceio/telepresence/v2/pkg/iputil"
    21  	"github.com/telepresenceio/telepresence/v2/pkg/log"
    22  )
    23  
    24  type uni struct {
    25  	done <-chan struct{}
    26  	ch   chan *manager.TunnelMessage
    27  }
    28  
    29  type bidi struct {
    30  	cToS *uni
    31  	sToC *uni
    32  }
    33  
    34  func newUni(bufSize int, done <-chan struct{}) *uni {
    35  	return &uni{ch: make(chan *manager.TunnelMessage, bufSize), done: done}
    36  }
    37  
    38  func newBidi(bufSize int, done <-chan struct{}) *bidi {
    39  	return &bidi{cToS: newUni(bufSize, done), sToC: newUni(bufSize, done)}
    40  }
    41  
    42  func (t *uni) recv() (*manager.TunnelMessage, error) {
    43  	select {
    44  	case <-t.done:
    45  		return nil, context.Canceled
    46  	case m := <-t.ch:
    47  		if m == nil {
    48  			return nil, net.ErrClosed
    49  		}
    50  		// Simulate a network latency of one microsecond per byte
    51  		time.Sleep(time.Duration(len(m.Payload)) * time.Microsecond)
    52  		return m, nil
    53  	}
    54  }
    55  
    56  func (t *uni) send(msg *manager.TunnelMessage) error {
    57  	select {
    58  	case <-t.done:
    59  		return context.Canceled
    60  	case t.ch <- msg:
    61  		return nil
    62  	}
    63  }
    64  
    65  func (t *uni) close() error {
    66  	close(t.ch)
    67  	return nil
    68  }
    69  
    70  func (t *bidi) clientSide() GRPCClientStream {
    71  	return &clientSide{t}
    72  }
    73  
    74  func (t *bidi) serverSide() GRPCStream {
    75  	return &serverSide{t}
    76  }
    77  
    78  type clientSide struct {
    79  	*bidi
    80  }
    81  
    82  func (c *clientSide) Recv() (*manager.TunnelMessage, error) {
    83  	return c.sToC.recv()
    84  }
    85  
    86  func (c *clientSide) Send(msg *manager.TunnelMessage) error {
    87  	return c.cToS.send(msg)
    88  }
    89  
    90  func (c *clientSide) CloseSend() error {
    91  	return c.cToS.close()
    92  }
    93  
    94  type serverSide struct {
    95  	*bidi
    96  }
    97  
    98  func (c *serverSide) Recv() (*manager.TunnelMessage, error) {
    99  	return c.cToS.recv()
   100  }
   101  
   102  func (c *serverSide) Send(msg *manager.TunnelMessage) error {
   103  	return c.sToC.send(msg)
   104  }
   105  
   106  func testContext(t *testing.T, timeout time.Duration) (context.Context, context.CancelFunc) {
   107  	return context.WithTimeout(dlog.WithLogger(context.Background(), log.NewTestLogger(t, dlog.LogLevelDebug)), timeout)
   108  }
   109  
   110  func TestStream_Connect(t *testing.T) {
   111  	ctx, cancel := testContext(t, time.Second)
   112  	defer cancel()
   113  
   114  	tunnel := newBidi(10, ctx.Done())
   115  	id := NewConnID(ipproto.TCP, iputil.Parse("127.0.0.1"), iputil.Parse("192.168.0.1"), 1001, 8080)
   116  	si := uuid.New().String()
   117  
   118  	wg := sync.WaitGroup{}
   119  	wg.Add(2)
   120  	go func() {
   121  		defer wg.Done()
   122  		client, err := NewClientStream(ctx, tunnel.clientSide(), id, si, 0, 0)
   123  		require.NoError(t, err)
   124  		assert.Equal(t, Version, client.PeerVersion())
   125  		assert.NoError(t, client.CloseSend(ctx))
   126  	}()
   127  
   128  	go func() {
   129  		defer wg.Done()
   130  		server, err := NewServerStream(ctx, tunnel.serverSide())
   131  		require.NoError(t, err)
   132  		assert.Equal(t, id, server.ID())
   133  		assert.Equal(t, Version, server.PeerVersion())
   134  		assert.Equal(t, si, server.SessionID())
   135  	}()
   136  	wg.Wait()
   137  }
   138  
   139  func produce(ctx context.Context, s Stream, msg Message, errs chan<- error) {
   140  	wrCh := make(chan Message)
   141  	wg := sync.WaitGroup{}
   142  	wg.Add(1)
   143  	WriteLoop(ctx, s, wrCh, &wg, nil)
   144  	go func() {
   145  		for i := 0; i < 100; i++ {
   146  			wrCh <- msg
   147  		}
   148  		close(wrCh)
   149  		wg.Wait()
   150  	}()
   151  
   152  	rdCh, errCh := ReadLoop(ctx, s, nil)
   153  	select {
   154  	case <-ctx.Done():
   155  		errs <- ctx.Err()
   156  	case err, ok := <-errCh:
   157  		if ok {
   158  			errs <- err
   159  		}
   160  	case m, ok := <-rdCh:
   161  		if ok {
   162  			errs <- fmt.Errorf("unexpected message: %s", m)
   163  		}
   164  	}
   165  }
   166  
   167  func consume(ctx context.Context, s Stream, expectedPayload []byte, errs chan<- error) {
   168  	count := 0
   169  	wrCh := make(chan Message)
   170  	wg := sync.WaitGroup{}
   171  	wg.Add(1)
   172  	WriteLoop(ctx, s, wrCh, &wg, nil)
   173  	defer close(wrCh)
   174  	rdCh, errCh := ReadLoop(ctx, s, nil)
   175  	for {
   176  		select {
   177  		case <-ctx.Done():
   178  			errs <- ctx.Err()
   179  		case err, ok := <-errCh:
   180  			if ok {
   181  				errs <- err
   182  			}
   183  		case m, ok := <-rdCh:
   184  			if !ok {
   185  				return
   186  			}
   187  			if m.Code() != Normal {
   188  				errs <- fmt.Errorf("unexpected message code %s", m.Code())
   189  				return
   190  			}
   191  			if !bytes.Equal(expectedPayload, m.Payload()) {
   192  				errs <- errors.New("unexpected message content")
   193  				return
   194  			}
   195  			count++
   196  		}
   197  	}
   198  }
   199  
   200  func requireNoErrs(t *testing.T, errs chan error) chan error {
   201  	t.Helper()
   202  	close(errs)
   203  	for err := range errs {
   204  		assert.NoError(t, err)
   205  	}
   206  	if t.Failed() {
   207  		t.FailNow()
   208  	}
   209  	return make(chan error, 10)
   210  }
   211  
   212  func TestStream_Xfer(t *testing.T) {
   213  	ctx, cancel := testContext(t, 30*time.Second)
   214  	defer cancel()
   215  
   216  	id := NewConnID(ipproto.TCP, iputil.Parse("127.0.0.1"), iputil.Parse("192.168.0.1"), 1001, 8080)
   217  	si := uuid.New().String()
   218  	b := make([]byte, 0x1000)
   219  	for i := range b {
   220  		b[i] = byte(i & 0xff)
   221  	}
   222  	large := NewMessage(Normal, b)
   223  	errs := make(chan error, 10)
   224  
   225  	// Send data from client to server
   226  	t.Run("client to server", func(t *testing.T) {
   227  		tunnel := newBidi(10, ctx.Done())
   228  		wg := sync.WaitGroup{}
   229  		wg.Add(2)
   230  		go func() {
   231  			defer wg.Done()
   232  			if client, err := NewClientStream(ctx, tunnel.clientSide(), id, si, 0, 0); err != nil {
   233  				errs <- err
   234  			} else {
   235  				produce(ctx, client, large, errs)
   236  			}
   237  		}()
   238  		go func() {
   239  			defer wg.Done()
   240  			if server, err := NewServerStream(ctx, tunnel.serverSide()); err != nil {
   241  				errs <- err
   242  			} else {
   243  				consume(ctx, server, b, errs)
   244  			}
   245  		}()
   246  		wg.Wait()
   247  		errs = requireNoErrs(t, errs)
   248  	})
   249  
   250  	t.Run("server to client", func(t *testing.T) {
   251  		tunnel := newBidi(10, ctx.Done())
   252  		wg := sync.WaitGroup{}
   253  		wg.Add(2)
   254  		go func() {
   255  			defer wg.Done()
   256  			if server, err := NewServerStream(ctx, tunnel.serverSide()); err != nil {
   257  				errs <- err
   258  			} else {
   259  				produce(ctx, server, large, errs)
   260  			}
   261  		}()
   262  		go func() {
   263  			defer wg.Done()
   264  			if client, err := NewClientStream(ctx, tunnel.clientSide(), id, si, 0, 0); err != nil {
   265  				errs <- err
   266  			} else {
   267  				consume(ctx, client, b, errs)
   268  			}
   269  		}()
   270  		wg.Wait()
   271  		errs = requireNoErrs(t, errs)
   272  	})
   273  
   274  	t.Run("client to client over BidiPipe", func(t *testing.T) {
   275  		ta := newBidi(10, ctx.Done())
   276  		tb := newBidi(10, ctx.Done())
   277  
   278  		var counter int32
   279  		aCh := make(chan Stream)
   280  		bCh := make(chan Stream)
   281  		wg := sync.WaitGroup{}
   282  		wg.Add(5)
   283  		go func() {
   284  			defer wg.Done()
   285  			if s, err := NewServerStream(ctx, ta.serverSide()); err != nil {
   286  				errs <- err
   287  				close(aCh)
   288  			} else {
   289  				aCh <- s
   290  			}
   291  		}()
   292  		go func() {
   293  			defer wg.Done()
   294  			if s, err := NewServerStream(ctx, tb.serverSide()); err != nil {
   295  				errs <- err
   296  				close(bCh)
   297  			} else {
   298  				bCh <- s
   299  			}
   300  		}()
   301  		go func() {
   302  			defer wg.Done()
   303  			if server, err := NewClientStream(ctx, ta.clientSide(), id, si, 0, 0); err != nil {
   304  				errs <- err
   305  			} else {
   306  				produce(ctx, server, large, errs)
   307  			}
   308  		}()
   309  		go func() {
   310  			defer wg.Done()
   311  			if client, err := NewClientStream(ctx, tb.clientSide(), id, si, 0, 0); err != nil {
   312  				errs <- err
   313  			} else {
   314  				consume(ctx, client, b, errs)
   315  			}
   316  		}()
   317  		go func() {
   318  			defer wg.Done()
   319  			var a, b Stream
   320  			for a == nil || b == nil {
   321  				select {
   322  				case <-ctx.Done():
   323  					errs <- ctx.Err()
   324  					return
   325  				case a = <-aCh:
   326  				case b = <-bCh:
   327  				}
   328  			}
   329  			fwd := NewBidiPipe(a, b, "pipe", &counter, nil)
   330  			fwd.Start(ctx)
   331  			select {
   332  			case <-ctx.Done():
   333  				errs <- ctx.Err()
   334  			case <-fwd.Done():
   335  			}
   336  		}()
   337  		wg.Wait()
   338  		errs = requireNoErrs(t, errs)
   339  	})
   340  }