github.com/Uptycs/basequery-go@v0.8.0/server.go (about)

     1  package osquery
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"strconv"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/apache/thrift/lib/go/thrift"
    12  
    13  	"github.com/Uptycs/basequery-go/gen/osquery"
    14  	"github.com/Uptycs/basequery-go/transport"
    15  	"github.com/pkg/errors"
    16  	"github.com/prometheus/client_golang/prometheus"
    17  	"github.com/prometheus/client_golang/prometheus/promauto"
    18  	"github.com/prometheus/client_golang/prometheus/promhttp"
    19  )
    20  
    21  // Plugin exposes the basequery Plugin interface.
    22  type Plugin interface {
    23  	// Name is the name used to refer to the plugin (eg. the name of the
    24  	// table the plugin implements).
    25  	Name() string
    26  	// RegistryName is which "registry" the plugin should be added to.
    27  	// Valid names are ["config", "logger", "table"].
    28  	RegistryName() string
    29  	// Routes returns the detailed information about the interface exposed
    30  	// by the plugin. See the example plugins for samples.
    31  	Routes() osquery.ExtensionPluginResponse
    32  	// Ping implements a health check for the plugin. If the plugin is in a
    33  	// healthy state, StatusOK should be returned.
    34  	Ping() osquery.ExtensionStatus
    35  	// Call requests the plugin to perform its defined behavior, returning
    36  	// a response containing the result.
    37  	Call(context.Context, osquery.ExtensionPluginRequest) osquery.ExtensionResponse
    38  	// Shutdown alerts the plugin to stop.
    39  	Shutdown()
    40  }
    41  
    42  const defaultTimeout = 1 * time.Second
    43  const defaultPingInterval = 5 * time.Second
    44  
    45  // ExtensionManagerServer is an implementation of the full ExtensionManager
    46  // API. Plugins can register with an extension manager, which handles the
    47  // communication with the osquery process.
    48  type ExtensionManagerServer struct {
    49  	name           string
    50  	version        string
    51  	sockPath       string
    52  	serverClient   ExtensionManager
    53  	registry       map[string](map[string]Plugin)
    54  	promServer     *http.Server
    55  	pluginCounter  *prometheus.CounterVec
    56  	pluginGauge    *prometheus.GaugeVec
    57  	pluginTime     *prometheus.HistogramVec
    58  	server         thrift.TServer
    59  	transport      thrift.TServerTransport
    60  	timeout        time.Duration
    61  	pingInterval   time.Duration // How often to ping osquery server
    62  	prometheusPort uint16        // Expose prometheus metrics, if > 0
    63  	mutex          sync.Mutex
    64  	started        bool // Used to ensure tests wait until the server is actually started
    65  }
    66  
    67  // validRegistryNames contains the allowable RegistryName() values. If a plugin
    68  // attempts to register with another value, the program will panic.
    69  var validRegistryNames = map[string]bool{
    70  	"table":       true,
    71  	"logger":      true,
    72  	"config":      true,
    73  	"distributed": true,
    74  }
    75  
    76  // ServerOption is function for setting extension manager server options.
    77  type ServerOption func(*ExtensionManagerServer)
    78  
    79  // ServerVersion can be used to specify the basequery SDK version.
    80  func ServerVersion(version string) ServerOption {
    81  	return func(s *ExtensionManagerServer) {
    82  		s.version = version
    83  	}
    84  }
    85  
    86  // ServerTimeout sets timeout duration for thrift socket.
    87  func ServerTimeout(timeout time.Duration) ServerOption {
    88  	return func(s *ExtensionManagerServer) {
    89  		s.timeout = timeout
    90  	}
    91  }
    92  
    93  // ServerPingInterval can be used to configure health check ping interval/frequency.
    94  func ServerPingInterval(interval time.Duration) ServerOption {
    95  	return func(s *ExtensionManagerServer) {
    96  		s.pingInterval = interval
    97  	}
    98  }
    99  
   100  // ServerPrometheusPort is used to specify the port on which prometheus metrics will be exposed.
   101  // By default this is disabled (0). A positive integer port value should be specified to enable it.
   102  func ServerPrometheusPort(port uint16) ServerOption {
   103  	return func(s *ExtensionManagerServer) {
   104  		s.prometheusPort = port
   105  	}
   106  }
   107  
   108  // NewExtensionManagerServer creates a new extension management server
   109  // communicating with osquery over the socket at the provided path. If
   110  // resolving the address or connecting to the socket fails, this function will
   111  // error.
   112  func NewExtensionManagerServer(name string, sockPath string, opts ...ServerOption) (*ExtensionManagerServer, error) {
   113  	// Initialize nested registry maps
   114  	registry := make(map[string](map[string]Plugin))
   115  	for reg := range validRegistryNames {
   116  		registry[reg] = make(map[string]Plugin)
   117  	}
   118  
   119  	manager := &ExtensionManagerServer{
   120  		name:           name,
   121  		sockPath:       sockPath,
   122  		registry:       registry,
   123  		timeout:        defaultTimeout,
   124  		pingInterval:   defaultPingInterval,
   125  		prometheusPort: 0,
   126  	}
   127  
   128  	for _, opt := range opts {
   129  		opt(manager)
   130  	}
   131  
   132  	serverClient, err := NewClient(sockPath, manager.timeout)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	manager.serverClient = serverClient
   137  
   138  	return manager, nil
   139  }
   140  
   141  // GetClient returns the extension manager client.
   142  func (s *ExtensionManagerServer) GetClient() ExtensionManager {
   143  	return s.serverClient
   144  }
   145  
   146  // RegisterPlugin adds one or more OsqueryPlugins to this extension manager.
   147  func (s *ExtensionManagerServer) RegisterPlugin(plugins ...Plugin) {
   148  	s.mutex.Lock()
   149  	defer s.mutex.Unlock()
   150  	for _, plugin := range plugins {
   151  		if !validRegistryNames[plugin.RegistryName()] {
   152  			panic("invalid registry name: " + plugin.RegistryName())
   153  		}
   154  		s.registry[plugin.RegistryName()][plugin.Name()] = plugin
   155  	}
   156  }
   157  
   158  func (s *ExtensionManagerServer) genRegistry() osquery.ExtensionRegistry {
   159  	registry := osquery.ExtensionRegistry{}
   160  	for regName := range s.registry {
   161  		registry[regName] = osquery.ExtensionRouteTable{}
   162  		for _, plugin := range s.registry[regName] {
   163  			registry[regName][plugin.Name()] = plugin.Routes()
   164  		}
   165  	}
   166  	return registry
   167  }
   168  
   169  // Start registers the extension plugins and begins listening on a unix socket
   170  // for requests from the osquery process. All plugins should be registered with
   171  // RegisterPlugin() before calling Start().
   172  func (s *ExtensionManagerServer) Start() error {
   173  	var server thrift.TServer
   174  	err := func() error {
   175  		s.mutex.Lock()
   176  		defer s.mutex.Unlock()
   177  		registry := s.genRegistry()
   178  
   179  		stat, err := s.serverClient.RegisterExtension(
   180  			&osquery.InternalExtensionInfo{
   181  				Name:    s.name,
   182  				Version: s.version,
   183  			},
   184  			registry,
   185  		)
   186  
   187  		if err != nil {
   188  			return errors.Wrap(err, "registering extension")
   189  		}
   190  		if stat.Code != 0 {
   191  			return errors.Errorf("status %d registering extension: %s", stat.Code, stat.Message)
   192  		}
   193  
   194  		listenPath := fmt.Sprintf("%s.%d", s.sockPath, stat.UUID)
   195  
   196  		processor := osquery.NewExtensionProcessor(s)
   197  
   198  		s.transport, err = transport.OpenServer(listenPath, s.timeout)
   199  		if err != nil {
   200  			return errors.Wrapf(err, "opening server socket (%s)", listenPath)
   201  		}
   202  
   203  		s.server = thrift.NewTSimpleServer2(processor, s.transport)
   204  		server = s.server
   205  
   206  		if s.prometheusPort > 0 {
   207  			mux := http.NewServeMux()
   208  			mux.Handle("/metrics", promhttp.Handler())
   209  
   210  			s.promServer = &http.Server{
   211  				Addr:    ":" + strconv.Itoa(int(s.prometheusPort)),
   212  				Handler: mux,
   213  			}
   214  
   215  			s.pluginCounter = promauto.NewCounterVec(prometheus.CounterOpts{
   216  				Name: "plugin_calls",
   217  				Help: "Number of calls to a plugin action",
   218  			}, []string{"plugin_name", "plugin_action"})
   219  			s.pluginGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
   220  				Name: "plugin_results",
   221  				Help: "Number of results returns by plugin action",
   222  			}, []string{"plugin_name", "plugin_action"})
   223  			s.pluginTime = promauto.NewHistogramVec(prometheus.HistogramOpts{
   224  				Name: "plugin_duration_seconds",
   225  				Help: "Histogram for plugin action duration in seconds",
   226  			}, []string{"plugin_name", "plugin_action"})
   227  		}
   228  
   229  		s.started = true
   230  
   231  		return nil
   232  	}()
   233  
   234  	if err != nil {
   235  		return err
   236  	}
   237  
   238  	if s.promServer != nil {
   239  		go func() {
   240  			s.promServer.ListenAndServe()
   241  		}()
   242  	}
   243  
   244  	return server.Serve()
   245  }
   246  
   247  // Run starts the extension manager and runs until osquery calls for a shutdown
   248  // or the osquery instance goes away.
   249  func (s *ExtensionManagerServer) Run() error {
   250  	errc := make(chan error)
   251  	go func() {
   252  		errc <- s.Start()
   253  	}()
   254  
   255  	// Watch for the osquery process going away. If so, initiate shutdown.
   256  	go func() {
   257  		for {
   258  			time.Sleep(s.pingInterval)
   259  
   260  			status, err := s.serverClient.Ping()
   261  			if err != nil {
   262  				errc <- errors.Wrap(err, "extension ping failed")
   263  				break
   264  			}
   265  			if status.Code != 0 {
   266  				errc <- errors.Errorf("ping returned status %d", status.Code)
   267  				break
   268  			}
   269  		}
   270  	}()
   271  
   272  	err := <-errc
   273  	if s.promServer != nil {
   274  		// Ignore promtheus shutdown errors
   275  		s.promServer.Shutdown(context.Background())
   276  	}
   277  	if err := s.Shutdown(context.Background()); err != nil {
   278  		return err
   279  	}
   280  	return err
   281  }
   282  
   283  // Ping implements the basic health check.
   284  func (s *ExtensionManagerServer) Ping(ctx context.Context) (*osquery.ExtensionStatus, error) {
   285  	return &osquery.ExtensionStatus{Code: 0, Message: "OK"}, nil
   286  }
   287  
   288  // Call routes a call from the osquery process to the appropriate registered
   289  // plugin.
   290  func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
   291  	subreg, ok := s.registry[registry]
   292  	if !ok {
   293  		return &osquery.ExtensionResponse{
   294  			Status: &osquery.ExtensionStatus{
   295  				Code:    1,
   296  				Message: "Unknown registry: " + registry,
   297  			},
   298  		}, nil
   299  	}
   300  
   301  	plugin, ok := subreg[item]
   302  	if !ok {
   303  		return &osquery.ExtensionResponse{
   304  			Status: &osquery.ExtensionStatus{
   305  				Code:    1,
   306  				Message: "Unknown registry item: " + item,
   307  			},
   308  		}, nil
   309  	}
   310  
   311  	if s.pluginCounter != nil {
   312  		s.pluginCounter.WithLabelValues(item, request["action"]).Inc()
   313  	}
   314  	if s.pluginTime != nil {
   315  		timer := prometheus.NewTimer(s.pluginTime.WithLabelValues(item, request["action"]))
   316  		defer timer.ObserveDuration()
   317  	}
   318  	response := plugin.Call(context.Background(), request)
   319  	if s.pluginGauge != nil {
   320  		s.pluginGauge.WithLabelValues(item, request["action"]).Set(float64(len(response.Response)))
   321  	}
   322  
   323  	return &response, nil
   324  }
   325  
   326  // Shutdown stops the server and closes the listening socket.
   327  func (s *ExtensionManagerServer) Shutdown(ctx context.Context) error {
   328  	s.mutex.Lock()
   329  	defer s.mutex.Unlock()
   330  	if s.server != nil {
   331  		server := s.server
   332  		s.server = nil
   333  		// Stop the server asynchronously so that the current request
   334  		// can complete. Otherwise, this is vulnerable to deadlock if a
   335  		// shutdown request is being processed when shutdown is
   336  		// explicitly called.
   337  		go func() {
   338  			server.Stop()
   339  		}()
   340  	}
   341  
   342  	return nil
   343  }
   344  
   345  // Useful for testing
   346  func (s *ExtensionManagerServer) waitStarted() {
   347  	for {
   348  		s.mutex.Lock()
   349  		started := s.started
   350  		s.mutex.Unlock()
   351  		if started {
   352  			time.Sleep(10 * time.Millisecond)
   353  			break
   354  		}
   355  	}
   356  }