github.com/Uptycs/basequery-go@v0.8.0/server_test.go (about)

     1  package osquery
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net"
     9  	"os"
    10  	"sync"
    11  	"syscall"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/apache/thrift/lib/go/thrift"
    16  
    17  	"github.com/Uptycs/basequery-go/gen/osquery"
    18  	"github.com/Uptycs/basequery-go/plugin/logger"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  // Verify that an error in server.Start will return an error instead of deadlock.
    24  func TestNoDeadlockOnError(t *testing.T) {
    25  	registry := make(map[string](map[string]Plugin))
    26  	for reg := range validRegistryNames {
    27  		registry[reg] = make(map[string]Plugin)
    28  	}
    29  	mut := sync.Mutex{}
    30  	mock := &MockExtensionManager{
    31  		RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
    32  			mut.Lock()
    33  			defer mut.Unlock()
    34  			return nil, errors.New("Boom")
    35  		},
    36  		PingFunc: func() (*osquery.ExtensionStatus, error) {
    37  			return &osquery.ExtensionStatus{}, nil
    38  		},
    39  	}
    40  	server := &ExtensionManagerServer{
    41  		serverClient: mock,
    42  		registry:     registry,
    43  	}
    44  
    45  	log := func(ctx context.Context, typ logger.LogType, logText string) error {
    46  		fmt.Printf("%s: %s\n", typ, logText)
    47  		return nil
    48  	}
    49  	server.RegisterPlugin(logger.NewPlugin("testLogger", log))
    50  
    51  	err := server.Run()
    52  	assert.Error(t, err)
    53  	mut.Lock()
    54  	defer mut.Unlock()
    55  	assert.True(t, mock.RegisterExtensionFuncInvoked)
    56  }
    57  
    58  // Ensure that the extension server will shutdown and return if the osquery
    59  // instance it is talking to stops responding to pings.
    60  func TestShutdownWhenPingFails(t *testing.T) {
    61  	registry := make(map[string](map[string]Plugin))
    62  	for reg := range validRegistryNames {
    63  		registry[reg] = make(map[string]Plugin)
    64  	}
    65  	mock := &MockExtensionManager{
    66  		RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
    67  			return &osquery.ExtensionStatus{}, nil
    68  		},
    69  		PingFunc: func() (*osquery.ExtensionStatus, error) {
    70  			// As if the socket was closed
    71  			return nil, syscall.EPIPE
    72  		},
    73  	}
    74  	server := &ExtensionManagerServer{
    75  		serverClient: mock,
    76  		registry:     registry,
    77  	}
    78  
    79  	err := server.Run()
    80  	assert.Error(t, err)
    81  	assert.Contains(t, err.Error(), "broken pipe")
    82  }
    83  
    84  // How many parallel tests to run (because sync issues do not occur on every
    85  // run, this maximizes our chances of seeing any issue by quickly executing
    86  // many runs of the test).
    87  const parallelTestShutdownDeadlock = 20
    88  
    89  func TestShutdownDeadlock(t *testing.T) {
    90  	for i := 0; i < parallelTestShutdownDeadlock; i++ {
    91  		t.Run("", func(t *testing.T) {
    92  			t.Parallel()
    93  			testShutdownDeadlock(t)
    94  		})
    95  	}
    96  }
    97  func testShutdownDeadlock(t *testing.T) {
    98  	tempPath, err := ioutil.TempFile("", "")
    99  	require.Nil(t, err)
   100  	defer os.Remove(tempPath.Name())
   101  
   102  	retUUID := osquery.ExtensionRouteUUID(0)
   103  	mock := &MockExtensionManager{
   104  		RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
   105  			return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
   106  		},
   107  	}
   108  	server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
   109  
   110  	wait := sync.WaitGroup{}
   111  
   112  	wait.Add(1)
   113  	go func() {
   114  		err := server.Start()
   115  		require.Nil(t, err)
   116  		wait.Done()
   117  	}()
   118  	// Wait for server to be set up
   119  	server.waitStarted()
   120  
   121  	// Create a raw client to access the shutdown method that is not
   122  	// usually exposed.
   123  	listenPath := fmt.Sprintf("%s.%d", tempPath.Name(), retUUID)
   124  	addr, err := net.ResolveUnixAddr("unix", listenPath)
   125  	require.Nil(t, err)
   126  	timeout := 500 * time.Millisecond
   127  	trans := thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
   128  	err = trans.Open()
   129  	require.Nil(t, err)
   130  	client := osquery.NewExtensionManagerClientFactory(trans,
   131  		thrift.NewTBinaryProtocolFactoryDefault())
   132  
   133  	// Simultaneously call shutdown through a request from the client and
   134  	// directly on the server object.
   135  	wait.Add(1)
   136  	go func() {
   137  		defer wait.Done()
   138  		client.Shutdown(context.Background())
   139  	}()
   140  
   141  	wait.Add(1)
   142  	go func() {
   143  		defer wait.Done()
   144  		err = server.Shutdown(context.Background())
   145  		require.Nil(t, err)
   146  	}()
   147  
   148  	// Track whether shutdown completed
   149  	completed := make(chan struct{})
   150  	go func() {
   151  		wait.Wait()
   152  		close(completed)
   153  	}()
   154  
   155  	// either indicate successful shutdown, or fatal the test because it
   156  	// hung
   157  	select {
   158  	case <-completed:
   159  		// Success. Do nothing.
   160  	case <-time.After(5 * time.Second):
   161  		t.Fatal("hung on shutdown")
   162  	}
   163  }
   164  
   165  func TestShutdownBasic(t *testing.T) {
   166  	tempPath, err := ioutil.TempFile("", "")
   167  	require.Nil(t, err)
   168  	defer os.Remove(tempPath.Name())
   169  
   170  	retUUID := osquery.ExtensionRouteUUID(0)
   171  	mock := &MockExtensionManager{
   172  		RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
   173  			return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
   174  		},
   175  	}
   176  	server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
   177  
   178  	completed := make(chan struct{})
   179  	go func() {
   180  		err := server.Start()
   181  		require.NoError(t, err)
   182  		close(completed)
   183  	}()
   184  
   185  	server.waitStarted()
   186  	err = server.Shutdown(context.Background())
   187  	require.NoError(t, err)
   188  
   189  	// Either indicate successful shutdown, or fatal the test because it
   190  	// hung
   191  	select {
   192  	case <-completed:
   193  		// Success. Do nothing.
   194  	case <-time.After(5 * time.Second):
   195  		t.Fatal("hung on shutdown")
   196  	}
   197  }