github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/tunnel/stream.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  	"time"
    11  
    12  	"go.opentelemetry.io/otel"
    13  	"google.golang.org/grpc/codes"
    14  	"google.golang.org/grpc/status"
    15  
    16  	"github.com/datawire/dlib/dlog"
    17  	rpc "github.com/telepresenceio/telepresence/rpc/v2/manager"
    18  )
    19  
    20  // Version
    21  //
    22  //	0 which didn't report versions and didn't do synchronization
    23  //	1 used MuxTunnel instead of one tunnel per connection.
    24  const Version = uint16(2)
    25  
    26  // Endpoint is an endpoint for a Stream such as a Dialer or a bidirectional pipe.
    27  type Endpoint interface {
    28  	Start(ctx context.Context)
    29  	Done() <-chan struct{}
    30  }
    31  
    32  // GRPCStream is the bare minimum needed for reading and writing TunnelMessages
    33  // on a Manager_TunnelServer or Manager_TunnelClient.
    34  type GRPCStream interface {
    35  	Recv() (*rpc.TunnelMessage, error)
    36  	Send(*rpc.TunnelMessage) error
    37  }
    38  
    39  // The Stream interface represents a bidirectional, synchronized connection Tunnel
    40  // that sends TCP or UDP traffic over gRPC using manager.TunnelMessage messages.
    41  //
    42  // A Stream is closed by one of six things happening at either end (or at both ends).
    43  //
    44  //  1. Read from local connection fails (typically EOF)
    45  //  2. Write to local connection fails (connection peer closed)
    46  //  3. Idle timer timed out.
    47  //  4. Context is cancelled.
    48  //  5. closeSend request received from Tunnel peer.
    49  //  6. Disconnect received from Tunnel peer.
    50  //
    51  // When #1 or #2 happens, the Stream will either call CloseSend() (if it's a client Stream)
    52  // or send a closeSend request (if it's a StreamServer) to its Stream peer, shorten the
    53  // Idle timer, and then continue to serve incoming data from the Stream peer until it's
    54  // closed or a Disconnect is received. Once that happens, it's guaranteed that the Tunnel
    55  // peer will send no more messages and the Stream is closed.
    56  //
    57  // When #3, #4, or #5 happens, the Tunnel will send a Disconnect to its Stream peer and close.
    58  //
    59  // When #6 happens, the Stream will simply close.
    60  type Stream interface {
    61  	Tag() string
    62  	ID() ConnID
    63  	Receive(context.Context) (Message, error)
    64  	Send(context.Context, Message) error
    65  	CloseSend(ctx context.Context) error
    66  	PeerVersion() uint16
    67  	SessionID() string
    68  	DialTimeout() time.Duration
    69  	RoundtripLatency() time.Duration
    70  }
    71  
    72  // StreamCreator is a function that creats a Stream.
    73  type StreamCreator func(context.Context, ConnID) (Stream, error)
    74  
    75  // ReadLoop reads from the Stream and dispatches messages and error to the give channels. There
    76  // will be max one error since the error also terminates the loop.
    77  func ReadLoop(ctx context.Context, s Stream, p *CounterProbe) (<-chan Message, <-chan error) {
    78  	msgCh := make(chan Message, 50)
    79  	errCh := make(chan error, 1) // Max one message will be sent on this channel
    80  	dlog.Tracef(ctx, "   %s %s, ReadLoop starting", s.Tag(), s.ID())
    81  	go func() {
    82  		ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "ReadLoop")
    83  		defer span.End()
    84  		s.ID().SpanRecord(span)
    85  		var endReason string
    86  		defer func() {
    87  			close(errCh)
    88  			close(msgCh)
    89  			dlog.Tracef(ctx, "   %s %s, ReadLoop ended: %s", s.Tag(), s.ID(), endReason)
    90  		}()
    91  
    92  		for {
    93  			m, err := s.Receive(ctx)
    94  			if m != nil && p != nil {
    95  				p.Increment(uint64(len(m.Payload())))
    96  			}
    97  
    98  			switch {
    99  			case err == nil:
   100  				select {
   101  				case <-ctx.Done():
   102  					endReason = ctx.Err().Error()
   103  				case msgCh <- m:
   104  					continue
   105  				}
   106  			case ctx.Err() != nil:
   107  				endReason = ctx.Err().Error()
   108  			case errors.Is(err, io.EOF):
   109  				endReason = "EOF on input"
   110  			case errors.Is(err, net.ErrClosed):
   111  				endReason = "stream closed"
   112  			case errors.Is(err, context.Canceled), status.Code(err) == codes.Canceled:
   113  				endReason = err.Error()
   114  			default:
   115  				endReason = err.Error()
   116  				select {
   117  				case errCh <- fmt.Errorf("!! %s %s, read from grpc.ClientStream failed: %w", s.Tag(), s.ID(), err):
   118  				default:
   119  				}
   120  			}
   121  			break
   122  		}
   123  	}()
   124  	return msgCh, errCh
   125  }
   126  
   127  // WriteLoop reads messages from the channel and writes them to the Stream. It will call CloseSend() on the
   128  // stream when the channel is closed.
   129  func WriteLoop(
   130  	ctx context.Context,
   131  	s Stream, msgCh <-chan Message,
   132  	wg *sync.WaitGroup,
   133  	p *CounterProbe,
   134  ) {
   135  	dlog.Tracef(ctx, "   %s %s, WriteLoop starting", s.Tag(), s.ID())
   136  	go func() {
   137  		ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "WriteLoop")
   138  		defer span.End()
   139  		s.ID().SpanRecord(span)
   140  		var endReason string
   141  		defer func() {
   142  			dlog.Tracef(ctx, "   %s %s, WriteLoop ended: %s", s.Tag(), s.ID(), endReason)
   143  			if err := s.CloseSend(ctx); err != nil {
   144  				dlog.Errorf(ctx, "!! %s %s, Send of closeSend failed: %v", s.Tag(), s.ID(), err)
   145  			}
   146  			wg.Done()
   147  		}()
   148  		for {
   149  			select {
   150  			case <-ctx.Done():
   151  				endReason = ctx.Err().Error()
   152  			case m, ok := <-msgCh:
   153  				if !ok {
   154  					endReason = "input channel is closed"
   155  					break
   156  				}
   157  
   158  				err := s.Send(ctx, m)
   159  				if m != nil && p != nil {
   160  					p.Increment(uint64(len(m.Payload())))
   161  				}
   162  
   163  				switch {
   164  				case err == nil:
   165  					continue
   166  				case errors.Is(err, net.ErrClosed):
   167  					endReason = "output stream is closed"
   168  				default:
   169  					endReason = err.Error()
   170  					dlog.Errorf(ctx, "!! %s %s, Send failed: %v", s.Tag(), s.ID(), err)
   171  				}
   172  			}
   173  			break
   174  		}
   175  	}()
   176  }
   177  
   178  type stream struct {
   179  	grpcStream       GRPCStream
   180  	id               ConnID
   181  	dialTimeout      time.Duration
   182  	roundtripLatency time.Duration
   183  	sessionID        string
   184  	tag              string
   185  	syncRatio        uint32 // send and check sync after each syncRatio message
   186  	ackWindow        uint32 // maximum permitted difference between sent and received ack
   187  	peerVersion      uint16
   188  }
   189  
   190  func newStream(tag string, grpcStream GRPCStream) stream {
   191  	return stream{tag: tag, grpcStream: grpcStream, syncRatio: 8, ackWindow: 1}
   192  }
   193  
   194  func (s *stream) Tag() string {
   195  	return s.tag
   196  }
   197  
   198  func (s *stream) ID() ConnID {
   199  	return s.id
   200  }
   201  
   202  func (s *stream) PeerVersion() uint16 {
   203  	return s.peerVersion
   204  }
   205  
   206  func (s *stream) DialTimeout() time.Duration {
   207  	return s.dialTimeout
   208  }
   209  
   210  func (s *stream) RoundtripLatency() time.Duration {
   211  	return s.roundtripLatency
   212  }
   213  
   214  func (s *stream) SessionID() string {
   215  	return s.sessionID
   216  }
   217  
   218  func (s *stream) Receive(ctx context.Context) (Message, error) {
   219  	cm, err := s.grpcStream.Recv()
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  	m := msg(cm.Payload)
   224  	switch m.Code() {
   225  	case closeSend:
   226  		dlog.Tracef(ctx, "<- %s %s, close send", s.tag, s.id)
   227  		return nil, net.ErrClosed
   228  	case streamInfo:
   229  		dlog.Tracef(ctx, "<- %s, %s", s.tag, m)
   230  	default:
   231  		dlog.Tracef(ctx, "<- %s %s, %s", s.tag, s.id, m)
   232  	}
   233  	return m, nil
   234  }
   235  
   236  func (s *stream) Send(ctx context.Context, m Message) error {
   237  	if err := s.grpcStream.Send(m.TunnelMessage()); err != nil {
   238  		if ctx.Err() == nil && !errors.Is(err, net.ErrClosed) {
   239  			dlog.Errorf(ctx, "!! %s %s, Send failed: %v", s.tag, s.id, err)
   240  		}
   241  		return err
   242  	}
   243  	dlog.Tracef(ctx, "-> %s %s, %s", s.tag, s.id, m)
   244  	return nil
   245  }
   246  
   247  func (s *stream) CloseSend(ctx context.Context) error {
   248  	if err := s.Send(ctx, NewMessage(closeSend, nil)); err != nil {
   249  		if ctx.Err() == nil && !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
   250  			return fmt.Errorf("send of closeSend message failed: %w", err)
   251  		}
   252  	}
   253  	return nil
   254  }