github.com/thajeztah/cli@v0.0.0-20240223162942-dc6bfac81a8b/cli-plugins/socket/socket_test.go (about) 1 package socket 2 3 import ( 4 "io/fs" 5 "net" 6 "os" 7 "runtime" 8 "strings" 9 "testing" 10 "time" 11 12 "gotest.tools/v3/assert" 13 "gotest.tools/v3/poll" 14 ) 15 16 func TestSetupConn(t *testing.T) { 17 t.Run("updates conn when connected", func(t *testing.T) { 18 var conn *net.UnixConn 19 listener, err := SetupConn(&conn) 20 assert.NilError(t, err) 21 assert.Check(t, listener != nil, "returned nil listener but no error") 22 addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) 23 assert.NilError(t, err, "failed to resolve listener address") 24 25 _, err = net.DialUnix("unix", nil, addr) 26 assert.NilError(t, err, "failed to dial returned listener") 27 28 pollConnNotNil(t, &conn) 29 }) 30 31 t.Run("allows reconnects", func(t *testing.T) { 32 var conn *net.UnixConn 33 listener, err := SetupConn(&conn) 34 assert.NilError(t, err) 35 assert.Check(t, listener != nil, "returned nil listener but no error") 36 addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) 37 assert.NilError(t, err, "failed to resolve listener address") 38 39 otherConn, err := net.DialUnix("unix", nil, addr) 40 assert.NilError(t, err, "failed to dial returned listener") 41 42 otherConn.Close() 43 44 _, err = net.DialUnix("unix", nil, addr) 45 assert.NilError(t, err, "failed to redial listener") 46 }) 47 48 t.Run("does not leak sockets to local directory", func(t *testing.T) { 49 var conn *net.UnixConn 50 listener, err := SetupConn(&conn) 51 assert.NilError(t, err) 52 assert.Check(t, listener != nil, "returned nil listener but no error") 53 checkDirNoPluginSocket(t) 54 55 addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) 56 assert.NilError(t, err, "failed to resolve listener address") 57 _, err = net.DialUnix("unix", nil, addr) 58 assert.NilError(t, err, "failed to dial returned listener") 59 checkDirNoPluginSocket(t) 60 }) 61 } 62 63 func checkDirNoPluginSocket(t *testing.T) { 64 t.Helper() 65 66 files, err := os.ReadDir(".") 67 assert.NilError(t, err, "failed to list files in dir to check for leaked sockets") 68 69 for _, f := range files { 70 info, err := f.Info() 71 assert.NilError(t, err, "failed to check file info") 72 // check for a socket with `docker_cli_` in the name (from `SetupConn()`) 73 if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket { 74 t.Fatal("found socket in a local directory") 75 } 76 } 77 } 78 79 func TestConnectAndWait(t *testing.T) { 80 t.Run("calls cancel func on EOF", func(t *testing.T) { 81 var conn *net.UnixConn 82 listener, err := SetupConn(&conn) 83 assert.NilError(t, err, "failed to setup listener") 84 85 done := make(chan struct{}) 86 t.Setenv(EnvKey, listener.Addr().String()) 87 cancelFunc := func() { 88 done <- struct{}{} 89 } 90 ConnectAndWait(cancelFunc) 91 pollConnNotNil(t, &conn) 92 conn.Close() 93 94 select { 95 case <-done: 96 case <-time.After(10 * time.Millisecond): 97 t.Fatal("cancel function not closed after 10ms") 98 } 99 }) 100 101 // TODO: this test cannot be executed with `t.Parallel()`, due to 102 // relying on goroutine numbers to ensure correct behaviour 103 t.Run("connect goroutine exits after EOF", func(t *testing.T) { 104 var conn *net.UnixConn 105 listener, err := SetupConn(&conn) 106 assert.NilError(t, err, "failed to setup listener") 107 t.Setenv(EnvKey, listener.Addr().String()) 108 numGoroutines := runtime.NumGoroutine() 109 110 ConnectAndWait(func() {}) 111 assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) 112 113 pollConnNotNil(t, &conn) 114 conn.Close() 115 poll.WaitOn(t, func(t poll.LogT) poll.Result { 116 if runtime.NumGoroutine() > numGoroutines+1 { 117 return poll.Continue("waiting for connect goroutine to exit") 118 } 119 return poll.Success() 120 }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) 121 }) 122 } 123 124 func pollConnNotNil(t *testing.T, conn **net.UnixConn) { 125 t.Helper() 126 127 poll.WaitOn(t, func(t poll.LogT) poll.Result { 128 if *conn == nil { 129 return poll.Continue("waiting for conn to not be nil") 130 } 131 return poll.Success() 132 }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) 133 }