github.com/Uptycs/basequery-go@v0.8.0/server_test.go (about) 1 package osquery 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "net" 9 "os" 10 "sync" 11 "syscall" 12 "testing" 13 "time" 14 15 "github.com/apache/thrift/lib/go/thrift" 16 17 "github.com/Uptycs/basequery-go/gen/osquery" 18 "github.com/Uptycs/basequery-go/plugin/logger" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 ) 22 23 // Verify that an error in server.Start will return an error instead of deadlock. 24 func TestNoDeadlockOnError(t *testing.T) { 25 registry := make(map[string](map[string]Plugin)) 26 for reg := range validRegistryNames { 27 registry[reg] = make(map[string]Plugin) 28 } 29 mut := sync.Mutex{} 30 mock := &MockExtensionManager{ 31 RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) { 32 mut.Lock() 33 defer mut.Unlock() 34 return nil, errors.New("Boom") 35 }, 36 PingFunc: func() (*osquery.ExtensionStatus, error) { 37 return &osquery.ExtensionStatus{}, nil 38 }, 39 } 40 server := &ExtensionManagerServer{ 41 serverClient: mock, 42 registry: registry, 43 } 44 45 log := func(ctx context.Context, typ logger.LogType, logText string) error { 46 fmt.Printf("%s: %s\n", typ, logText) 47 return nil 48 } 49 server.RegisterPlugin(logger.NewPlugin("testLogger", log)) 50 51 err := server.Run() 52 assert.Error(t, err) 53 mut.Lock() 54 defer mut.Unlock() 55 assert.True(t, mock.RegisterExtensionFuncInvoked) 56 } 57 58 // Ensure that the extension server will shutdown and return if the osquery 59 // instance it is talking to stops responding to pings. 60 func TestShutdownWhenPingFails(t *testing.T) { 61 registry := make(map[string](map[string]Plugin)) 62 for reg := range validRegistryNames { 63 registry[reg] = make(map[string]Plugin) 64 } 65 mock := &MockExtensionManager{ 66 RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) { 67 return &osquery.ExtensionStatus{}, nil 68 }, 69 PingFunc: func() (*osquery.ExtensionStatus, error) { 70 // As if the socket was closed 71 return nil, syscall.EPIPE 72 }, 73 } 74 server := &ExtensionManagerServer{ 75 serverClient: mock, 76 registry: registry, 77 } 78 79 err := server.Run() 80 assert.Error(t, err) 81 assert.Contains(t, err.Error(), "broken pipe") 82 } 83 84 // How many parallel tests to run (because sync issues do not occur on every 85 // run, this maximizes our chances of seeing any issue by quickly executing 86 // many runs of the test). 87 const parallelTestShutdownDeadlock = 20 88 89 func TestShutdownDeadlock(t *testing.T) { 90 for i := 0; i < parallelTestShutdownDeadlock; i++ { 91 t.Run("", func(t *testing.T) { 92 t.Parallel() 93 testShutdownDeadlock(t) 94 }) 95 } 96 } 97 func testShutdownDeadlock(t *testing.T) { 98 tempPath, err := ioutil.TempFile("", "") 99 require.Nil(t, err) 100 defer os.Remove(tempPath.Name()) 101 102 retUUID := osquery.ExtensionRouteUUID(0) 103 mock := &MockExtensionManager{ 104 RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) { 105 return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil 106 }, 107 } 108 server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} 109 110 wait := sync.WaitGroup{} 111 112 wait.Add(1) 113 go func() { 114 err := server.Start() 115 require.Nil(t, err) 116 wait.Done() 117 }() 118 // Wait for server to be set up 119 server.waitStarted() 120 121 // Create a raw client to access the shutdown method that is not 122 // usually exposed. 123 listenPath := fmt.Sprintf("%s.%d", tempPath.Name(), retUUID) 124 addr, err := net.ResolveUnixAddr("unix", listenPath) 125 require.Nil(t, err) 126 timeout := 500 * time.Millisecond 127 trans := thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout) 128 err = trans.Open() 129 require.Nil(t, err) 130 client := osquery.NewExtensionManagerClientFactory(trans, 131 thrift.NewTBinaryProtocolFactoryDefault()) 132 133 // Simultaneously call shutdown through a request from the client and 134 // directly on the server object. 135 wait.Add(1) 136 go func() { 137 defer wait.Done() 138 client.Shutdown(context.Background()) 139 }() 140 141 wait.Add(1) 142 go func() { 143 defer wait.Done() 144 err = server.Shutdown(context.Background()) 145 require.Nil(t, err) 146 }() 147 148 // Track whether shutdown completed 149 completed := make(chan struct{}) 150 go func() { 151 wait.Wait() 152 close(completed) 153 }() 154 155 // either indicate successful shutdown, or fatal the test because it 156 // hung 157 select { 158 case <-completed: 159 // Success. Do nothing. 160 case <-time.After(5 * time.Second): 161 t.Fatal("hung on shutdown") 162 } 163 } 164 165 func TestShutdownBasic(t *testing.T) { 166 tempPath, err := ioutil.TempFile("", "") 167 require.Nil(t, err) 168 defer os.Remove(tempPath.Name()) 169 170 retUUID := osquery.ExtensionRouteUUID(0) 171 mock := &MockExtensionManager{ 172 RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) { 173 return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil 174 }, 175 } 176 server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} 177 178 completed := make(chan struct{}) 179 go func() { 180 err := server.Start() 181 require.NoError(t, err) 182 close(completed) 183 }() 184 185 server.waitStarted() 186 err = server.Shutdown(context.Background()) 187 require.NoError(t, err) 188 189 // Either indicate successful shutdown, or fatal the test because it 190 // hung 191 select { 192 case <-completed: 193 // Success. Do nothing. 194 case <-time.After(5 * time.Second): 195 t.Fatal("hung on shutdown") 196 } 197 }