open-cluster-management.io/governance-policy-propagator@v0.13.0/controllers/complianceeventsapi/server.go (about)

     1  package complianceeventsapi
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"database/sql"
     7  	"encoding/csv"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	stdlog "log"
    13  	"math"
    14  	"net"
    15  	"net/http"
    16  	"net/url"
    17  	"reflect"
    18  	"slices"
    19  	"sort"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  	"time"
    24  
    25  	"github.com/lib/pq"
    26  	"k8s.io/client-go/rest"
    27  )
    28  
    29  // init dynamically parses the database columns of each struct type to create a mapping of user provided sort/filter
    30  // options to the equivalent SQL column. ErrInvalidSortOption, ErrInvalidQueryArg, and validQueryArgs are also defined
    31  // with the available sort/query options to choose from.
    32  func init() {
    33  	tableToStruct := map[string]any{
    34  		"clusters":          Cluster{},
    35  		"compliance_events": EventDetails{},
    36  		"parent_policies":   ParentPolicy{},
    37  		"policies":          Policy{},
    38  	}
    39  
    40  	tableNameToJSONName := map[string]string{
    41  		"clusters":          "cluster",
    42  		"compliance_events": "event",
    43  		"parent_policies":   "parent_policy",
    44  		"policies":          "policy",
    45  	}
    46  
    47  	// ID is a special case since it's displayed at the top-level in the JSON but is actually in the compliance_events
    48  	// table.
    49  	queryOptionsToSQL = map[string]string{"id": "compliance_events.id"}
    50  	sortOptionsKeys := []string{"id"}
    51  
    52  	for tableName, tableStruct := range tableToStruct {
    53  		structType := reflect.TypeOf(tableStruct)
    54  		for i := 0; i < structType.NumField(); i++ {
    55  			structField := structType.Field(i)
    56  
    57  			jsonField := structField.Tag.Get("json")
    58  			if jsonField == "" || jsonField == "-" {
    59  				continue
    60  			}
    61  
    62  			// This removes additional text in tag such as capturing `spec` from `json:"spec,omitempty"`.
    63  			jsonField = strings.SplitN(jsonField, ",", 2)[0]
    64  
    65  			dbColumn := structField.Tag.Get("db")
    66  			if dbColumn == "" {
    67  				continue
    68  			}
    69  
    70  			// Skip JSONB columns as sortable options
    71  			if tableName == "policies" && dbColumn == "spec" {
    72  				continue
    73  			}
    74  
    75  			if tableName == "compliance_events" && dbColumn == "metadata" {
    76  				continue
    77  			}
    78  
    79  			queryOption := fmt.Sprintf("%s.%s", tableNameToJSONName[tableName], jsonField)
    80  
    81  			sortOptionsKeys = append(sortOptionsKeys, queryOption)
    82  			validQueryArgs = append(validQueryArgs, queryOption)
    83  
    84  			queryOptionsToSQL[queryOption] = fmt.Sprintf("%s.%s", tableName, dbColumn)
    85  		}
    86  	}
    87  
    88  	sort.Strings(sortOptionsKeys)
    89  
    90  	ErrInvalidSortOption = fmt.Errorf(
    91  		"an invalid sort option was provided, choose from: %s", strings.Join(sortOptionsKeys, ", "),
    92  	)
    93  
    94  	validQueryArgs = []string{
    95  		"direction",
    96  		"event.message_includes",
    97  		"event.message_like",
    98  		"event.timestamp_after",
    99  		"event.timestamp_before",
   100  		"include_spec",
   101  		"page",
   102  		"per_page",
   103  		"sort",
   104  	}
   105  
   106  	validQueryArgs = append(
   107  		validQueryArgs,
   108  		// sortOptionsKeys are all filterable columns.
   109  		sortOptionsKeys...,
   110  	)
   111  
   112  	sort.Strings(validQueryArgs)
   113  
   114  	ErrInvalidQueryArg = fmt.Errorf(
   115  		"an invalid query argument was provided, choose from: %s", strings.Join(validQueryArgs, ", "),
   116  	)
   117  }
   118  
   119  const (
   120  	postgresForeignKeyViolationCode = "23503"
   121  	postgresUniqueViolationCode     = "23505"
   122  )
   123  
   124  var (
   125  	clusterKeyCache         sync.Map
   126  	queryOptionsToSQL       map[string]string
   127  	validQueryArgs          []string
   128  	ErrInvalidSortOption    error
   129  	ErrInvalidQueryArgValue = errors.New("invalid query argument")
   130  	ErrInvalidQueryArg      error
   131  	ErrUnauthorized         = errors.New("not authorized")
   132  	ErrForbidden            = errors.New("the request is not allowed")
   133  	// The user has no access to any managed cluster
   134  	ErrNoAccess = errors.New("the user has no access")
   135  )
   136  
   137  type ComplianceAPIServer struct {
   138  	server *http.Server
   139  	addr   string
   140  	cert   *tls.Certificate
   141  	cfg    *rest.Config
   142  }
   143  
   144  func NewComplianceAPIServer(listenAddress string, cfg *rest.Config, cert *tls.Certificate) *ComplianceAPIServer {
   145  	return &ComplianceAPIServer{
   146  		addr: listenAddress,
   147  		cert: cert,
   148  		cfg:  cfg,
   149  	}
   150  }
   151  
   152  type serverErrorLogWriter struct{}
   153  
   154  func (*serverErrorLogWriter) Write(p []byte) (int, error) {
   155  	m := string(p)
   156  
   157  	// The OpenShift router (haproxy) seems to perform TCP checks to see if the connection is available. When it does
   158  	// this, it resets the connection when done, which causes a log message every 5 seconds, so this will filter it out.
   159  	if strings.HasPrefix(m, "http: TLS handshake error") && strings.HasSuffix(m, ": connection reset by peer\n") {
   160  		log.V(2).Info(m)
   161  	} else {
   162  		log.Info(m)
   163  	}
   164  
   165  	return len(p), nil
   166  }
   167  
   168  func newServerErrorLog() *stdlog.Logger {
   169  	return stdlog.New(&serverErrorLogWriter{}, "", 0)
   170  }
   171  
   172  // Start starts the HTTP server and blocks until ctx is closed or there was an error starting the
   173  // HTTP server.
   174  func (s *ComplianceAPIServer) Start(ctx context.Context, serverContext *ComplianceServerCtx) error {
   175  	mux := http.NewServeMux()
   176  
   177  	s.server = &http.Server{
   178  		Addr:    s.addr,
   179  		Handler: mux,
   180  
   181  		// need to investigate ideal values for these
   182  		ReadTimeout:  15 * time.Second,
   183  		WriteTimeout: 15 * time.Second,
   184  		IdleTimeout:  15 * time.Second,
   185  		ErrorLog:     newServerErrorLog(),
   186  	}
   187  
   188  	listener, err := net.Listen("tcp", s.addr)
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	if s.cert != nil {
   194  		s.server.TLSConfig = &tls.Config{
   195  			MinVersion:   tls.VersionTLS12,
   196  			Certificates: []tls.Certificate{*s.cert},
   197  		}
   198  
   199  		listener = tls.NewListener(listener, s.server.TLSConfig)
   200  	}
   201  
   202  	// register handlers here
   203  	mux.HandleFunc("/api/v1/compliance-events", func(w http.ResponseWriter, r *http.Request) {
   204  		w.Header().Set("Content-Type", "application/json")
   205  
   206  		serverContext.Lock.RLock()
   207  		defer serverContext.Lock.RUnlock()
   208  
   209  		if serverContext.DB == nil || serverContext.DB.PingContext(r.Context()) != nil {
   210  			writeErrMsgJSON(w, "The database is unavailable", http.StatusInternalServerError)
   211  
   212  			return
   213  		}
   214  
   215  		switch r.Method {
   216  		case http.MethodGet:
   217  			// To verify each request independently
   218  			userConfig, err := getUserKubeConfig(s.cfg, r)
   219  			if err != nil {
   220  				if errors.Is(err, ErrUnauthorized) {
   221  					writeErrMsgJSON(w, "The Authorization header is not set", http.StatusUnauthorized)
   222  				}
   223  
   224  				return
   225  			}
   226  			getComplianceEvents(serverContext.DB, w, r, userConfig)
   227  		case http.MethodPost:
   228  			postComplianceEvent(serverContext, s.cfg, w, r)
   229  		default:
   230  			writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed)
   231  		}
   232  	})
   233  
   234  	mux.HandleFunc("/api/v1/compliance-events/", func(w http.ResponseWriter, r *http.Request) {
   235  		w.Header().Set("Content-Type", "application/json")
   236  
   237  		serverContext.Lock.RLock()
   238  		defer serverContext.Lock.RUnlock()
   239  
   240  		if serverContext.DB == nil || serverContext.DB.PingContext(r.Context()) != nil {
   241  			writeErrMsgJSON(w, "The database is unavailable", http.StatusInternalServerError)
   242  
   243  			return
   244  		}
   245  
   246  		if r.Method != http.MethodGet {
   247  			writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed)
   248  
   249  			return
   250  		}
   251  
   252  		// To verify each request independently
   253  		userConfig, err := getUserKubeConfig(s.cfg, r)
   254  		if err != nil {
   255  			if errors.Is(err, ErrUnauthorized) {
   256  				writeErrMsgJSON(w, "The Authorization header is not set", http.StatusUnauthorized)
   257  			}
   258  
   259  			return
   260  		}
   261  
   262  		getSingleComplianceEvent(serverContext.DB, w, r, userConfig)
   263  	})
   264  
   265  	mux.HandleFunc("/api/v1/reports/compliance-events", func(w http.ResponseWriter, r *http.Request) {
   266  		// This header is for error writings
   267  		w.Header().Set("Content-Type", "application/json")
   268  
   269  		// To verify each request independently
   270  		userConfig, err := getUserKubeConfig(s.cfg, r)
   271  		if err != nil {
   272  			if errors.Is(err, ErrUnauthorized) {
   273  				writeErrMsgJSON(w, "The Authorization header is not set", http.StatusUnauthorized)
   274  			}
   275  
   276  			return
   277  		}
   278  
   279  		if r.Method != http.MethodGet {
   280  			writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed)
   281  
   282  			return
   283  		}
   284  
   285  		getComplianceEventsCSV(serverContext.DB, w, r, userConfig)
   286  	})
   287  
   288  	serveErr := make(chan error)
   289  
   290  	go func() {
   291  		defer close(serveErr)
   292  
   293  		err := s.server.Serve(listener)
   294  		if err != nil && err != http.ErrServerClosed {
   295  			serveErr <- err
   296  		}
   297  	}()
   298  
   299  	select {
   300  	case <-ctx.Done():
   301  		err := s.server.Shutdown(context.Background())
   302  		if err != nil {
   303  			log.Error(err, "Failed to shutdown the compliance API server")
   304  		}
   305  
   306  		return nil
   307  	case err, closed := <-serveErr:
   308  		if err != nil {
   309  			return err
   310  		}
   311  
   312  		if closed {
   313  			return errors.New("the compliance API server unexpectedly shutdown without an error")
   314  		}
   315  
   316  		return nil
   317  	}
   318  }
   319  
   320  // splitQueryValue will parse a string and split on unescaped commas. Empty values are discarded.
   321  func splitQueryValue(value string) []string {
   322  	values := []string{}
   323  
   324  	var currentVal string
   325  	var previousChar rune
   326  
   327  	for _, char := range value {
   328  		if char == ',' {
   329  			if previousChar == '\\' {
   330  				// This comma was escaped, so remove the escape character and keep the comma. Runes are used in case
   331  				// unicode characters are present in currentVal.
   332  				runeCurrentVal := []rune(currentVal)
   333  				currentVal = string(runeCurrentVal[:len(runeCurrentVal)-1]) + ","
   334  			} else {
   335  				// The comma was not escaped so we encountered a new value.
   336  				if currentVal != "" {
   337  					values = append(values, currentVal)
   338  				}
   339  
   340  				currentVal = ""
   341  			}
   342  		} else {
   343  			// A non-special character was encountered so just append the character.
   344  			currentVal += string(char)
   345  		}
   346  
   347  		previousChar = char
   348  	}
   349  
   350  	if currentVal != "" {
   351  		values = append(values, currentVal)
   352  	}
   353  
   354  	return values
   355  }
   356  
   357  // parseQueryArgs will parse the HTTP request's query arguments and convert them to a usable format for constructing
   358  // the SQL query. All defaults are set and any invalid query arguments result in an error being returned.
   359  func parseQueryArgs(ctx context.Context, queryArgs url.Values, db *sql.DB,
   360  	userConfig *rest.Config, isCSV bool,
   361  ) (*queryOptions, error) {
   362  	parsed := &queryOptions{
   363  		Direction:    "desc",
   364  		Page:         1,
   365  		PerPage:      20,
   366  		Sort:         []string{"compliance_events.timestamp"},
   367  		ArrayFilters: map[string][]string{},
   368  		Filters:      map[string][]string{},
   369  		NullFilters:  []string{},
   370  	}
   371  
   372  	// Case return CSV file, default PerPage is 0. Unlimited
   373  	if isCSV {
   374  		parsed.PerPage = 0
   375  	}
   376  
   377  	for arg := range queryArgs {
   378  		valid := false
   379  
   380  		for _, validQueryArg := range validQueryArgs {
   381  			if arg == validQueryArg {
   382  				valid = true
   383  
   384  				break
   385  			}
   386  		}
   387  
   388  		if !valid {
   389  			return nil, ErrInvalidQueryArg
   390  		}
   391  
   392  		sqlName, hasSQLName := queryOptionsToSQL[arg]
   393  
   394  		value := queryArgs.Get(arg)
   395  		if value == "" && arg != "include_spec" {
   396  			// Only support null filters if it's a SQL column
   397  			if !hasSQLName {
   398  				return nil, fmt.Errorf("%w: %s must have a value", ErrInvalidQueryArgValue, arg)
   399  			}
   400  
   401  			parsed.NullFilters = append(parsed.NullFilters, sqlName)
   402  
   403  			continue
   404  		}
   405  
   406  		switch arg {
   407  		case "direction":
   408  			if value == "desc" {
   409  				parsed.Direction = "DESC"
   410  			} else if value == "asc" {
   411  				parsed.Direction = "ASC"
   412  			} else {
   413  				return nil, fmt.Errorf("%w: direction must be one of: asc, desc", ErrInvalidQueryArg)
   414  			}
   415  		case "include_spec":
   416  			if value != "" {
   417  				return nil, fmt.Errorf("%w: include_spec is a flag and does not accept a value", ErrInvalidQueryArg)
   418  			}
   419  
   420  			parsed.IncludeSpec = true
   421  		case "page":
   422  			var err error
   423  
   424  			parsed.Page, err = strconv.ParseUint(value, 10, 64)
   425  			if err != nil || parsed.Page == 0 {
   426  				return nil, fmt.Errorf("%w: page must be a positive integer", ErrInvalidQueryArg)
   427  			}
   428  		case "per_page":
   429  			var err error
   430  
   431  			parsed.PerPage, err = strconv.ParseUint(value, 10, 64)
   432  			if err != nil || parsed.PerPage == 0 || parsed.PerPage > 100 {
   433  				return nil, fmt.Errorf("%w: per_page must be a value between 1 and 100", ErrInvalidQueryArg)
   434  			}
   435  		case "sort":
   436  			sortArgs := splitQueryValue(value)
   437  
   438  			sortSQL := []string{}
   439  
   440  			for _, sortArg := range sortArgs {
   441  				sortOption, ok := queryOptionsToSQL[sortArg]
   442  				if !ok {
   443  					return nil, ErrInvalidSortOption
   444  				}
   445  
   446  				sortSQL = append(sortSQL, sortOption)
   447  			}
   448  
   449  			parsed.Sort = sortSQL
   450  		case "parent_policy.categories", "parent_policy.controls", "parent_policy.standards":
   451  			parsed.ArrayFilters[sqlName] = splitQueryValue(value)
   452  		case "event.message_includes":
   453  			// Escape the SQL LIKE operators because we aren't exposing that functionality.
   454  			escapedVal := strings.ReplaceAll(value, "%", `\%`)
   455  			escapedVal = strings.ReplaceAll(escapedVal, "_", `\_`)
   456  			// Add wildcards at the beginning and end of the search keyword for substring matching.
   457  			parsed.MessageIncludes = "%" + escapedVal + "%"
   458  		case "event.message_like":
   459  			parsed.MessageLike = value
   460  		case "event.timestamp_before":
   461  			var err error
   462  
   463  			parsed.TimestampBefore, err = time.Parse(time.RFC3339, value)
   464  			if err != nil {
   465  				return nil, fmt.Errorf(
   466  					"%w: event.timestamp_before must be in the format of RFC 3339", ErrInvalidQueryArgValue,
   467  				)
   468  			}
   469  		case "event.timestamp_after":
   470  			var err error
   471  
   472  			parsed.TimestampAfter, err = time.Parse(time.RFC3339, value)
   473  			if err != nil {
   474  				return nil, fmt.Errorf(
   475  					"%w: event.timestamp_after must be in the format of RFC 3339", ErrInvalidQueryArgValue,
   476  				)
   477  			}
   478  		default:
   479  			// Standard string filtering
   480  			parsed.Filters[sqlName] = splitQueryValue(value)
   481  		}
   482  	}
   483  
   484  	parsed, err := setAuthorizedClusters(ctx, db, parsed, userConfig)
   485  	if err != nil {
   486  		// ErrNoAccess needs queryOptions
   487  		return parsed, err
   488  	}
   489  
   490  	return parsed, nil
   491  }
   492  
   493  // setAuthorizedClusters verifies that if a cluster filter is provided,
   494  // the user has access to this filter. If no cluster filter is provided,
   495  // it sets the cluster filter to all managed clusters the user has access to.
   496  // If the user has no access, then ErrNoAccess is returned.
   497  func setAuthorizedClusters(ctx context.Context, db *sql.DB, parsed *queryOptions,
   498  	userConfig *rest.Config,
   499  ) (*queryOptions, error) {
   500  	unAuthorizedClusters := []string{}
   501  
   502  	// Get all managedCluster rules
   503  	allRules, err := getManagedClusterRules(userConfig, nil)
   504  	if err != nil {
   505  		return parsed, err
   506  	}
   507  
   508  	if slices.Contains(allRules["*"], "get") || slices.Contains(allRules["*"], "*") {
   509  		return parsed, nil
   510  	}
   511  
   512  	clusterIDs := parsed.Filters["clusters.cluster_id"]
   513  	// Temporarily reset clusters.cluster_id and repopulate with all known cluster IDs
   514  	parsed.Filters["clusters.cluster_id"] = []string{}
   515  
   516  	// Convert id to name
   517  	for _, id := range clusterIDs {
   518  		clusterName, err := getClusterNameFromID(ctx, db, id)
   519  		if err != nil {
   520  			if errors.Is(err, sql.ErrNoRows) {
   521  				// Filter out invalid cluster IDs from the query
   522  				continue
   523  			}
   524  
   525  			log.Error(err, "Failed to get cluster name from cluster ID", getPqErrKeyVals(err, "ID", id)...)
   526  
   527  			return parsed, err
   528  		}
   529  
   530  		if !getAccessByClusterName(allRules, clusterName) {
   531  			unAuthorizedClusters = append(unAuthorizedClusters, id)
   532  		} else {
   533  			parsed.Filters["clusters.cluster_id"] = append(parsed.Filters["clusters.cluster_id"], id)
   534  		}
   535  	}
   536  
   537  	parsedClusterNames := parsed.Filters["clusters.name"]
   538  	for _, clusterName := range parsedClusterNames {
   539  		if !getAccessByClusterName(allRules, clusterName) {
   540  			unAuthorizedClusters = append(unAuthorizedClusters, clusterName)
   541  		}
   542  	}
   543  
   544  	// There is no cluster.cluster_id or cluster.name query argument.
   545  	// In other words, the user requests all they have access to.
   546  	if len(clusterIDs) == 0 && len(parsedClusterNames) == 0 {
   547  		for mcName := range allRules {
   548  			// Add the cluster to the filter if the user has get authentication. Note that if the user has get access
   549  			// on all managed clusters, that gets handled at the beginning of the function.
   550  			if getAccessByClusterName(allRules, mcName) {
   551  				parsed.Filters["clusters.name"] = append(parsed.Filters["clusters.name"], mcName)
   552  			}
   553  		}
   554  	}
   555  
   556  	if len(unAuthorizedClusters) > 0 {
   557  		return parsed, fmt.Errorf("%w: the following cluster filters are not authorized: %s",
   558  			ErrForbidden, strings.Join(unAuthorizedClusters, ", "))
   559  	}
   560  
   561  	if len(parsed.Filters["clusters.name"]) == 0 && len(parsed.Filters["clusters.cluster_id"]) == 0 {
   562  		return parsed, ErrNoAccess
   563  	}
   564  
   565  	return parsed, nil
   566  }
   567  
   568  // generateGetComplianceEventsQuery will return a SELECT query with results ready to be parsed by
   569  // scanIntoComplianceEvent. The caller is responsible for adding filters to the query.
   570  func generateGetComplianceEventsQuery(includeSpec bool) string {
   571  	return fmt.Sprintf(`SELECT %s
   572  FROM
   573    compliance_events
   574    LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
   575    LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
   576    LEFT JOIN policies ON compliance_events.policy_id = policies.id`,
   577  		strings.Join(generateSelectedArgs(includeSpec), ", "),
   578  	)
   579  }
   580  
   581  func generateSelectedArgs(includeSpec bool) []string {
   582  	selectArgs := []string{
   583  		"compliance_events.id",
   584  		"compliance_events.compliance",
   585  		"compliance_events.message",
   586  		"compliance_events.metadata",
   587  		"compliance_events.reported_by",
   588  		"compliance_events.timestamp",
   589  		"clusters.cluster_id",
   590  		"clusters.name",
   591  		"parent_policies.id",
   592  		"parent_policies.name",
   593  		"parent_policies.namespace",
   594  		"parent_policies.categories",
   595  		"parent_policies.controls",
   596  		"parent_policies.standards",
   597  		"policies.id",
   598  		"policies.api_group",
   599  		"policies.kind",
   600  		"policies.name",
   601  		"policies.namespace",
   602  		"policies.severity",
   603  	}
   604  
   605  	if includeSpec {
   606  		selectArgs = append(selectArgs, "policies.spec")
   607  	}
   608  
   609  	return selectArgs
   610  }
   611  
   612  // generate Headers for CSV. "." replace by "_"
   613  // Example: parent_policies.namespace -> parent_policies_namespace
   614  func getCsvHeader(includeSpec bool) []string {
   615  	localSelectArgs := generateSelectedArgs(includeSpec)
   616  
   617  	for i, arg := range localSelectArgs {
   618  		localSelectArgs[i] = strings.ReplaceAll(arg, ".", "_")
   619  	}
   620  
   621  	return localSelectArgs
   622  }
   623  
   624  type Scannable interface {
   625  	Scan(dest ...any) error
   626  }
   627  
   628  // scanIntoComplianceEvent will scan the row result from the SELECT query generated by generateGetComplianceEventsQuery
   629  // into a ComplianceEvent object.
   630  func scanIntoComplianceEvent(rows Scannable, includeSpec bool) (*ComplianceEvent, error) {
   631  	ce := ComplianceEvent{
   632  		Cluster:      Cluster{},
   633  		Event:        EventDetails{},
   634  		ParentPolicy: nil,
   635  		Policy:       Policy{},
   636  	}
   637  
   638  	ppID := sql.NullInt32{}
   639  	ppName := sql.NullString{}
   640  	ppNamespace := sql.NullString{}
   641  	ppCategories := pq.StringArray{}
   642  	ppControls := pq.StringArray{}
   643  	ppStandards := pq.StringArray{}
   644  
   645  	scanArgs := []any{
   646  		&ce.EventID,
   647  		&ce.Event.Compliance,
   648  		&ce.Event.Message,
   649  		&ce.Event.Metadata,
   650  		&ce.Event.ReportedBy,
   651  		&ce.Event.Timestamp,
   652  		&ce.Cluster.ClusterID,
   653  		&ce.Cluster.Name,
   654  		&ppID,
   655  		&ppName,
   656  		&ppNamespace,
   657  		&ppCategories,
   658  		&ppControls,
   659  		&ppStandards,
   660  		&ce.Policy.KeyID,
   661  		&ce.Policy.APIGroup,
   662  		&ce.Policy.Kind,
   663  		&ce.Policy.Name,
   664  		&ce.Policy.Namespace,
   665  		&ce.Policy.Severity,
   666  	}
   667  
   668  	if includeSpec {
   669  		scanArgs = append(scanArgs, &ce.Policy.Spec)
   670  	}
   671  
   672  	err := rows.Scan(scanArgs...)
   673  	if err != nil {
   674  		return nil, err
   675  	}
   676  
   677  	// The parent policy is optional but when it's set, the name is guaranteed to be set.
   678  	if ppName.String != "" {
   679  		ce.ParentPolicy = &ParentPolicy{
   680  			KeyID:      ppID.Int32,
   681  			Name:       ppName.String,
   682  			Namespace:  ppNamespace.String,
   683  			Categories: ppCategories,
   684  			Controls:   ppControls,
   685  			Standards:  ppStandards,
   686  		}
   687  	}
   688  
   689  	return &ce, nil
   690  }
   691  
   692  // getSingleComplianceEvent handles the GET API endpoint for a single compliance event by ID.
   693  func getSingleComplianceEvent(db *sql.DB, w http.ResponseWriter,
   694  	r *http.Request, config *rest.Config,
   695  ) {
   696  	eventIDStr := strings.TrimPrefix(r.URL.Path, "/api/v1/compliance-events/")
   697  
   698  	eventID, err := strconv.ParseUint(eventIDStr, 10, 64)
   699  	if err != nil {
   700  		writeErrMsgJSON(w, "The provided compliance event ID is invalid", http.StatusBadRequest)
   701  
   702  		return
   703  	}
   704  
   705  	query := fmt.Sprintf("%s\nWHERE compliance_events.id = $1;", generateGetComplianceEventsQuery(true))
   706  
   707  	row := db.QueryRowContext(r.Context(), query, eventID)
   708  	if row.Err() != nil {
   709  		log.Error(row.Err(), "Failed to query for the compliance event", "eventID", eventID)
   710  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   711  
   712  		return
   713  	}
   714  
   715  	complianceEvent, err := scanIntoComplianceEvent(row, true)
   716  	if err != nil {
   717  		if errors.Is(err, sql.ErrNoRows) {
   718  			writeErrMsgJSON(w, "The requested compliance event was not found", http.StatusNotFound)
   719  
   720  			return
   721  		}
   722  
   723  		log.Error(err, "Failed to unmarshal the database results", getPqErrKeyVals(err)...)
   724  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   725  
   726  		return
   727  	}
   728  
   729  	// Check auth for managedCluster GET verb
   730  	isAllowed, err := canGetManagedCluster(config, complianceEvent.Cluster.Name)
   731  	if err != nil {
   732  		log.Error(err, `Failed to get the "get" authorization for the cluster`,
   733  			"cluster", complianceEvent.Cluster.Name)
   734  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   735  
   736  		return
   737  	}
   738  
   739  	if !isAllowed {
   740  		writeErrMsgJSON(w, "Forbidden", http.StatusForbidden)
   741  
   742  		return
   743  	}
   744  
   745  	jsonResp, err := json.Marshal(complianceEvent)
   746  	if err != nil {
   747  		log.Error(err, "Failed marshal the compliance event", "eventID", eventID)
   748  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   749  
   750  		return
   751  	}
   752  
   753  	if _, err = w.Write(jsonResp); err != nil {
   754  		log.Error(err, "Error writing success response")
   755  	}
   756  }
   757  
   758  // getPqErrKeyVals is a helper to add additional database error details to a log message. additionalKeyVals is provided
   759  // as a convenience so that the keys don't need to be explicitly set to interface{} types when using the
   760  // `getPqErrKeyVals(err, "key1", "val1")...“ syntax.
   761  func getPqErrKeyVals(err error, additionalKeyVals ...interface{}) []interface{} {
   762  	var pqErr *pq.Error
   763  
   764  	if errors.As(err, &pqErr) {
   765  		return append(
   766  			[]interface{}{"dbMessage", pqErr.Message, "dbDetail", pqErr.Detail, "dbCode", pqErr.Code},
   767  			additionalKeyVals...,
   768  		)
   769  	}
   770  
   771  	return additionalKeyVals
   772  }
   773  
   774  func getClusterNameFromID(ctx context.Context, db *sql.DB, clusterID string) (name string, err error) {
   775  	err = db.QueryRowContext(ctx,
   776  		`SELECT name FROM clusters WHERE cluster_id = $1`, clusterID,
   777  	).Scan(&name)
   778  	if err != nil {
   779  		return "", err
   780  	}
   781  
   782  	return name, nil
   783  }
   784  
   785  // getWhereClause will convert the input queryOptions to a WHERE statement and return the filter values for a prepared
   786  // statement.
   787  func getWhereClause(options *queryOptions) (string, []any) {
   788  	filterSQL := []string{}
   789  	filterValues := []any{}
   790  
   791  	for sqlColumn, values := range options.Filters {
   792  		if len(values) == 0 {
   793  			continue
   794  		}
   795  
   796  		for i, value := range values {
   797  			filterValues = append(filterValues, value)
   798  			// For example: compliance_events.name=$1
   799  			filter := fmt.Sprintf("%s=$%d", sqlColumn, len(filterValues))
   800  			if i == 0 {
   801  				filterSQL = append(filterSQL, "("+filter)
   802  			} else {
   803  				filterSQL[len(filterSQL)-1] += " OR " + filter
   804  			}
   805  		}
   806  
   807  		filterSQL[len(filterSQL)-1] += ")"
   808  	}
   809  
   810  	for sqlColumn, values := range options.ArrayFilters {
   811  		if len(values) == 0 {
   812  			continue
   813  		}
   814  
   815  		for i, value := range values {
   816  			filterValues = append(filterValues, value)
   817  
   818  			// For example: $1=ANY(parent_policies.categories)
   819  			filter := fmt.Sprintf("$%d=ANY(%s)", len(filterValues), sqlColumn)
   820  			if i == 0 {
   821  				filterSQL = append(filterSQL, "("+filter)
   822  			} else {
   823  				filterSQL[len(filterSQL)-1] += " OR " + filter
   824  			}
   825  		}
   826  
   827  		filterSQL[len(filterSQL)-1] += ")"
   828  	}
   829  
   830  	for _, sqlColumn := range options.NullFilters {
   831  		filterSQL = append(filterSQL, fmt.Sprintf("%s IS NULL", sqlColumn))
   832  	}
   833  
   834  	if options.MessageIncludes != "" {
   835  		filterValues = append(filterValues, options.MessageIncludes)
   836  
   837  		filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.message LIKE $%d", len(filterValues)))
   838  	}
   839  
   840  	if options.MessageLike != "" {
   841  		filterValues = append(filterValues, options.MessageLike)
   842  
   843  		filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.message LIKE $%d", len(filterValues)))
   844  	}
   845  
   846  	if !options.TimestampAfter.IsZero() {
   847  		filterValues = append(filterValues, options.TimestampAfter)
   848  
   849  		filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.timestamp > $%d", len(filterValues)))
   850  	}
   851  
   852  	if !options.TimestampBefore.IsZero() {
   853  		filterValues = append(filterValues, options.TimestampBefore)
   854  
   855  		filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.timestamp < $%d", len(filterValues)))
   856  	}
   857  
   858  	var whereClause string
   859  
   860  	if len(filterSQL) > 0 {
   861  		// For example:
   862  		// WHERE (policy.name=$1) AND ($2=ANY(parent_policies.categories) OR $3=ANY(parent_policies.categories))
   863  		whereClause = "\nWHERE " + strings.Join(filterSQL, " AND ")
   864  	}
   865  
   866  	return whereClause, filterValues
   867  }
   868  
   869  // getComplianceEvents handles the list API endpoint for compliance events.
   870  func getComplianceEvents(db *sql.DB, w http.ResponseWriter,
   871  	r *http.Request, userConfig *rest.Config,
   872  ) {
   873  	queryArgs, err := parseQueryArgs(r.Context(), r.URL.Query(), db, userConfig, false)
   874  	if err != nil {
   875  		if errors.Is(err, ErrForbidden) {
   876  			writeErrMsgJSON(w, err.Error(), http.StatusForbidden)
   877  
   878  			return
   879  		}
   880  
   881  		if errors.Is(err, ErrNoAccess) {
   882  			response := ListResponse{
   883  				Data: []ComplianceEvent{},
   884  				Metadata: metadata{
   885  					Page:    queryArgs.Page,
   886  					Pages:   0,
   887  					PerPage: queryArgs.PerPage,
   888  					Total:   0,
   889  				},
   890  			}
   891  
   892  			jsonResp, err := json.Marshal(response)
   893  			if err != nil {
   894  				log.Error(err, "Failed to marshal an empty response")
   895  				writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   896  
   897  				return
   898  			}
   899  
   900  			if _, err = w.Write(jsonResp); err != nil {
   901  				log.Error(err, "Error writing empty response")
   902  				writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   903  			}
   904  
   905  			return
   906  		}
   907  
   908  		if errors.Is(err, ErrInvalidQueryArg) || errors.Is(err, ErrInvalidQueryArgValue) ||
   909  			errors.Is(err, ErrInvalidSortOption) {
   910  			writeErrMsgJSON(w, err.Error(), http.StatusBadRequest)
   911  
   912  			return
   913  		}
   914  
   915  		writeErrMsgJSON(w, err.Error(), http.StatusInternalServerError)
   916  
   917  		return
   918  	}
   919  
   920  	// Note that the where clause could be an empty string if not filters were passed in the query arguments.
   921  	whereClause, filterValues := getWhereClause(queryArgs)
   922  
   923  	query := getComplianceEventsQuery(whereClause, queryArgs)
   924  
   925  	rows, err := db.QueryContext(r.Context(), query, filterValues...)
   926  	if err == nil {
   927  		err = rows.Err()
   928  	}
   929  
   930  	if err != nil {
   931  		log.Error(err, "Failed to query for compliance events")
   932  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   933  
   934  		return
   935  	}
   936  
   937  	defer rows.Close()
   938  
   939  	complianceEvents := make([]ComplianceEvent, 0, queryArgs.PerPage)
   940  
   941  	for rows.Next() {
   942  		ce, err := scanIntoComplianceEvent(rows, queryArgs.IncludeSpec)
   943  		if err != nil {
   944  			log.Error(err, "Failed to unmarshal the database results")
   945  			writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   946  
   947  			return
   948  		}
   949  
   950  		complianceEvents = append(complianceEvents, *ce)
   951  	}
   952  
   953  	countQuery := `SELECT COUNT(*) FROM compliance_events
   954  LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
   955  LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
   956  LEFT JOIN policies ON compliance_events.policy_id = policies.id` + whereClause // #nosec G202
   957  
   958  	row := db.QueryRowContext(r.Context(), countQuery, filterValues...)
   959  
   960  	var total uint64
   961  
   962  	if err := row.Scan(&total); err != nil {
   963  		log.Error(err, "Failed to get the count of compliance events", getPqErrKeyVals(err)...)
   964  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   965  
   966  		return
   967  	}
   968  
   969  	pages := math.Ceil(float64(total) / float64(queryArgs.PerPage))
   970  
   971  	response := ListResponse{
   972  		Data: complianceEvents,
   973  		Metadata: metadata{
   974  			Page:    queryArgs.Page,
   975  			Pages:   uint64(pages),
   976  			PerPage: queryArgs.PerPage,
   977  			Total:   total,
   978  		},
   979  	}
   980  
   981  	jsonResp, err := json.Marshal(response)
   982  	if err != nil {
   983  		log.Error(err, "Failed to marshal the response")
   984  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   985  
   986  		return
   987  	}
   988  
   989  	if _, err = w.Write(jsonResp); err != nil {
   990  		log.Error(err, "Error writing success response")
   991  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
   992  
   993  		return
   994  	}
   995  }
   996  
   997  // postComplianceEvent assumes you have a read lock already attained.
   998  func postComplianceEvent(serverContext *ComplianceServerCtx, cfg *rest.Config, w http.ResponseWriter, r *http.Request) {
   999  	body, err := io.ReadAll(r.Body)
  1000  	if err != nil {
  1001  		log.Error(err, "error reading request body")
  1002  		writeErrMsgJSON(w, "Could not read request body", http.StatusBadRequest)
  1003  
  1004  		return
  1005  	}
  1006  
  1007  	reqEvent := &ComplianceEvent{}
  1008  
  1009  	if err := json.Unmarshal(body, reqEvent); err != nil {
  1010  		writeErrMsgJSON(w, "Incorrectly formatted request body, must be valid JSON", http.StatusBadRequest)
  1011  
  1012  		return
  1013  	}
  1014  
  1015  	if err := reqEvent.Validate(r.Context(), serverContext); err != nil {
  1016  		writeErrMsgJSON(w, err.Error(), http.StatusBadRequest)
  1017  
  1018  		return
  1019  	}
  1020  
  1021  	allowed, err := canRecordComplianceEvent(cfg, reqEvent.Cluster.Name, r)
  1022  	if err != nil {
  1023  		if errors.Is(err, ErrUnauthorized) {
  1024  			writeErrMsgJSON(w, "Unauthorized", http.StatusUnauthorized)
  1025  
  1026  			return
  1027  		}
  1028  
  1029  		log.Error(err, "error determining if the user is authorized for recording compliance events")
  1030  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1031  
  1032  		return
  1033  	}
  1034  
  1035  	if !allowed {
  1036  		// Logging is handled by canRecordComplianceEvent
  1037  		writeErrMsgJSON(w, "Forbidden", http.StatusForbidden)
  1038  
  1039  		return
  1040  	}
  1041  
  1042  	clusterFK, err := GetClusterForeignKey(r.Context(), serverContext.DB, reqEvent.Cluster)
  1043  	if err != nil {
  1044  		log.Error(err, "error getting cluster foreign key", getPqErrKeyVals(err)...)
  1045  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1046  
  1047  		return
  1048  	}
  1049  
  1050  	reqEvent.Event.ClusterID = clusterFK
  1051  
  1052  	if reqEvent.ParentPolicy != nil {
  1053  		pfk, err := getParentPolicyForeignKey(r.Context(), serverContext, *reqEvent.ParentPolicy)
  1054  		if err != nil {
  1055  			log.Error(err, "error getting parent policy foreign key", getPqErrKeyVals(err)...)
  1056  			writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1057  
  1058  			return
  1059  		}
  1060  
  1061  		reqEvent.Event.ParentPolicyID = &pfk
  1062  	}
  1063  
  1064  	policyFK, err := getPolicyForeignKey(r.Context(), serverContext, reqEvent.Policy)
  1065  	if err != nil {
  1066  		log.Error(err, "error getting policy foreign key", getPqErrKeyVals(err)...)
  1067  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1068  
  1069  		return
  1070  	}
  1071  
  1072  	reqEvent.Event.PolicyID = policyFK
  1073  
  1074  	err = reqEvent.Create(r.Context(), serverContext.DB)
  1075  	if err != nil {
  1076  		if errors.Is(err, errDuplicateComplianceEvent) {
  1077  			writeErrMsgJSON(w, "The compliance event already exists", http.StatusConflict)
  1078  
  1079  			return
  1080  		}
  1081  
  1082  		var pqErr *pq.Error
  1083  
  1084  		if errors.As(err, &pqErr) && pqErr.Code == postgresForeignKeyViolationCode {
  1085  			// This can only happen if the cache is out of date due to data loss in the database because if the
  1086  			// database ID is provided, it is validated against the database.
  1087  			log.Info(
  1088  				"Encountered a foreign key violation. Assuming the database lost data, so the cache is "+
  1089  					"being cleared",
  1090  				"message", pqErr.Message,
  1091  				"detail", pqErr.Detail,
  1092  			)
  1093  
  1094  			// Temporarily upgrade the lock to a write lock
  1095  			serverContext.Lock.RUnlock()
  1096  			serverContext.Lock.Lock()
  1097  			serverContext.ParentPolicyToID = sync.Map{}
  1098  			serverContext.PolicyToID = sync.Map{}
  1099  			clusterKeyCache = sync.Map{}
  1100  			serverContext.Lock.Unlock()
  1101  			serverContext.Lock.RLock()
  1102  		} else {
  1103  			log.Error(err, "error inserting compliance event", getPqErrKeyVals(err)...)
  1104  		}
  1105  
  1106  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1107  
  1108  		return
  1109  	}
  1110  
  1111  	// remove the spec so it's not returned in the JSON.
  1112  	reqEvent.Policy.Spec = nil
  1113  
  1114  	resp, err := json.Marshal(reqEvent)
  1115  	if err != nil {
  1116  		log.Error(err, "error marshaling reqEvent for the response")
  1117  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1118  
  1119  		return
  1120  	}
  1121  
  1122  	w.WriteHeader(http.StatusCreated)
  1123  
  1124  	if _, err = w.Write(resp); err != nil {
  1125  		log.Error(err, "error writing success response")
  1126  	}
  1127  }
  1128  
  1129  func getComplianceEventsQuery(whereClause string, queryArgs *queryOptions) string {
  1130  	// Getting CSV without the page argument
  1131  	// Query should fetch all rows (unlimited)
  1132  	if queryArgs.PerPage == 0 {
  1133  		return fmt.Sprintf(`%s%s
  1134  		ORDER BY %s %s;`,
  1135  			generateGetComplianceEventsQuery(queryArgs.IncludeSpec),
  1136  			whereClause,
  1137  			strings.Join(queryArgs.Sort, ", "),
  1138  			queryArgs.Direction,
  1139  		)
  1140  	}
  1141  	// Example query
  1142  	//   SELECT compliance_events.id, compliance_events.compliance, ...
  1143  	//     FROM compliance_events
  1144  	//   LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
  1145  	//   LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
  1146  	//   LEFT JOIN policies ON compliance_events.policy_id = policies.id
  1147  	//   WHERE (policies.name=$1 OR policies.name=$2) AND (policies.kind=$3)
  1148  	//   ORDER BY compliance_events.timestamp desc
  1149  	//   LIMIT 20
  1150  	//   OFFSET 0 ROWS;
  1151  	return fmt.Sprintf(`%s%s
  1152  	ORDER BY %s %s
  1153  	LIMIT %d
  1154  	OFFSET %d ROWS;`,
  1155  		generateGetComplianceEventsQuery(queryArgs.IncludeSpec),
  1156  		whereClause,
  1157  		strings.Join(queryArgs.Sort, ", "),
  1158  		queryArgs.Direction,
  1159  		queryArgs.PerPage,
  1160  		(queryArgs.Page-1)*queryArgs.PerPage,
  1161  	)
  1162  }
  1163  
  1164  func setCSVResponseHeaders(w http.ResponseWriter) {
  1165  	w.Header().Set("Content-Disposition", "attachment; filename=reports.csv")
  1166  	w.Header().Set("Content-Type", "text/csv")
  1167  	// It's going to be divided into chunks. if the user don't get it all at once,
  1168  	// the user can receive one by one in the meantime
  1169  	w.Header().Set("Transfer-Encoding", "chunked")
  1170  }
  1171  
  1172  func getComplianceEventsCSV(db *sql.DB, w http.ResponseWriter, r *http.Request,
  1173  	userConfig *rest.Config,
  1174  ) {
  1175  	var writer *csv.Writer
  1176  
  1177  	queryArgs, queryArgsErr := parseQueryArgs(r.Context(), r.URL.Query(), db, userConfig, true)
  1178  	if queryArgs != nil {
  1179  		headers := getCsvHeader(queryArgs.IncludeSpec)
  1180  
  1181  		writer = csv.NewWriter(w)
  1182  
  1183  		err := writer.Write(headers)
  1184  		if err != nil {
  1185  			log.Error(err, "Failed to write csv header")
  1186  			writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1187  
  1188  			return
  1189  		}
  1190  	}
  1191  
  1192  	if queryArgsErr != nil {
  1193  		if errors.Is(queryArgsErr, ErrNoAccess) {
  1194  			setCSVResponseHeaders(w)
  1195  
  1196  			writer.Flush()
  1197  
  1198  			return
  1199  		}
  1200  
  1201  		if errors.Is(queryArgsErr, ErrInvalidQueryArg) || errors.Is(queryArgsErr, ErrInvalidQueryArgValue) ||
  1202  			errors.Is(queryArgsErr, ErrInvalidSortOption) {
  1203  			writeErrMsgJSON(w, queryArgsErr.Error(), http.StatusBadRequest)
  1204  
  1205  			return
  1206  		}
  1207  
  1208  		writeErrMsgJSON(w, queryArgsErr.Error(), http.StatusInternalServerError)
  1209  
  1210  		return
  1211  	}
  1212  
  1213  	// Note that the where clause could be an empty string if no filters were passed in the query arguments.
  1214  	whereClause, filterValues := getWhereClause(queryArgs)
  1215  
  1216  	query := getComplianceEventsQuery(whereClause, queryArgs)
  1217  
  1218  	rows, err := db.QueryContext(r.Context(), query, filterValues...)
  1219  	if err == nil {
  1220  		err = rows.Err()
  1221  	}
  1222  
  1223  	if err != nil {
  1224  		log.Error(err, "Failed to query for compliance events")
  1225  		writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1226  
  1227  		return
  1228  	}
  1229  
  1230  	defer rows.Close()
  1231  
  1232  	for rows.Next() {
  1233  		ce, err := scanIntoComplianceEvent(rows, queryArgs.IncludeSpec)
  1234  		if err != nil {
  1235  			log.Error(err, "Failed to unmarshal the database results")
  1236  			writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1237  
  1238  			return
  1239  		}
  1240  
  1241  		stringValues := convertToCsvLine(ce, queryArgs.IncludeSpec)
  1242  
  1243  		err = writer.Write(stringValues)
  1244  		if err != nil {
  1245  			log.Error(err, "Failed to write csv list")
  1246  			writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)
  1247  
  1248  			return
  1249  		}
  1250  	}
  1251  
  1252  	setCSVResponseHeaders(w)
  1253  
  1254  	writer.Flush()
  1255  }
  1256  
  1257  func convertToCsvLine(ce *ComplianceEvent, includeSpec bool) []string {
  1258  	nilString := ""
  1259  
  1260  	if ce.ParentPolicy == nil {
  1261  		ce.ParentPolicy = &ParentPolicy{
  1262  			KeyID:      0,
  1263  			Name:       "",
  1264  			Namespace:  "",
  1265  			Categories: nil,
  1266  			Controls:   nil,
  1267  			Standards:  nil,
  1268  		}
  1269  	}
  1270  
  1271  	if ce.Event.ReportedBy == nil {
  1272  		ce.Event.ReportedBy = &nilString
  1273  	}
  1274  
  1275  	if ce.Policy.Severity == nil {
  1276  		ce.Policy.Severity = &nilString
  1277  	}
  1278  
  1279  	if ce.Policy.Namespace == nil {
  1280  		ce.Policy.Namespace = &nilString
  1281  	}
  1282  
  1283  	values := []string{
  1284  		convertToString(ce.EventID),
  1285  		convertToString(ce.Event.Compliance),
  1286  		convertToString(ce.Event.Message),
  1287  		convertToString(ce.Event.Metadata),
  1288  		convertToString(*ce.Event.ReportedBy),
  1289  		convertToString(ce.Event.Timestamp),
  1290  		convertToString(ce.Cluster.ClusterID),
  1291  		convertToString(ce.Cluster.Name),
  1292  		convertToString(ce.ParentPolicy.KeyID),
  1293  		convertToString(ce.ParentPolicy.Name),
  1294  		convertToString(ce.ParentPolicy.Namespace),
  1295  		convertToString(ce.ParentPolicy.Categories),
  1296  		convertToString(ce.ParentPolicy.Controls),
  1297  		convertToString(ce.ParentPolicy.Standards),
  1298  		convertToString(ce.Policy.KeyID),
  1299  		convertToString(ce.Policy.APIGroup),
  1300  		convertToString(ce.Policy.Kind),
  1301  		convertToString(ce.Policy.Name),
  1302  		convertToString(*ce.Policy.Namespace),
  1303  		convertToString(*ce.Policy.Severity),
  1304  	}
  1305  
  1306  	if includeSpec {
  1307  		values = append(values, convertToString(ce.Policy.Spec))
  1308  	}
  1309  
  1310  	return values
  1311  }
  1312  
  1313  func convertToString(v interface{}) string {
  1314  	switch vv := v.(type) {
  1315  	case *string:
  1316  		if vv == nil {
  1317  			return ""
  1318  		}
  1319  
  1320  		return *vv
  1321  	case string:
  1322  		return vv
  1323  	case int32:
  1324  		// All int32 related id
  1325  		if int(vv) == 0 {
  1326  			return ""
  1327  		}
  1328  
  1329  		return strconv.Itoa(int(vv))
  1330  	case time.Time:
  1331  		return vv.String()
  1332  	case pq.StringArray:
  1333  		// nil will be []
  1334  		return strings.Join(vv, ", ")
  1335  	case bool:
  1336  		return strconv.FormatBool(vv)
  1337  	case JSONMap:
  1338  		if vv == nil {
  1339  			return ""
  1340  		}
  1341  
  1342  		jsonByte, err := json.MarshalIndent(vv, "", "  ")
  1343  		if err != nil {
  1344  			return ""
  1345  		}
  1346  
  1347  		return string(jsonByte)
  1348  	default:
  1349  		// case nil:
  1350  		return fmt.Sprintf("%v", vv)
  1351  	}
  1352  }
  1353  
  1354  // GetClusterForeignKey will return the database ID based on the cluster.ClusterID.
  1355  func GetClusterForeignKey(ctx context.Context, db *sql.DB, cluster Cluster) (int32, error) {
  1356  	// Check cache
  1357  	key, ok := clusterKeyCache.Load(cluster.ClusterID)
  1358  	if ok {
  1359  		return key.(int32), nil
  1360  	}
  1361  
  1362  	err := cluster.GetOrCreate(ctx, db)
  1363  	if err != nil {
  1364  		return 0, err
  1365  	}
  1366  
  1367  	clusterKeyCache.Store(cluster.ClusterID, cluster.KeyID)
  1368  
  1369  	return cluster.KeyID, nil
  1370  }
  1371  
  1372  func getParentPolicyForeignKey(
  1373  	ctx context.Context, complianceServerCtx *ComplianceServerCtx, parent ParentPolicy,
  1374  ) (int32, error) {
  1375  	if parent.KeyID != 0 {
  1376  		return parent.KeyID, nil
  1377  	}
  1378  
  1379  	// Check cache
  1380  	parKey := parent.Key()
  1381  
  1382  	key, ok := complianceServerCtx.ParentPolicyToID.Load(parKey)
  1383  	if ok {
  1384  		return key.(int32), nil
  1385  	}
  1386  
  1387  	err := parent.GetOrCreate(ctx, complianceServerCtx.DB)
  1388  	if err != nil {
  1389  		return 0, err
  1390  	}
  1391  
  1392  	complianceServerCtx.ParentPolicyToID.Store(parKey, parent.KeyID)
  1393  
  1394  	return parent.KeyID, nil
  1395  }
  1396  
  1397  func getPolicyForeignKey(ctx context.Context, complianceServerCtx *ComplianceServerCtx, pol Policy) (int32, error) {
  1398  	if pol.KeyID != 0 {
  1399  		return pol.KeyID, nil
  1400  	}
  1401  
  1402  	// Check cache
  1403  	polKey := pol.Key()
  1404  
  1405  	key, ok := complianceServerCtx.PolicyToID.Load(polKey)
  1406  	if ok {
  1407  		return key.(int32), nil
  1408  	}
  1409  
  1410  	err := pol.GetOrCreate(ctx, complianceServerCtx.DB)
  1411  	if err != nil {
  1412  		return 0, err
  1413  	}
  1414  
  1415  	complianceServerCtx.PolicyToID.Store(polKey, pol.KeyID)
  1416  
  1417  	return pol.KeyID, nil
  1418  }
  1419  
  1420  type errorMessage struct {
  1421  	Message string `json:"message"`
  1422  }
  1423  
  1424  // writeErrMsgJSON wraps the given message in JSON like `{"message": <>}` and
  1425  // writes the response, setting the header to the given code. Since this message
  1426  // will be read by the user, take care not to leak any sensitive details that
  1427  // might be in the error message.
  1428  func writeErrMsgJSON(w http.ResponseWriter, message string, code int) {
  1429  	msg := errorMessage{Message: message}
  1430  
  1431  	resp, err := json.Marshal(msg)
  1432  	if err != nil {
  1433  		log.Error(err, "error marshaling error message", "message", message)
  1434  	}
  1435  
  1436  	w.WriteHeader(code)
  1437  
  1438  	if _, err := w.Write(resp); err != nil {
  1439  		log.Error(err, "error writing error message")
  1440  	}
  1441  }