github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/cli-plugins/socket/socket_test.go (about)

     1  package socket
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"io/fs"
     7  	"net"
     8  	"os"
     9  	"runtime"
    10  	"strings"
    11  	"sync/atomic"
    12  	"testing"
    13  	"time"
    14  
    15  	"gotest.tools/v3/assert"
    16  	"gotest.tools/v3/poll"
    17  )
    18  
    19  func TestPluginServer(t *testing.T) {
    20  	t.Run("connection closes with EOF when server closes", func(t *testing.T) {
    21  		called := make(chan struct{})
    22  		srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
    23  		assert.NilError(t, err)
    24  		assert.Assert(t, srv != nil, "returned nil server but no error")
    25  
    26  		addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
    27  		assert.NilError(t, err, "failed to resolve server address")
    28  
    29  		conn, err := net.DialUnix("unix", nil, addr)
    30  		assert.NilError(t, err, "failed to dial returned server")
    31  		defer conn.Close()
    32  
    33  		done := make(chan error, 1)
    34  		go func() {
    35  			_, err := conn.Read(make([]byte, 1))
    36  			done <- err
    37  		}()
    38  
    39  		select {
    40  		case <-called:
    41  		case <-time.After(10 * time.Millisecond):
    42  			t.Fatal("handler not called")
    43  		}
    44  
    45  		srv.Close()
    46  
    47  		select {
    48  		case err := <-done:
    49  			if !errors.Is(err, io.EOF) {
    50  				t.Fatalf("exepcted EOF error, got: %v", err)
    51  			}
    52  		case <-time.After(10 * time.Millisecond):
    53  		}
    54  	})
    55  
    56  	t.Run("allows reconnects", func(t *testing.T) {
    57  		var calls int32
    58  		h := func(_ net.Conn) {
    59  			atomic.AddInt32(&calls, 1)
    60  		}
    61  
    62  		srv, err := NewPluginServer(h)
    63  		assert.NilError(t, err)
    64  		defer srv.Close()
    65  
    66  		assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
    67  
    68  		addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
    69  		assert.NilError(t, err, "failed to resolve server address")
    70  
    71  		waitForCalls := func(n int) {
    72  			poll.WaitOn(t, func(t poll.LogT) poll.Result {
    73  				if atomic.LoadInt32(&calls) == int32(n) {
    74  					return poll.Success()
    75  				}
    76  				return poll.Continue("waiting for handler to be called")
    77  			})
    78  		}
    79  
    80  		otherConn, err := net.DialUnix("unix", nil, addr)
    81  		assert.NilError(t, err, "failed to dial returned server")
    82  		otherConn.Close()
    83  		waitForCalls(1)
    84  
    85  		conn, err := net.DialUnix("unix", nil, addr)
    86  		assert.NilError(t, err, "failed to redial server")
    87  		defer conn.Close()
    88  		waitForCalls(2)
    89  
    90  		// and again but don't close the existing connection
    91  		conn2, err := net.DialUnix("unix", nil, addr)
    92  		assert.NilError(t, err, "failed to redial server")
    93  		defer conn2.Close()
    94  		waitForCalls(3)
    95  
    96  		srv.Close()
    97  
    98  		// now make sure we get EOF on the existing connections
    99  		buf := make([]byte, 1)
   100  		_, err = conn.Read(buf)
   101  		assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
   102  
   103  		_, err = conn2.Read(buf)
   104  		assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
   105  	})
   106  
   107  	t.Run("does not leak sockets to local directory", func(t *testing.T) {
   108  		srv, err := NewPluginServer(nil)
   109  		assert.NilError(t, err)
   110  		assert.Check(t, srv != nil, "returned nil server but no error")
   111  		checkDirNoNewPluginServer(t)
   112  
   113  		addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
   114  		assert.NilError(t, err, "failed to resolve server address")
   115  
   116  		_, err = net.DialUnix("unix", nil, addr)
   117  		assert.NilError(t, err, "failed to dial returned server")
   118  		checkDirNoNewPluginServer(t)
   119  	})
   120  }
   121  
   122  func checkDirNoNewPluginServer(t *testing.T) {
   123  	t.Helper()
   124  
   125  	files, err := os.ReadDir(".")
   126  	assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")
   127  
   128  	for _, f := range files {
   129  		info, err := f.Info()
   130  		assert.NilError(t, err, "failed to check file info")
   131  		// check for a socket with `docker_cli_` in the name (from `SetupConn()`)
   132  		if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket {
   133  			t.Fatal("found socket in a local directory")
   134  		}
   135  	}
   136  }
   137  
   138  func TestConnectAndWait(t *testing.T) {
   139  	t.Run("calls cancel func on EOF", func(t *testing.T) {
   140  		srv, err := NewPluginServer(nil)
   141  		assert.NilError(t, err, "failed to setup server")
   142  		defer srv.Close()
   143  
   144  		done := make(chan struct{})
   145  		t.Setenv(EnvKey, srv.Addr().String())
   146  		cancelFunc := func() {
   147  			done <- struct{}{}
   148  		}
   149  		ConnectAndWait(cancelFunc)
   150  
   151  		select {
   152  		case <-done:
   153  			t.Fatal("unexpectedly done")
   154  		default:
   155  		}
   156  
   157  		srv.Close()
   158  
   159  		select {
   160  		case <-done:
   161  		case <-time.After(10 * time.Millisecond):
   162  			t.Fatal("cancel function not closed after 10ms")
   163  		}
   164  	})
   165  
   166  	// TODO: this test cannot be executed with `t.Parallel()`, due to
   167  	// relying on goroutine numbers to ensure correct behaviour
   168  	t.Run("connect goroutine exits after EOF", func(t *testing.T) {
   169  		srv, err := NewPluginServer(nil)
   170  		assert.NilError(t, err, "failed to setup server")
   171  
   172  		defer srv.Close()
   173  
   174  		t.Setenv(EnvKey, srv.Addr().String())
   175  		numGoroutines := runtime.NumGoroutine()
   176  
   177  		ConnectAndWait(func() {})
   178  		assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
   179  
   180  		srv.Close()
   181  
   182  		poll.WaitOn(t, func(t poll.LogT) poll.Result {
   183  			if runtime.NumGoroutine() > numGoroutines+1 {
   184  				return poll.Continue("waiting for connect goroutine to exit")
   185  			}
   186  			return poll.Success()
   187  		}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
   188  	})
   189  }