github.com/thajeztah/cli@v0.0.0-20240223162942-dc6bfac81a8b/cli-plugins/socket/socket_test.go (about)

     1  package socket
     2  
     3  import (
     4  	"io/fs"
     5  	"net"
     6  	"os"
     7  	"runtime"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"gotest.tools/v3/assert"
    13  	"gotest.tools/v3/poll"
    14  )
    15  
    16  func TestSetupConn(t *testing.T) {
    17  	t.Run("updates conn when connected", func(t *testing.T) {
    18  		var conn *net.UnixConn
    19  		listener, err := SetupConn(&conn)
    20  		assert.NilError(t, err)
    21  		assert.Check(t, listener != nil, "returned nil listener but no error")
    22  		addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
    23  		assert.NilError(t, err, "failed to resolve listener address")
    24  
    25  		_, err = net.DialUnix("unix", nil, addr)
    26  		assert.NilError(t, err, "failed to dial returned listener")
    27  
    28  		pollConnNotNil(t, &conn)
    29  	})
    30  
    31  	t.Run("allows reconnects", func(t *testing.T) {
    32  		var conn *net.UnixConn
    33  		listener, err := SetupConn(&conn)
    34  		assert.NilError(t, err)
    35  		assert.Check(t, listener != nil, "returned nil listener but no error")
    36  		addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
    37  		assert.NilError(t, err, "failed to resolve listener address")
    38  
    39  		otherConn, err := net.DialUnix("unix", nil, addr)
    40  		assert.NilError(t, err, "failed to dial returned listener")
    41  
    42  		otherConn.Close()
    43  
    44  		_, err = net.DialUnix("unix", nil, addr)
    45  		assert.NilError(t, err, "failed to redial listener")
    46  	})
    47  
    48  	t.Run("does not leak sockets to local directory", func(t *testing.T) {
    49  		var conn *net.UnixConn
    50  		listener, err := SetupConn(&conn)
    51  		assert.NilError(t, err)
    52  		assert.Check(t, listener != nil, "returned nil listener but no error")
    53  		checkDirNoPluginSocket(t)
    54  
    55  		addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
    56  		assert.NilError(t, err, "failed to resolve listener address")
    57  		_, err = net.DialUnix("unix", nil, addr)
    58  		assert.NilError(t, err, "failed to dial returned listener")
    59  		checkDirNoPluginSocket(t)
    60  	})
    61  }
    62  
    63  func checkDirNoPluginSocket(t *testing.T) {
    64  	t.Helper()
    65  
    66  	files, err := os.ReadDir(".")
    67  	assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")
    68  
    69  	for _, f := range files {
    70  		info, err := f.Info()
    71  		assert.NilError(t, err, "failed to check file info")
    72  		// check for a socket with `docker_cli_` in the name (from `SetupConn()`)
    73  		if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket {
    74  			t.Fatal("found socket in a local directory")
    75  		}
    76  	}
    77  }
    78  
    79  func TestConnectAndWait(t *testing.T) {
    80  	t.Run("calls cancel func on EOF", func(t *testing.T) {
    81  		var conn *net.UnixConn
    82  		listener, err := SetupConn(&conn)
    83  		assert.NilError(t, err, "failed to setup listener")
    84  
    85  		done := make(chan struct{})
    86  		t.Setenv(EnvKey, listener.Addr().String())
    87  		cancelFunc := func() {
    88  			done <- struct{}{}
    89  		}
    90  		ConnectAndWait(cancelFunc)
    91  		pollConnNotNil(t, &conn)
    92  		conn.Close()
    93  
    94  		select {
    95  		case <-done:
    96  		case <-time.After(10 * time.Millisecond):
    97  			t.Fatal("cancel function not closed after 10ms")
    98  		}
    99  	})
   100  
   101  	// TODO: this test cannot be executed with `t.Parallel()`, due to
   102  	// relying on goroutine numbers to ensure correct behaviour
   103  	t.Run("connect goroutine exits after EOF", func(t *testing.T) {
   104  		var conn *net.UnixConn
   105  		listener, err := SetupConn(&conn)
   106  		assert.NilError(t, err, "failed to setup listener")
   107  		t.Setenv(EnvKey, listener.Addr().String())
   108  		numGoroutines := runtime.NumGoroutine()
   109  
   110  		ConnectAndWait(func() {})
   111  		assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
   112  
   113  		pollConnNotNil(t, &conn)
   114  		conn.Close()
   115  		poll.WaitOn(t, func(t poll.LogT) poll.Result {
   116  			if runtime.NumGoroutine() > numGoroutines+1 {
   117  				return poll.Continue("waiting for connect goroutine to exit")
   118  			}
   119  			return poll.Success()
   120  		}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
   121  	})
   122  }
   123  
   124  func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
   125  	t.Helper()
   126  
   127  	poll.WaitOn(t, func(t poll.LogT) poll.Result {
   128  		if *conn == nil {
   129  			return poll.Continue("waiting for conn to not be nil")
   130  		}
   131  		return poll.Success()
   132  	}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
   133  }