github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/tunnel/dialer.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"go.opentelemetry.io/otel"
    14  	"go.opentelemetry.io/otel/codes"
    15  	"go.opentelemetry.io/otel/propagation"
    16  
    17  	"github.com/datawire/dlib/dlog"
    18  	rpc "github.com/telepresenceio/telepresence/rpc/v2/manager"
    19  	"github.com/telepresenceio/telepresence/v2/pkg/ipproto"
    20  )
    21  
    22  // The idleDuration controls how long a dialer for a specific proto+from-to address combination remains alive without
    23  // reading or writing any messages. The dialer is normally closed by one of the peers.
    24  const (
    25  	tcpConnTTL = 2 * time.Hour // Default tcp_keepalive_time on Linux
    26  	udpConnTTL = 1 * time.Minute
    27  )
    28  
    29  const (
    30  	notConnected = int32(iota)
    31  	connecting
    32  	connected
    33  )
    34  
    35  // streamReader is implemented by the dialer and udpListener so that they can share the
    36  // readLoop function.
    37  type streamReader interface {
    38  	Idle() <-chan time.Time
    39  	ResetIdle() bool
    40  	Stop(context.Context)
    41  	getStream() Stream
    42  	reply([]byte) (int, error)
    43  	startDisconnect(context.Context, string)
    44  }
    45  
    46  // The dialer takes care of dispatching messages between gRPC and UDP connections.
    47  type dialer struct {
    48  	TimedHandler
    49  	stream    Stream
    50  	cancel    context.CancelFunc
    51  	conn      net.Conn
    52  	connected int32
    53  	done      chan struct{}
    54  
    55  	ingressBytesProbe *CounterProbe
    56  	egressBytesProbe  *CounterProbe
    57  }
    58  
    59  // NewDialer creates a new handler that dispatches messages in both directions between the given gRPC stream
    60  // and the given connection.
    61  func NewDialer(
    62  	stream Stream,
    63  	cancel context.CancelFunc,
    64  	ingressBytesProbe, egressBytesProbe *CounterProbe,
    65  ) Endpoint {
    66  	return NewConnEndpoint(stream, nil, cancel, ingressBytesProbe, egressBytesProbe)
    67  }
    68  
    69  // NewDialerTTL creates a new handler that dispatches messages in both directions between the given gRPC stream
    70  // and the given connection. The TTL decides how long the connection can be idle before it's closed.
    71  //
    72  // The handler remains active until it's been idle for the ttl duration, at which time it will automatically close
    73  // and call the release function it got from the tunnel.Pool to ensure that it gets properly released.
    74  func NewDialerTTL(stream Stream, cancel context.CancelFunc, ttl time.Duration, ingressBytesProbe, egressBytesProbe *CounterProbe) Endpoint {
    75  	return NewConnEndpointTTL(stream, nil, cancel, ttl, ingressBytesProbe, egressBytesProbe)
    76  }
    77  
    78  func NewConnEndpoint(stream Stream, conn net.Conn, cancel context.CancelFunc, ingressBytesProbe, egressBytesProbe *CounterProbe) Endpoint {
    79  	ttl := tcpConnTTL
    80  	if stream.ID().Protocol() == ipproto.UDP {
    81  		ttl = udpConnTTL
    82  	}
    83  	return NewConnEndpointTTL(stream, conn, cancel, ttl, ingressBytesProbe, egressBytesProbe)
    84  }
    85  
    86  func NewConnEndpointTTL(
    87  	stream Stream,
    88  	conn net.Conn,
    89  	cancel context.CancelFunc,
    90  	ttl time.Duration,
    91  	ingressBytesProbe, egressBytesProbe *CounterProbe,
    92  ) Endpoint {
    93  	state := notConnected
    94  	if conn != nil {
    95  		state = connecting
    96  	}
    97  	return &dialer{
    98  		TimedHandler: NewTimedHandler(stream.ID(), ttl, nil),
    99  		stream:       stream,
   100  		cancel:       cancel,
   101  		conn:         conn,
   102  		connected:    state,
   103  		done:         make(chan struct{}),
   104  
   105  		ingressBytesProbe: ingressBytesProbe,
   106  		egressBytesProbe:  egressBytesProbe,
   107  	}
   108  }
   109  
   110  func (h *dialer) Start(ctx context.Context) {
   111  	go func() {
   112  		ctx, span := otel.Tracer("").Start(ctx, "dialer")
   113  		defer span.End()
   114  		defer close(h.done)
   115  
   116  		id := h.stream.ID()
   117  		id.SpanRecord(span)
   118  
   119  		switch h.connected {
   120  		case notConnected:
   121  			// Set up the idle timer to close and release this handler when it's been idle for a while.
   122  			h.connected = connecting
   123  
   124  			dlog.Tracef(ctx, "   CONN %s, dialing", id)
   125  			d := net.Dialer{Timeout: h.stream.DialTimeout()}
   126  			conn, err := d.DialContext(ctx, id.DestinationProtocolString(), id.DestinationAddr().String())
   127  			if err != nil {
   128  				dlog.Errorf(ctx, "!! CONN %s, failed to establish connection: %v", id, err)
   129  				span.SetStatus(codes.Error, err.Error())
   130  				if err = h.stream.Send(ctx, NewMessage(DialReject, nil)); err != nil {
   131  					dlog.Errorf(ctx, "!! CONN %s, failed to send DialReject: %v", id, err)
   132  				}
   133  				if err = h.stream.CloseSend(ctx); err != nil {
   134  					dlog.Errorf(ctx, "!! CONN %s, stream.CloseSend failed: %v", id, err)
   135  				}
   136  				h.connected = notConnected
   137  				return
   138  			}
   139  			if err = h.stream.Send(ctx, NewMessage(DialOK, nil)); err != nil {
   140  				_ = conn.Close()
   141  				dlog.Errorf(ctx, "!! CONN %s, failed to send DialOK: %v", id, err)
   142  				span.SetStatus(codes.Error, err.Error())
   143  				return
   144  			}
   145  			dlog.Tracef(ctx, "   CONN %s, dial answered", id)
   146  			h.conn = conn
   147  
   148  		case connecting:
   149  		default:
   150  			dlog.Errorf(ctx, "!! CONN %s, start called in invalid state", id)
   151  			return
   152  		}
   153  
   154  		// Set up the idle timer to close and release this endpoint when it's been idle for a while.
   155  		h.TimedHandler.Start(ctx)
   156  		h.connected = connected
   157  
   158  		wg := sync.WaitGroup{}
   159  		wg.Add(2)
   160  		go h.connToStreamLoop(ctx, &wg)
   161  		go h.streamToConnLoop(ctx, &wg)
   162  		wg.Wait()
   163  		h.Stop(ctx)
   164  	}()
   165  }
   166  
   167  func (h *dialer) Done() <-chan struct{} {
   168  	return h.done
   169  }
   170  
   171  // Stop will close the underlying TCP/UDP connection.
   172  func (h *dialer) Stop(ctx context.Context) {
   173  	h.startDisconnect(ctx, "explicit close")
   174  	h.cancel()
   175  }
   176  
   177  func (h *dialer) startDisconnect(ctx context.Context, reason string) {
   178  	if !atomic.CompareAndSwapInt32(&h.connected, connected, notConnected) {
   179  		return
   180  	}
   181  	id := h.stream.ID()
   182  	dlog.Tracef(ctx, "   CONN %s closing connection: %s", id, reason)
   183  	if err := h.conn.Close(); err != nil {
   184  		dlog.Tracef(ctx, "!! CONN %s, Close failed: %v", id, err)
   185  	}
   186  }
   187  
   188  func (h *dialer) connToStreamLoop(ctx context.Context, wg *sync.WaitGroup) {
   189  	var endReason string
   190  	endLevel := dlog.LogLevelTrace
   191  	id := h.stream.ID()
   192  
   193  	outgoing := make(chan Message, 50)
   194  	defer func() {
   195  		if !h.ResetIdle() {
   196  			// Hard close of peer. We don't want any more data
   197  			select {
   198  			case outgoing <- NewMessage(Disconnect, nil):
   199  			default:
   200  			}
   201  		}
   202  		close(outgoing)
   203  		dlog.Logf(ctx, endLevel, "   CONN %s conn-to-stream loop ended because %s", id, endReason)
   204  		wg.Done()
   205  	}()
   206  
   207  	wg.Add(1)
   208  	WriteLoop(ctx, h.stream, outgoing, wg, h.egressBytesProbe)
   209  
   210  	buf := make([]byte, 0x100000)
   211  	dlog.Tracef(ctx, "   CONN %s conn-to-stream loop started", id)
   212  	for {
   213  		n, err := h.conn.Read(buf)
   214  		if n > 0 {
   215  			dlog.Tracef(ctx, "<- CONN %s, len %d", id, n)
   216  			select {
   217  			case <-ctx.Done():
   218  				endReason = ctx.Err().Error()
   219  				return
   220  			case outgoing <- NewMessage(Normal, buf[:n]):
   221  			}
   222  		}
   223  
   224  		if err != nil {
   225  			switch {
   226  			case errors.Is(err, io.EOF):
   227  				endReason = "EOF was encountered"
   228  			case errors.Is(err, net.ErrClosed):
   229  				endReason = "the connection was closed"
   230  				h.startDisconnect(ctx, endReason)
   231  			default:
   232  				endReason = fmt.Sprintf("a read error occurred: %v", err)
   233  				endLevel = dlog.LogLevelError
   234  			}
   235  			return
   236  		}
   237  
   238  		if !h.ResetIdle() {
   239  			endReason = "it was idle for too long"
   240  			return
   241  		}
   242  	}
   243  }
   244  
   245  func (h *dialer) getStream() Stream {
   246  	return h.stream
   247  }
   248  
   249  func (h *dialer) reply(data []byte) (int, error) {
   250  	return h.conn.Write(data)
   251  }
   252  
   253  func (h *dialer) streamToConnLoop(ctx context.Context, wg *sync.WaitGroup) {
   254  	defer func() {
   255  		wg.Done()
   256  	}()
   257  	readLoop(ctx, h, h.ingressBytesProbe)
   258  }
   259  
   260  func handleControl(ctx context.Context, h streamReader, cm Message) {
   261  	switch cm.Code() {
   262  	case DialReject, Disconnect: // Peer wants to hard-close. No more messages will arrive
   263  		h.Stop(ctx)
   264  	case KeepAlive:
   265  		h.ResetIdle()
   266  	case DialOK:
   267  		// So how can a dialer get a DialOK from a peer? Surely, there cannot be a dialer at both ends?
   268  		// Well, the story goes like this:
   269  		// 1. A request to the service is made on the workstation.
   270  		// 2. This agent's listener receives a connection.
   271  		// 3. Since an intercept is active, the agent creates a tunnel to the workstation
   272  		// 4. A new dialer is attached to that tunnel (reused as a tunnel endpoint)
   273  		// 5. The dialer at the workstation dials and responds with DialOK, and here we are.
   274  	default:
   275  		dlog.Errorf(ctx, "!! CONN %s: unhandled connection control message: %s", h.getStream().ID(), cm)
   276  	}
   277  }
   278  
   279  func readLoop(ctx context.Context, h streamReader, trafficProbe *CounterProbe) {
   280  	var endReason string
   281  	endLevel := dlog.LogLevelTrace
   282  	id := h.getStream().ID()
   283  	defer func() {
   284  		h.startDisconnect(ctx, endReason)
   285  		dlog.Logf(ctx, endLevel, "   CONN %s stream-to-conn loop ended because %s", id, endReason)
   286  	}()
   287  
   288  	incoming, errCh := ReadLoop(ctx, h.getStream(), trafficProbe)
   289  	dlog.Tracef(ctx, "   CONN %s stream-to-conn loop started", id)
   290  	for {
   291  		select {
   292  		case <-ctx.Done():
   293  			endReason = ctx.Err().Error()
   294  			return
   295  		case <-h.Idle():
   296  			endReason = "it was idle for too long"
   297  			return
   298  		case err, ok := <-errCh:
   299  			if ok {
   300  				dlog.Error(ctx, err)
   301  			}
   302  		case dg, ok := <-incoming:
   303  			if !ok {
   304  				// h.incoming was closed by the reader and is now drained.
   305  				endReason = "there was no more input"
   306  				return
   307  			}
   308  			if !h.ResetIdle() {
   309  				endReason = "it was idle for too long"
   310  				return
   311  			}
   312  			if dg.Code() != Normal {
   313  				handleControl(ctx, h, dg)
   314  				continue
   315  			}
   316  			payload := dg.Payload()
   317  			pn := len(payload)
   318  			for n := 0; n < pn; {
   319  				wn, err := h.reply(payload[n:])
   320  				if err != nil {
   321  					endReason = fmt.Sprintf("a write error occurred: %v", err)
   322  					endLevel = dlog.LogLevelError
   323  					return
   324  				}
   325  				dlog.Tracef(ctx, "-> CONN %s, len %d", id, wn)
   326  				n += wn
   327  			}
   328  		}
   329  	}
   330  }
   331  
   332  // DialWaitLoop reads from the given dialStream. A new goroutine that creates a Tunnel to the manager and then
   333  // attaches a dialer Endpoint to that tunnel is spawned for each request that arrives. The method blocks until
   334  // the dialStream is closed.
   335  func DialWaitLoop(
   336  	ctx context.Context,
   337  	tunnelProvider Provider,
   338  	dialStream rpc.Manager_WatchDialClient,
   339  	sessionID string,
   340  ) error {
   341  	// create ctx to cleanup leftover dialRespond if waitloop dies
   342  	ctx, cancel := context.WithCancel(ctx)
   343  	defer cancel()
   344  	for ctx.Err() == nil {
   345  		dr, err := dialStream.Recv()
   346  		if err != nil {
   347  			if ctx.Err() == nil && !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
   348  				return fmt.Errorf("dial request stream recv: %w", err) // May be io.EOF
   349  			}
   350  			return nil
   351  		}
   352  		go dialRespond(ctx, tunnelProvider, dr, sessionID)
   353  	}
   354  	return nil
   355  }
   356  
   357  func dialRespond(ctx context.Context, tunnelProvider Provider, dr *rpc.DialRequest, sessionID string) {
   358  	if tc := dr.GetTraceContext(); tc != nil {
   359  		carrier := propagation.MapCarrier(tc)
   360  		propagator := otel.GetTextMapPropagator()
   361  		ctx = propagator.Extract(ctx, carrier)
   362  	}
   363  	ctx, span := otel.Tracer("").Start(ctx, "dialRespond")
   364  	defer span.End()
   365  	id := ConnID(dr.ConnId)
   366  	id.SpanRecord(span)
   367  	mt, err := tunnelProvider.Tunnel(ctx)
   368  	if err != nil {
   369  		dlog.Errorf(ctx, "!! CONN %s, call to manager Tunnel failed: %v", id, err)
   370  		return
   371  	}
   372  	ctx, cancel := context.WithCancel(ctx)
   373  	s, err := NewClientStream(ctx, mt, id, sessionID, time.Duration(dr.RoundtripLatency), time.Duration(dr.DialTimeout))
   374  	if err != nil {
   375  		dlog.Error(ctx, err)
   376  		cancel()
   377  		return
   378  	}
   379  	d := NewDialer(s, cancel, nil, nil)
   380  	d.Start(ctx)
   381  	<-d.Done()
   382  }