github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/cli-plugins/socket/socket_test.go (about) 1 package socket 2 3 import ( 4 "errors" 5 "io" 6 "io/fs" 7 "net" 8 "os" 9 "runtime" 10 "strings" 11 "sync/atomic" 12 "testing" 13 "time" 14 15 "gotest.tools/v3/assert" 16 "gotest.tools/v3/poll" 17 ) 18 19 func TestPluginServer(t *testing.T) { 20 t.Run("connection closes with EOF when server closes", func(t *testing.T) { 21 called := make(chan struct{}) 22 srv, err := NewPluginServer(func(_ net.Conn) { close(called) }) 23 assert.NilError(t, err) 24 assert.Assert(t, srv != nil, "returned nil server but no error") 25 26 addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) 27 assert.NilError(t, err, "failed to resolve server address") 28 29 conn, err := net.DialUnix("unix", nil, addr) 30 assert.NilError(t, err, "failed to dial returned server") 31 defer conn.Close() 32 33 done := make(chan error, 1) 34 go func() { 35 _, err := conn.Read(make([]byte, 1)) 36 done <- err 37 }() 38 39 select { 40 case <-called: 41 case <-time.After(10 * time.Millisecond): 42 t.Fatal("handler not called") 43 } 44 45 srv.Close() 46 47 select { 48 case err := <-done: 49 if !errors.Is(err, io.EOF) { 50 t.Fatalf("exepcted EOF error, got: %v", err) 51 } 52 case <-time.After(10 * time.Millisecond): 53 } 54 }) 55 56 t.Run("allows reconnects", func(t *testing.T) { 57 var calls int32 58 h := func(_ net.Conn) { 59 atomic.AddInt32(&calls, 1) 60 } 61 62 srv, err := NewPluginServer(h) 63 assert.NilError(t, err) 64 defer srv.Close() 65 66 assert.Check(t, srv.Addr() != nil, "returned nil addr but no error") 67 68 addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) 69 assert.NilError(t, err, "failed to resolve server address") 70 71 waitForCalls := func(n int) { 72 poll.WaitOn(t, func(t poll.LogT) poll.Result { 73 if atomic.LoadInt32(&calls) == int32(n) { 74 return poll.Success() 75 } 76 return poll.Continue("waiting for handler to be called") 77 }) 78 } 79 80 otherConn, err := net.DialUnix("unix", nil, addr) 81 assert.NilError(t, err, "failed to dial returned server") 82 otherConn.Close() 83 waitForCalls(1) 84 85 conn, err := net.DialUnix("unix", nil, addr) 86 assert.NilError(t, err, "failed to redial server") 87 defer conn.Close() 88 waitForCalls(2) 89 90 // and again but don't close the existing connection 91 conn2, err := net.DialUnix("unix", nil, addr) 92 assert.NilError(t, err, "failed to redial server") 93 defer conn2.Close() 94 waitForCalls(3) 95 96 srv.Close() 97 98 // now make sure we get EOF on the existing connections 99 buf := make([]byte, 1) 100 _, err = conn.Read(buf) 101 assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) 102 103 _, err = conn2.Read(buf) 104 assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) 105 }) 106 107 t.Run("does not leak sockets to local directory", func(t *testing.T) { 108 srv, err := NewPluginServer(nil) 109 assert.NilError(t, err) 110 assert.Check(t, srv != nil, "returned nil server but no error") 111 checkDirNoNewPluginServer(t) 112 113 addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) 114 assert.NilError(t, err, "failed to resolve server address") 115 116 _, err = net.DialUnix("unix", nil, addr) 117 assert.NilError(t, err, "failed to dial returned server") 118 checkDirNoNewPluginServer(t) 119 }) 120 } 121 122 func checkDirNoNewPluginServer(t *testing.T) { 123 t.Helper() 124 125 files, err := os.ReadDir(".") 126 assert.NilError(t, err, "failed to list files in dir to check for leaked sockets") 127 128 for _, f := range files { 129 info, err := f.Info() 130 assert.NilError(t, err, "failed to check file info") 131 // check for a socket with `docker_cli_` in the name (from `SetupConn()`) 132 if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket { 133 t.Fatal("found socket in a local directory") 134 } 135 } 136 } 137 138 func TestConnectAndWait(t *testing.T) { 139 t.Run("calls cancel func on EOF", func(t *testing.T) { 140 srv, err := NewPluginServer(nil) 141 assert.NilError(t, err, "failed to setup server") 142 defer srv.Close() 143 144 done := make(chan struct{}) 145 t.Setenv(EnvKey, srv.Addr().String()) 146 cancelFunc := func() { 147 done <- struct{}{} 148 } 149 ConnectAndWait(cancelFunc) 150 151 select { 152 case <-done: 153 t.Fatal("unexpectedly done") 154 default: 155 } 156 157 srv.Close() 158 159 select { 160 case <-done: 161 case <-time.After(10 * time.Millisecond): 162 t.Fatal("cancel function not closed after 10ms") 163 } 164 }) 165 166 // TODO: this test cannot be executed with `t.Parallel()`, due to 167 // relying on goroutine numbers to ensure correct behaviour 168 t.Run("connect goroutine exits after EOF", func(t *testing.T) { 169 srv, err := NewPluginServer(nil) 170 assert.NilError(t, err, "failed to setup server") 171 172 defer srv.Close() 173 174 t.Setenv(EnvKey, srv.Addr().String()) 175 numGoroutines := runtime.NumGoroutine() 176 177 ConnectAndWait(func() {}) 178 assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) 179 180 srv.Close() 181 182 poll.WaitOn(t, func(t poll.LogT) poll.Result { 183 if runtime.NumGoroutine() > numGoroutines+1 { 184 return poll.Continue("waiting for connect goroutine to exit") 185 } 186 return poll.Success() 187 }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) 188 }) 189 }