github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/client/userd/daemon/mgrproxy.go (about)

     1  package daemon
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  
    10  	"google.golang.org/grpc"
    11  	"google.golang.org/grpc/codes"
    12  	"google.golang.org/grpc/status"
    13  	"google.golang.org/protobuf/types/known/emptypb"
    14  
    15  	"github.com/datawire/dlib/dlog"
    16  	"github.com/telepresenceio/telepresence/rpc/v2/connector"
    17  	"github.com/telepresenceio/telepresence/rpc/v2/manager"
    18  )
    19  
    20  // mgrProxy implements connector.ManagerProxyServer, but just proxies all requests through a manager.ManagerClient.
    21  type mgrProxy struct {
    22  	sync.RWMutex
    23  	clientX      manager.ManagerClient
    24  	callOptionsX []grpc.CallOption
    25  
    26  	connector.UnsafeManagerProxyServer
    27  }
    28  
    29  var _ connector.ManagerProxyServer = &mgrProxy{}
    30  
    31  func (p *mgrProxy) setClient(client manager.ManagerClient, callOptions ...grpc.CallOption) {
    32  	p.Lock()
    33  	p.clientX = client
    34  	p.callOptionsX = callOptions
    35  	p.Unlock()
    36  }
    37  
    38  func (p *mgrProxy) get() (manager.ManagerClient, []grpc.CallOption, error) {
    39  	p.RLock()
    40  	defer p.RUnlock()
    41  	if p.clientX == nil {
    42  		return nil, nil, status.Error(codes.Unavailable, "telepresence: the userd is not connected to the manager")
    43  	}
    44  	return p.clientX, p.callOptionsX, nil
    45  }
    46  
    47  func (p *mgrProxy) Version(ctx context.Context, arg *emptypb.Empty) (*manager.VersionInfo2, error) {
    48  	client, callOptions, err := p.get()
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	return client.Version(ctx, arg, callOptions...)
    53  }
    54  
    55  func (p *mgrProxy) GetClientConfig(ctx context.Context, arg *emptypb.Empty) (*manager.CLIConfig, error) {
    56  	client, callOptions, err := p.get()
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	return client.GetClientConfig(ctx, arg, callOptions...)
    61  }
    62  
    63  type tmReceiver interface {
    64  	Recv() (*manager.TunnelMessage, error)
    65  }
    66  
    67  type tmSender interface {
    68  	Send(*manager.TunnelMessage) error
    69  }
    70  
    71  func recvLoop(ctx context.Context, who string, in tmReceiver, out chan<- *manager.TunnelMessage, wg *sync.WaitGroup) {
    72  	defer func() {
    73  		dlog.Tracef(ctx, "%s Recv loop ended", who)
    74  		close(out)
    75  		wg.Done()
    76  	}()
    77  	dlog.Tracef(ctx, "%s Recv loop started", who)
    78  	for {
    79  		payload, err := in.Recv()
    80  		if err != nil {
    81  			if ctx.Err() == nil && !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
    82  				dlog.Errorf(ctx, "Tunnel %s.Recv() failed: %v", who, err)
    83  			}
    84  			return
    85  		}
    86  		dlog.Tracef(ctx, "<- %s %d", who, len(payload.Payload))
    87  		select {
    88  		case <-ctx.Done():
    89  			return
    90  		case out <- payload:
    91  		}
    92  	}
    93  }
    94  
    95  func sendLoop(ctx context.Context, who string, out tmSender, in <-chan *manager.TunnelMessage, wg *sync.WaitGroup) {
    96  	defer func() {
    97  		dlog.Tracef(ctx, "%s Send loop ended", who)
    98  		wg.Done()
    99  	}()
   100  	dlog.Tracef(ctx, "%s Send loop started", who)
   101  	if outC, ok := out.(interface{ CloseSend() error }); ok {
   102  		defer func() {
   103  			if err := outC.CloseSend(); err != nil {
   104  				dlog.Errorf(ctx, "CloseSend() failed: %v", err)
   105  			}
   106  		}()
   107  	}
   108  	for {
   109  		select {
   110  		case <-ctx.Done():
   111  			return
   112  		case payload, ok := <-in:
   113  			if !ok {
   114  				return
   115  			}
   116  			if err := out.Send(payload); err != nil {
   117  				if !errors.Is(err, net.ErrClosed) {
   118  					dlog.Errorf(ctx, "Tunnel %s.Send() failed: %v", who, err)
   119  				}
   120  				return
   121  			}
   122  			dlog.Tracef(ctx, "-> %s %d", who, len(payload.Payload))
   123  		}
   124  	}
   125  }
   126  
   127  func (p *mgrProxy) Tunnel(fhClient connector.ManagerProxy_TunnelServer) error {
   128  	client, callOptions, err := p.get()
   129  	if err != nil {
   130  		return err
   131  	}
   132  	ctx := fhClient.Context()
   133  	fhManager, err := client.Tunnel(ctx, callOptions...)
   134  	if err != nil {
   135  		return err
   136  	}
   137  	mgrToClient := make(chan *manager.TunnelMessage)
   138  	clientToMgr := make(chan *manager.TunnelMessage)
   139  
   140  	wg := sync.WaitGroup{}
   141  	wg.Add(4)
   142  	go recvLoop(ctx, "manager", fhManager, mgrToClient, &wg)
   143  	go sendLoop(ctx, "manager", fhManager, clientToMgr, &wg)
   144  	go recvLoop(ctx, "client", fhClient, clientToMgr, &wg)
   145  	go sendLoop(ctx, "client", fhClient, mgrToClient, &wg)
   146  	wg.Wait()
   147  	return nil
   148  }
   149  
   150  func (p *mgrProxy) EnsureAgent(ctx context.Context, arg *manager.EnsureAgentRequest) (*emptypb.Empty, error) {
   151  	client, callOptions, err := p.get()
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	return client.EnsureAgent(ctx, arg, callOptions...)
   156  }
   157  
   158  func (p *mgrProxy) LookupDNS(ctx context.Context, arg *manager.DNSRequest) (*manager.DNSResponse, error) {
   159  	client, callOptions, err := p.get()
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	return client.LookupDNS(ctx, arg, callOptions...)
   164  }
   165  
   166  func (p *mgrProxy) WatchClusterInfo(arg *manager.SessionInfo, srv connector.ManagerProxy_WatchClusterInfoServer) error {
   167  	client, callOptions, err := p.get()
   168  	if err != nil {
   169  		return err
   170  	}
   171  	cli, err := client.WatchClusterInfo(srv.Context(), arg, callOptions...)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	for {
   176  		info, err := cli.Recv()
   177  		if err != nil {
   178  			if err == io.EOF || srv.Context().Err() != nil {
   179  				return nil
   180  			}
   181  			return err
   182  		}
   183  		if err = srv.Send(info); err != nil {
   184  			return err
   185  		}
   186  	}
   187  }