github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/dnet/kpfconn_test.go (about)

     1  package dnet_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"path/filepath"
    10  	"runtime"
    11  	"strconv"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/sirupsen/logrus"
    16  	"golang.org/x/net/nettest"
    17  	"k8s.io/cli-runtime/pkg/genericclioptions"
    18  	"k8s.io/client-go/kubernetes"
    19  
    20  	"github.com/datawire/dlib/dcontext"
    21  	"github.com/datawire/dlib/dexec"
    22  	"github.com/datawire/dlib/dgroup"
    23  	"github.com/datawire/dlib/dlog"
    24  	"github.com/datawire/dlib/dtime"
    25  	"github.com/telepresenceio/telepresence/v2/pkg/dnet"
    26  )
    27  
    28  var mockServerBinary string
    29  
    30  func TestMain(m *testing.M) {
    31  	mbf, err := os.CreateTemp("", "mockServer")
    32  	if err != nil {
    33  		fmt.Fprintln(os.Stderr, err)
    34  		os.Exit(1)
    35  	}
    36  	mockServerBinary = mbf.Name()
    37  	if runtime.GOOS == "windows" {
    38  		mockServerBinary += ".exe"
    39  	}
    40  	mbf.Close()
    41  	ctx := dlog.WithLogger(context.Background(), dlog.WrapLogrus(logrus.StandardLogger()))
    42  	cmd := dexec.CommandContext(ctx, "go", "build", "-o", mockServerBinary, ".")
    43  	cmd.Dir = filepath.Join("testdata", "mockserver")
    44  	if err := cmd.Run(); err != nil {
    45  		fmt.Fprintln(os.Stderr, err)
    46  		os.Exit(1)
    47  	}
    48  	defer os.Remove(mockServerBinary)
    49  	m.Run()
    50  }
    51  
    52  func TestKubectlPortForward(t *testing.T) {
    53  	if runtime.GOOS == "windows" {
    54  		t.SkipNow()
    55  	}
    56  	if _, err := dexec.LookPath("socat"); err != nil {
    57  		if runtime.GOOS == "linux" && os.Getenv("CI") != "" {
    58  			t.Fatal("would skip this test in CI, which isn't OK")
    59  		}
    60  		t.SkipNow()
    61  	}
    62  	strPtr := func(s string) *string {
    63  		return &s
    64  	}
    65  
    66  	makePipe := func() (_, _ net.Conn, _ func(), _err error) {
    67  		ctx, cancel := context.WithCancel(dcontext.WithSoftness(dlog.NewTestContext(t, true)))
    68  		grp := dgroup.NewGroup(ctx, dgroup.GroupConfig{})
    69  		var cliConn, srvConn net.Conn
    70  		stop := func() {
    71  			cancel()
    72  			if err := grp.Wait(); err != nil {
    73  				t.Error(err)
    74  			}
    75  			// This is 10% just to do cleanup, and is 90% to prevent the GC from calling
    76  			// srvConn's finalizaer and closing the connection while the test is still
    77  			// running.
    78  			if cliConn != nil {
    79  				_ = cliConn.Close()
    80  			}
    81  			if srvConn != nil {
    82  				_ = srvConn.Close()
    83  			}
    84  		}
    85  		defer func() {
    86  			if _err != nil {
    87  				stop()
    88  			}
    89  		}()
    90  
    91  		podListener, err := net.Listen("tcp", "127.0.0.1:0")
    92  		if err != nil {
    93  			return nil, nil, nil, err
    94  		}
    95  		defer func() {
    96  			if _err != nil {
    97  				_ = podListener.Close()
    98  			}
    99  		}()
   100  
   101  		apiserverListener, err := net.Listen("tcp", "127.0.0.1:0")
   102  		if err != nil {
   103  			return nil, nil, nil, err
   104  		}
   105  		apiserverAddr := apiserverListener.Addr().(*net.TCPAddr)
   106  		_ = apiserverListener.Close()
   107  
   108  		srvConnCh := make(chan net.Conn)
   109  		apiReady := make(chan struct{})
   110  		grp.Go("pod", func(_ context.Context) error {
   111  			conn, err := podListener.Accept()
   112  			t.Log("accepted")
   113  			_ = podListener.Close()
   114  			if err != nil {
   115  				return err
   116  			}
   117  			srvConnCh <- conn
   118  			return nil
   119  		})
   120  		grp.Go("apiserver", func(ctx context.Context) error {
   121  			cmd := dexec.CommandContext(
   122  				ctx, mockServerBinary, "-p", strconv.Itoa(apiserverAddr.Port))
   123  			cmd.DisableLogging = true
   124  			cmd.Stdout = dlog.StdLogger(ctx, dlog.LogLevelInfo).Writer()
   125  			cmd.Stderr = dlog.StdLogger(ctx, dlog.LogLevelError).Writer()
   126  			err := cmd.Start()
   127  			if err != nil {
   128  				close(apiReady)
   129  				return err
   130  			}
   131  			for i := 0; i < 100; i++ {
   132  				dtime.SleepWithContext(ctx, 10*time.Millisecond)
   133  				var rsp *http.Response
   134  				if rsp, err = http.DefaultClient.Get(fmt.Sprintf("http://localhost:%d/api", apiserverAddr.Port)); err == nil {
   135  					rsp.Body.Close()
   136  					close(apiReady)
   137  					_ = cmd.Wait()
   138  					return nil
   139  				}
   140  			}
   141  			close(apiReady)
   142  			return err
   143  		})
   144  		<-apiReady
   145  
   146  		kubeFlags := &genericclioptions.ConfigFlags{
   147  			KubeConfig: strPtr("/dev/null"),
   148  			APIServer:  strPtr(fmt.Sprintf("http://localhost:%d", apiserverAddr.Port)),
   149  		}
   150  		kubeConfig, err := kubeFlags.ToRESTConfig()
   151  		if err != nil {
   152  			return nil, nil, nil, err
   153  		}
   154  		ki, err := kubernetes.NewForConfig(kubeConfig)
   155  		if err != nil {
   156  			return nil, nil, nil, err
   157  		}
   158  		dialer, err := dnet.NewK8sPortForwardDialer(ctx, kubeConfig, ki)
   159  		if err != nil {
   160  			return nil, nil, nil, err
   161  		}
   162  
   163  		cliConn, err = dialer.Dial(ctx, fmt.Sprintf("pods/SOMEPODNAME.SOMENAMESPACE:%d", podListener.Addr().(*net.TCPAddr).Port))
   164  		t.Log("dialed")
   165  		if err != nil {
   166  			return nil, nil, nil, err
   167  		}
   168  
   169  		srvConn = <-srvConnCh
   170  		return cliConn, srvConn, stop, nil
   171  	}
   172  	// Can't test Client side using nettest.TestConn, because the net.Conn exposed by the spdystream.Stream doesn't return the
   173  	// expected net.Error (it returns io.EOF).
   174  	// t.Run("Client", func(t *testing.T) { nettest.TestConn(t, makePipe) })
   175  	t.Run("Server", func(t *testing.T) { nettest.TestConn(t, flipMakePipe(makePipe)) })
   176  }