open-cluster-management.io/governance-policy-propagator@v0.13.0/controllers/complianceeventsapi/server.go (about) 1 package complianceeventsapi 2 3 import ( 4 "context" 5 "crypto/tls" 6 "database/sql" 7 "encoding/csv" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "io" 12 stdlog "log" 13 "math" 14 "net" 15 "net/http" 16 "net/url" 17 "reflect" 18 "slices" 19 "sort" 20 "strconv" 21 "strings" 22 "sync" 23 "time" 24 25 "github.com/lib/pq" 26 "k8s.io/client-go/rest" 27 ) 28 29 // init dynamically parses the database columns of each struct type to create a mapping of user provided sort/filter 30 // options to the equivalent SQL column. ErrInvalidSortOption, ErrInvalidQueryArg, and validQueryArgs are also defined 31 // with the available sort/query options to choose from. 32 func init() { 33 tableToStruct := map[string]any{ 34 "clusters": Cluster{}, 35 "compliance_events": EventDetails{}, 36 "parent_policies": ParentPolicy{}, 37 "policies": Policy{}, 38 } 39 40 tableNameToJSONName := map[string]string{ 41 "clusters": "cluster", 42 "compliance_events": "event", 43 "parent_policies": "parent_policy", 44 "policies": "policy", 45 } 46 47 // ID is a special case since it's displayed at the top-level in the JSON but is actually in the compliance_events 48 // table. 49 queryOptionsToSQL = map[string]string{"id": "compliance_events.id"} 50 sortOptionsKeys := []string{"id"} 51 52 for tableName, tableStruct := range tableToStruct { 53 structType := reflect.TypeOf(tableStruct) 54 for i := 0; i < structType.NumField(); i++ { 55 structField := structType.Field(i) 56 57 jsonField := structField.Tag.Get("json") 58 if jsonField == "" || jsonField == "-" { 59 continue 60 } 61 62 // This removes additional text in tag such as capturing `spec` from `json:"spec,omitempty"`. 63 jsonField = strings.SplitN(jsonField, ",", 2)[0] 64 65 dbColumn := structField.Tag.Get("db") 66 if dbColumn == "" { 67 continue 68 } 69 70 // Skip JSONB columns as sortable options 71 if tableName == "policies" && dbColumn == "spec" { 72 continue 73 } 74 75 if tableName == "compliance_events" && dbColumn == "metadata" { 76 continue 77 } 78 79 queryOption := fmt.Sprintf("%s.%s", tableNameToJSONName[tableName], jsonField) 80 81 sortOptionsKeys = append(sortOptionsKeys, queryOption) 82 validQueryArgs = append(validQueryArgs, queryOption) 83 84 queryOptionsToSQL[queryOption] = fmt.Sprintf("%s.%s", tableName, dbColumn) 85 } 86 } 87 88 sort.Strings(sortOptionsKeys) 89 90 ErrInvalidSortOption = fmt.Errorf( 91 "an invalid sort option was provided, choose from: %s", strings.Join(sortOptionsKeys, ", "), 92 ) 93 94 validQueryArgs = []string{ 95 "direction", 96 "event.message_includes", 97 "event.message_like", 98 "event.timestamp_after", 99 "event.timestamp_before", 100 "include_spec", 101 "page", 102 "per_page", 103 "sort", 104 } 105 106 validQueryArgs = append( 107 validQueryArgs, 108 // sortOptionsKeys are all filterable columns. 109 sortOptionsKeys..., 110 ) 111 112 sort.Strings(validQueryArgs) 113 114 ErrInvalidQueryArg = fmt.Errorf( 115 "an invalid query argument was provided, choose from: %s", strings.Join(validQueryArgs, ", "), 116 ) 117 } 118 119 const ( 120 postgresForeignKeyViolationCode = "23503" 121 postgresUniqueViolationCode = "23505" 122 ) 123 124 var ( 125 clusterKeyCache sync.Map 126 queryOptionsToSQL map[string]string 127 validQueryArgs []string 128 ErrInvalidSortOption error 129 ErrInvalidQueryArgValue = errors.New("invalid query argument") 130 ErrInvalidQueryArg error 131 ErrUnauthorized = errors.New("not authorized") 132 ErrForbidden = errors.New("the request is not allowed") 133 // The user has no access to any managed cluster 134 ErrNoAccess = errors.New("the user has no access") 135 ) 136 137 type ComplianceAPIServer struct { 138 server *http.Server 139 addr string 140 cert *tls.Certificate 141 cfg *rest.Config 142 } 143 144 func NewComplianceAPIServer(listenAddress string, cfg *rest.Config, cert *tls.Certificate) *ComplianceAPIServer { 145 return &ComplianceAPIServer{ 146 addr: listenAddress, 147 cert: cert, 148 cfg: cfg, 149 } 150 } 151 152 type serverErrorLogWriter struct{} 153 154 func (*serverErrorLogWriter) Write(p []byte) (int, error) { 155 m := string(p) 156 157 // The OpenShift router (haproxy) seems to perform TCP checks to see if the connection is available. When it does 158 // this, it resets the connection when done, which causes a log message every 5 seconds, so this will filter it out. 159 if strings.HasPrefix(m, "http: TLS handshake error") && strings.HasSuffix(m, ": connection reset by peer\n") { 160 log.V(2).Info(m) 161 } else { 162 log.Info(m) 163 } 164 165 return len(p), nil 166 } 167 168 func newServerErrorLog() *stdlog.Logger { 169 return stdlog.New(&serverErrorLogWriter{}, "", 0) 170 } 171 172 // Start starts the HTTP server and blocks until ctx is closed or there was an error starting the 173 // HTTP server. 174 func (s *ComplianceAPIServer) Start(ctx context.Context, serverContext *ComplianceServerCtx) error { 175 mux := http.NewServeMux() 176 177 s.server = &http.Server{ 178 Addr: s.addr, 179 Handler: mux, 180 181 // need to investigate ideal values for these 182 ReadTimeout: 15 * time.Second, 183 WriteTimeout: 15 * time.Second, 184 IdleTimeout: 15 * time.Second, 185 ErrorLog: newServerErrorLog(), 186 } 187 188 listener, err := net.Listen("tcp", s.addr) 189 if err != nil { 190 return err 191 } 192 193 if s.cert != nil { 194 s.server.TLSConfig = &tls.Config{ 195 MinVersion: tls.VersionTLS12, 196 Certificates: []tls.Certificate{*s.cert}, 197 } 198 199 listener = tls.NewListener(listener, s.server.TLSConfig) 200 } 201 202 // register handlers here 203 mux.HandleFunc("/api/v1/compliance-events", func(w http.ResponseWriter, r *http.Request) { 204 w.Header().Set("Content-Type", "application/json") 205 206 serverContext.Lock.RLock() 207 defer serverContext.Lock.RUnlock() 208 209 if serverContext.DB == nil || serverContext.DB.PingContext(r.Context()) != nil { 210 writeErrMsgJSON(w, "The database is unavailable", http.StatusInternalServerError) 211 212 return 213 } 214 215 switch r.Method { 216 case http.MethodGet: 217 // To verify each request independently 218 userConfig, err := getUserKubeConfig(s.cfg, r) 219 if err != nil { 220 if errors.Is(err, ErrUnauthorized) { 221 writeErrMsgJSON(w, "The Authorization header is not set", http.StatusUnauthorized) 222 } 223 224 return 225 } 226 getComplianceEvents(serverContext.DB, w, r, userConfig) 227 case http.MethodPost: 228 postComplianceEvent(serverContext, s.cfg, w, r) 229 default: 230 writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed) 231 } 232 }) 233 234 mux.HandleFunc("/api/v1/compliance-events/", func(w http.ResponseWriter, r *http.Request) { 235 w.Header().Set("Content-Type", "application/json") 236 237 serverContext.Lock.RLock() 238 defer serverContext.Lock.RUnlock() 239 240 if serverContext.DB == nil || serverContext.DB.PingContext(r.Context()) != nil { 241 writeErrMsgJSON(w, "The database is unavailable", http.StatusInternalServerError) 242 243 return 244 } 245 246 if r.Method != http.MethodGet { 247 writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed) 248 249 return 250 } 251 252 // To verify each request independently 253 userConfig, err := getUserKubeConfig(s.cfg, r) 254 if err != nil { 255 if errors.Is(err, ErrUnauthorized) { 256 writeErrMsgJSON(w, "The Authorization header is not set", http.StatusUnauthorized) 257 } 258 259 return 260 } 261 262 getSingleComplianceEvent(serverContext.DB, w, r, userConfig) 263 }) 264 265 mux.HandleFunc("/api/v1/reports/compliance-events", func(w http.ResponseWriter, r *http.Request) { 266 // This header is for error writings 267 w.Header().Set("Content-Type", "application/json") 268 269 // To verify each request independently 270 userConfig, err := getUserKubeConfig(s.cfg, r) 271 if err != nil { 272 if errors.Is(err, ErrUnauthorized) { 273 writeErrMsgJSON(w, "The Authorization header is not set", http.StatusUnauthorized) 274 } 275 276 return 277 } 278 279 if r.Method != http.MethodGet { 280 writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed) 281 282 return 283 } 284 285 getComplianceEventsCSV(serverContext.DB, w, r, userConfig) 286 }) 287 288 serveErr := make(chan error) 289 290 go func() { 291 defer close(serveErr) 292 293 err := s.server.Serve(listener) 294 if err != nil && err != http.ErrServerClosed { 295 serveErr <- err 296 } 297 }() 298 299 select { 300 case <-ctx.Done(): 301 err := s.server.Shutdown(context.Background()) 302 if err != nil { 303 log.Error(err, "Failed to shutdown the compliance API server") 304 } 305 306 return nil 307 case err, closed := <-serveErr: 308 if err != nil { 309 return err 310 } 311 312 if closed { 313 return errors.New("the compliance API server unexpectedly shutdown without an error") 314 } 315 316 return nil 317 } 318 } 319 320 // splitQueryValue will parse a string and split on unescaped commas. Empty values are discarded. 321 func splitQueryValue(value string) []string { 322 values := []string{} 323 324 var currentVal string 325 var previousChar rune 326 327 for _, char := range value { 328 if char == ',' { 329 if previousChar == '\\' { 330 // This comma was escaped, so remove the escape character and keep the comma. Runes are used in case 331 // unicode characters are present in currentVal. 332 runeCurrentVal := []rune(currentVal) 333 currentVal = string(runeCurrentVal[:len(runeCurrentVal)-1]) + "," 334 } else { 335 // The comma was not escaped so we encountered a new value. 336 if currentVal != "" { 337 values = append(values, currentVal) 338 } 339 340 currentVal = "" 341 } 342 } else { 343 // A non-special character was encountered so just append the character. 344 currentVal += string(char) 345 } 346 347 previousChar = char 348 } 349 350 if currentVal != "" { 351 values = append(values, currentVal) 352 } 353 354 return values 355 } 356 357 // parseQueryArgs will parse the HTTP request's query arguments and convert them to a usable format for constructing 358 // the SQL query. All defaults are set and any invalid query arguments result in an error being returned. 359 func parseQueryArgs(ctx context.Context, queryArgs url.Values, db *sql.DB, 360 userConfig *rest.Config, isCSV bool, 361 ) (*queryOptions, error) { 362 parsed := &queryOptions{ 363 Direction: "desc", 364 Page: 1, 365 PerPage: 20, 366 Sort: []string{"compliance_events.timestamp"}, 367 ArrayFilters: map[string][]string{}, 368 Filters: map[string][]string{}, 369 NullFilters: []string{}, 370 } 371 372 // Case return CSV file, default PerPage is 0. Unlimited 373 if isCSV { 374 parsed.PerPage = 0 375 } 376 377 for arg := range queryArgs { 378 valid := false 379 380 for _, validQueryArg := range validQueryArgs { 381 if arg == validQueryArg { 382 valid = true 383 384 break 385 } 386 } 387 388 if !valid { 389 return nil, ErrInvalidQueryArg 390 } 391 392 sqlName, hasSQLName := queryOptionsToSQL[arg] 393 394 value := queryArgs.Get(arg) 395 if value == "" && arg != "include_spec" { 396 // Only support null filters if it's a SQL column 397 if !hasSQLName { 398 return nil, fmt.Errorf("%w: %s must have a value", ErrInvalidQueryArgValue, arg) 399 } 400 401 parsed.NullFilters = append(parsed.NullFilters, sqlName) 402 403 continue 404 } 405 406 switch arg { 407 case "direction": 408 if value == "desc" { 409 parsed.Direction = "DESC" 410 } else if value == "asc" { 411 parsed.Direction = "ASC" 412 } else { 413 return nil, fmt.Errorf("%w: direction must be one of: asc, desc", ErrInvalidQueryArg) 414 } 415 case "include_spec": 416 if value != "" { 417 return nil, fmt.Errorf("%w: include_spec is a flag and does not accept a value", ErrInvalidQueryArg) 418 } 419 420 parsed.IncludeSpec = true 421 case "page": 422 var err error 423 424 parsed.Page, err = strconv.ParseUint(value, 10, 64) 425 if err != nil || parsed.Page == 0 { 426 return nil, fmt.Errorf("%w: page must be a positive integer", ErrInvalidQueryArg) 427 } 428 case "per_page": 429 var err error 430 431 parsed.PerPage, err = strconv.ParseUint(value, 10, 64) 432 if err != nil || parsed.PerPage == 0 || parsed.PerPage > 100 { 433 return nil, fmt.Errorf("%w: per_page must be a value between 1 and 100", ErrInvalidQueryArg) 434 } 435 case "sort": 436 sortArgs := splitQueryValue(value) 437 438 sortSQL := []string{} 439 440 for _, sortArg := range sortArgs { 441 sortOption, ok := queryOptionsToSQL[sortArg] 442 if !ok { 443 return nil, ErrInvalidSortOption 444 } 445 446 sortSQL = append(sortSQL, sortOption) 447 } 448 449 parsed.Sort = sortSQL 450 case "parent_policy.categories", "parent_policy.controls", "parent_policy.standards": 451 parsed.ArrayFilters[sqlName] = splitQueryValue(value) 452 case "event.message_includes": 453 // Escape the SQL LIKE operators because we aren't exposing that functionality. 454 escapedVal := strings.ReplaceAll(value, "%", `\%`) 455 escapedVal = strings.ReplaceAll(escapedVal, "_", `\_`) 456 // Add wildcards at the beginning and end of the search keyword for substring matching. 457 parsed.MessageIncludes = "%" + escapedVal + "%" 458 case "event.message_like": 459 parsed.MessageLike = value 460 case "event.timestamp_before": 461 var err error 462 463 parsed.TimestampBefore, err = time.Parse(time.RFC3339, value) 464 if err != nil { 465 return nil, fmt.Errorf( 466 "%w: event.timestamp_before must be in the format of RFC 3339", ErrInvalidQueryArgValue, 467 ) 468 } 469 case "event.timestamp_after": 470 var err error 471 472 parsed.TimestampAfter, err = time.Parse(time.RFC3339, value) 473 if err != nil { 474 return nil, fmt.Errorf( 475 "%w: event.timestamp_after must be in the format of RFC 3339", ErrInvalidQueryArgValue, 476 ) 477 } 478 default: 479 // Standard string filtering 480 parsed.Filters[sqlName] = splitQueryValue(value) 481 } 482 } 483 484 parsed, err := setAuthorizedClusters(ctx, db, parsed, userConfig) 485 if err != nil { 486 // ErrNoAccess needs queryOptions 487 return parsed, err 488 } 489 490 return parsed, nil 491 } 492 493 // setAuthorizedClusters verifies that if a cluster filter is provided, 494 // the user has access to this filter. If no cluster filter is provided, 495 // it sets the cluster filter to all managed clusters the user has access to. 496 // If the user has no access, then ErrNoAccess is returned. 497 func setAuthorizedClusters(ctx context.Context, db *sql.DB, parsed *queryOptions, 498 userConfig *rest.Config, 499 ) (*queryOptions, error) { 500 unAuthorizedClusters := []string{} 501 502 // Get all managedCluster rules 503 allRules, err := getManagedClusterRules(userConfig, nil) 504 if err != nil { 505 return parsed, err 506 } 507 508 if slices.Contains(allRules["*"], "get") || slices.Contains(allRules["*"], "*") { 509 return parsed, nil 510 } 511 512 clusterIDs := parsed.Filters["clusters.cluster_id"] 513 // Temporarily reset clusters.cluster_id and repopulate with all known cluster IDs 514 parsed.Filters["clusters.cluster_id"] = []string{} 515 516 // Convert id to name 517 for _, id := range clusterIDs { 518 clusterName, err := getClusterNameFromID(ctx, db, id) 519 if err != nil { 520 if errors.Is(err, sql.ErrNoRows) { 521 // Filter out invalid cluster IDs from the query 522 continue 523 } 524 525 log.Error(err, "Failed to get cluster name from cluster ID", getPqErrKeyVals(err, "ID", id)...) 526 527 return parsed, err 528 } 529 530 if !getAccessByClusterName(allRules, clusterName) { 531 unAuthorizedClusters = append(unAuthorizedClusters, id) 532 } else { 533 parsed.Filters["clusters.cluster_id"] = append(parsed.Filters["clusters.cluster_id"], id) 534 } 535 } 536 537 parsedClusterNames := parsed.Filters["clusters.name"] 538 for _, clusterName := range parsedClusterNames { 539 if !getAccessByClusterName(allRules, clusterName) { 540 unAuthorizedClusters = append(unAuthorizedClusters, clusterName) 541 } 542 } 543 544 // There is no cluster.cluster_id or cluster.name query argument. 545 // In other words, the user requests all they have access to. 546 if len(clusterIDs) == 0 && len(parsedClusterNames) == 0 { 547 for mcName := range allRules { 548 // Add the cluster to the filter if the user has get authentication. Note that if the user has get access 549 // on all managed clusters, that gets handled at the beginning of the function. 550 if getAccessByClusterName(allRules, mcName) { 551 parsed.Filters["clusters.name"] = append(parsed.Filters["clusters.name"], mcName) 552 } 553 } 554 } 555 556 if len(unAuthorizedClusters) > 0 { 557 return parsed, fmt.Errorf("%w: the following cluster filters are not authorized: %s", 558 ErrForbidden, strings.Join(unAuthorizedClusters, ", ")) 559 } 560 561 if len(parsed.Filters["clusters.name"]) == 0 && len(parsed.Filters["clusters.cluster_id"]) == 0 { 562 return parsed, ErrNoAccess 563 } 564 565 return parsed, nil 566 } 567 568 // generateGetComplianceEventsQuery will return a SELECT query with results ready to be parsed by 569 // scanIntoComplianceEvent. The caller is responsible for adding filters to the query. 570 func generateGetComplianceEventsQuery(includeSpec bool) string { 571 return fmt.Sprintf(`SELECT %s 572 FROM 573 compliance_events 574 LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id 575 LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id 576 LEFT JOIN policies ON compliance_events.policy_id = policies.id`, 577 strings.Join(generateSelectedArgs(includeSpec), ", "), 578 ) 579 } 580 581 func generateSelectedArgs(includeSpec bool) []string { 582 selectArgs := []string{ 583 "compliance_events.id", 584 "compliance_events.compliance", 585 "compliance_events.message", 586 "compliance_events.metadata", 587 "compliance_events.reported_by", 588 "compliance_events.timestamp", 589 "clusters.cluster_id", 590 "clusters.name", 591 "parent_policies.id", 592 "parent_policies.name", 593 "parent_policies.namespace", 594 "parent_policies.categories", 595 "parent_policies.controls", 596 "parent_policies.standards", 597 "policies.id", 598 "policies.api_group", 599 "policies.kind", 600 "policies.name", 601 "policies.namespace", 602 "policies.severity", 603 } 604 605 if includeSpec { 606 selectArgs = append(selectArgs, "policies.spec") 607 } 608 609 return selectArgs 610 } 611 612 // generate Headers for CSV. "." replace by "_" 613 // Example: parent_policies.namespace -> parent_policies_namespace 614 func getCsvHeader(includeSpec bool) []string { 615 localSelectArgs := generateSelectedArgs(includeSpec) 616 617 for i, arg := range localSelectArgs { 618 localSelectArgs[i] = strings.ReplaceAll(arg, ".", "_") 619 } 620 621 return localSelectArgs 622 } 623 624 type Scannable interface { 625 Scan(dest ...any) error 626 } 627 628 // scanIntoComplianceEvent will scan the row result from the SELECT query generated by generateGetComplianceEventsQuery 629 // into a ComplianceEvent object. 630 func scanIntoComplianceEvent(rows Scannable, includeSpec bool) (*ComplianceEvent, error) { 631 ce := ComplianceEvent{ 632 Cluster: Cluster{}, 633 Event: EventDetails{}, 634 ParentPolicy: nil, 635 Policy: Policy{}, 636 } 637 638 ppID := sql.NullInt32{} 639 ppName := sql.NullString{} 640 ppNamespace := sql.NullString{} 641 ppCategories := pq.StringArray{} 642 ppControls := pq.StringArray{} 643 ppStandards := pq.StringArray{} 644 645 scanArgs := []any{ 646 &ce.EventID, 647 &ce.Event.Compliance, 648 &ce.Event.Message, 649 &ce.Event.Metadata, 650 &ce.Event.ReportedBy, 651 &ce.Event.Timestamp, 652 &ce.Cluster.ClusterID, 653 &ce.Cluster.Name, 654 &ppID, 655 &ppName, 656 &ppNamespace, 657 &ppCategories, 658 &ppControls, 659 &ppStandards, 660 &ce.Policy.KeyID, 661 &ce.Policy.APIGroup, 662 &ce.Policy.Kind, 663 &ce.Policy.Name, 664 &ce.Policy.Namespace, 665 &ce.Policy.Severity, 666 } 667 668 if includeSpec { 669 scanArgs = append(scanArgs, &ce.Policy.Spec) 670 } 671 672 err := rows.Scan(scanArgs...) 673 if err != nil { 674 return nil, err 675 } 676 677 // The parent policy is optional but when it's set, the name is guaranteed to be set. 678 if ppName.String != "" { 679 ce.ParentPolicy = &ParentPolicy{ 680 KeyID: ppID.Int32, 681 Name: ppName.String, 682 Namespace: ppNamespace.String, 683 Categories: ppCategories, 684 Controls: ppControls, 685 Standards: ppStandards, 686 } 687 } 688 689 return &ce, nil 690 } 691 692 // getSingleComplianceEvent handles the GET API endpoint for a single compliance event by ID. 693 func getSingleComplianceEvent(db *sql.DB, w http.ResponseWriter, 694 r *http.Request, config *rest.Config, 695 ) { 696 eventIDStr := strings.TrimPrefix(r.URL.Path, "/api/v1/compliance-events/") 697 698 eventID, err := strconv.ParseUint(eventIDStr, 10, 64) 699 if err != nil { 700 writeErrMsgJSON(w, "The provided compliance event ID is invalid", http.StatusBadRequest) 701 702 return 703 } 704 705 query := fmt.Sprintf("%s\nWHERE compliance_events.id = $1;", generateGetComplianceEventsQuery(true)) 706 707 row := db.QueryRowContext(r.Context(), query, eventID) 708 if row.Err() != nil { 709 log.Error(row.Err(), "Failed to query for the compliance event", "eventID", eventID) 710 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 711 712 return 713 } 714 715 complianceEvent, err := scanIntoComplianceEvent(row, true) 716 if err != nil { 717 if errors.Is(err, sql.ErrNoRows) { 718 writeErrMsgJSON(w, "The requested compliance event was not found", http.StatusNotFound) 719 720 return 721 } 722 723 log.Error(err, "Failed to unmarshal the database results", getPqErrKeyVals(err)...) 724 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 725 726 return 727 } 728 729 // Check auth for managedCluster GET verb 730 isAllowed, err := canGetManagedCluster(config, complianceEvent.Cluster.Name) 731 if err != nil { 732 log.Error(err, `Failed to get the "get" authorization for the cluster`, 733 "cluster", complianceEvent.Cluster.Name) 734 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 735 736 return 737 } 738 739 if !isAllowed { 740 writeErrMsgJSON(w, "Forbidden", http.StatusForbidden) 741 742 return 743 } 744 745 jsonResp, err := json.Marshal(complianceEvent) 746 if err != nil { 747 log.Error(err, "Failed marshal the compliance event", "eventID", eventID) 748 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 749 750 return 751 } 752 753 if _, err = w.Write(jsonResp); err != nil { 754 log.Error(err, "Error writing success response") 755 } 756 } 757 758 // getPqErrKeyVals is a helper to add additional database error details to a log message. additionalKeyVals is provided 759 // as a convenience so that the keys don't need to be explicitly set to interface{} types when using the 760 // `getPqErrKeyVals(err, "key1", "val1")...“ syntax. 761 func getPqErrKeyVals(err error, additionalKeyVals ...interface{}) []interface{} { 762 var pqErr *pq.Error 763 764 if errors.As(err, &pqErr) { 765 return append( 766 []interface{}{"dbMessage", pqErr.Message, "dbDetail", pqErr.Detail, "dbCode", pqErr.Code}, 767 additionalKeyVals..., 768 ) 769 } 770 771 return additionalKeyVals 772 } 773 774 func getClusterNameFromID(ctx context.Context, db *sql.DB, clusterID string) (name string, err error) { 775 err = db.QueryRowContext(ctx, 776 `SELECT name FROM clusters WHERE cluster_id = $1`, clusterID, 777 ).Scan(&name) 778 if err != nil { 779 return "", err 780 } 781 782 return name, nil 783 } 784 785 // getWhereClause will convert the input queryOptions to a WHERE statement and return the filter values for a prepared 786 // statement. 787 func getWhereClause(options *queryOptions) (string, []any) { 788 filterSQL := []string{} 789 filterValues := []any{} 790 791 for sqlColumn, values := range options.Filters { 792 if len(values) == 0 { 793 continue 794 } 795 796 for i, value := range values { 797 filterValues = append(filterValues, value) 798 // For example: compliance_events.name=$1 799 filter := fmt.Sprintf("%s=$%d", sqlColumn, len(filterValues)) 800 if i == 0 { 801 filterSQL = append(filterSQL, "("+filter) 802 } else { 803 filterSQL[len(filterSQL)-1] += " OR " + filter 804 } 805 } 806 807 filterSQL[len(filterSQL)-1] += ")" 808 } 809 810 for sqlColumn, values := range options.ArrayFilters { 811 if len(values) == 0 { 812 continue 813 } 814 815 for i, value := range values { 816 filterValues = append(filterValues, value) 817 818 // For example: $1=ANY(parent_policies.categories) 819 filter := fmt.Sprintf("$%d=ANY(%s)", len(filterValues), sqlColumn) 820 if i == 0 { 821 filterSQL = append(filterSQL, "("+filter) 822 } else { 823 filterSQL[len(filterSQL)-1] += " OR " + filter 824 } 825 } 826 827 filterSQL[len(filterSQL)-1] += ")" 828 } 829 830 for _, sqlColumn := range options.NullFilters { 831 filterSQL = append(filterSQL, fmt.Sprintf("%s IS NULL", sqlColumn)) 832 } 833 834 if options.MessageIncludes != "" { 835 filterValues = append(filterValues, options.MessageIncludes) 836 837 filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.message LIKE $%d", len(filterValues))) 838 } 839 840 if options.MessageLike != "" { 841 filterValues = append(filterValues, options.MessageLike) 842 843 filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.message LIKE $%d", len(filterValues))) 844 } 845 846 if !options.TimestampAfter.IsZero() { 847 filterValues = append(filterValues, options.TimestampAfter) 848 849 filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.timestamp > $%d", len(filterValues))) 850 } 851 852 if !options.TimestampBefore.IsZero() { 853 filterValues = append(filterValues, options.TimestampBefore) 854 855 filterSQL = append(filterSQL, fmt.Sprintf("compliance_events.timestamp < $%d", len(filterValues))) 856 } 857 858 var whereClause string 859 860 if len(filterSQL) > 0 { 861 // For example: 862 // WHERE (policy.name=$1) AND ($2=ANY(parent_policies.categories) OR $3=ANY(parent_policies.categories)) 863 whereClause = "\nWHERE " + strings.Join(filterSQL, " AND ") 864 } 865 866 return whereClause, filterValues 867 } 868 869 // getComplianceEvents handles the list API endpoint for compliance events. 870 func getComplianceEvents(db *sql.DB, w http.ResponseWriter, 871 r *http.Request, userConfig *rest.Config, 872 ) { 873 queryArgs, err := parseQueryArgs(r.Context(), r.URL.Query(), db, userConfig, false) 874 if err != nil { 875 if errors.Is(err, ErrForbidden) { 876 writeErrMsgJSON(w, err.Error(), http.StatusForbidden) 877 878 return 879 } 880 881 if errors.Is(err, ErrNoAccess) { 882 response := ListResponse{ 883 Data: []ComplianceEvent{}, 884 Metadata: metadata{ 885 Page: queryArgs.Page, 886 Pages: 0, 887 PerPage: queryArgs.PerPage, 888 Total: 0, 889 }, 890 } 891 892 jsonResp, err := json.Marshal(response) 893 if err != nil { 894 log.Error(err, "Failed to marshal an empty response") 895 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 896 897 return 898 } 899 900 if _, err = w.Write(jsonResp); err != nil { 901 log.Error(err, "Error writing empty response") 902 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 903 } 904 905 return 906 } 907 908 if errors.Is(err, ErrInvalidQueryArg) || errors.Is(err, ErrInvalidQueryArgValue) || 909 errors.Is(err, ErrInvalidSortOption) { 910 writeErrMsgJSON(w, err.Error(), http.StatusBadRequest) 911 912 return 913 } 914 915 writeErrMsgJSON(w, err.Error(), http.StatusInternalServerError) 916 917 return 918 } 919 920 // Note that the where clause could be an empty string if not filters were passed in the query arguments. 921 whereClause, filterValues := getWhereClause(queryArgs) 922 923 query := getComplianceEventsQuery(whereClause, queryArgs) 924 925 rows, err := db.QueryContext(r.Context(), query, filterValues...) 926 if err == nil { 927 err = rows.Err() 928 } 929 930 if err != nil { 931 log.Error(err, "Failed to query for compliance events") 932 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 933 934 return 935 } 936 937 defer rows.Close() 938 939 complianceEvents := make([]ComplianceEvent, 0, queryArgs.PerPage) 940 941 for rows.Next() { 942 ce, err := scanIntoComplianceEvent(rows, queryArgs.IncludeSpec) 943 if err != nil { 944 log.Error(err, "Failed to unmarshal the database results") 945 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 946 947 return 948 } 949 950 complianceEvents = append(complianceEvents, *ce) 951 } 952 953 countQuery := `SELECT COUNT(*) FROM compliance_events 954 LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id 955 LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id 956 LEFT JOIN policies ON compliance_events.policy_id = policies.id` + whereClause // #nosec G202 957 958 row := db.QueryRowContext(r.Context(), countQuery, filterValues...) 959 960 var total uint64 961 962 if err := row.Scan(&total); err != nil { 963 log.Error(err, "Failed to get the count of compliance events", getPqErrKeyVals(err)...) 964 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 965 966 return 967 } 968 969 pages := math.Ceil(float64(total) / float64(queryArgs.PerPage)) 970 971 response := ListResponse{ 972 Data: complianceEvents, 973 Metadata: metadata{ 974 Page: queryArgs.Page, 975 Pages: uint64(pages), 976 PerPage: queryArgs.PerPage, 977 Total: total, 978 }, 979 } 980 981 jsonResp, err := json.Marshal(response) 982 if err != nil { 983 log.Error(err, "Failed to marshal the response") 984 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 985 986 return 987 } 988 989 if _, err = w.Write(jsonResp); err != nil { 990 log.Error(err, "Error writing success response") 991 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 992 993 return 994 } 995 } 996 997 // postComplianceEvent assumes you have a read lock already attained. 998 func postComplianceEvent(serverContext *ComplianceServerCtx, cfg *rest.Config, w http.ResponseWriter, r *http.Request) { 999 body, err := io.ReadAll(r.Body) 1000 if err != nil { 1001 log.Error(err, "error reading request body") 1002 writeErrMsgJSON(w, "Could not read request body", http.StatusBadRequest) 1003 1004 return 1005 } 1006 1007 reqEvent := &ComplianceEvent{} 1008 1009 if err := json.Unmarshal(body, reqEvent); err != nil { 1010 writeErrMsgJSON(w, "Incorrectly formatted request body, must be valid JSON", http.StatusBadRequest) 1011 1012 return 1013 } 1014 1015 if err := reqEvent.Validate(r.Context(), serverContext); err != nil { 1016 writeErrMsgJSON(w, err.Error(), http.StatusBadRequest) 1017 1018 return 1019 } 1020 1021 allowed, err := canRecordComplianceEvent(cfg, reqEvent.Cluster.Name, r) 1022 if err != nil { 1023 if errors.Is(err, ErrUnauthorized) { 1024 writeErrMsgJSON(w, "Unauthorized", http.StatusUnauthorized) 1025 1026 return 1027 } 1028 1029 log.Error(err, "error determining if the user is authorized for recording compliance events") 1030 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1031 1032 return 1033 } 1034 1035 if !allowed { 1036 // Logging is handled by canRecordComplianceEvent 1037 writeErrMsgJSON(w, "Forbidden", http.StatusForbidden) 1038 1039 return 1040 } 1041 1042 clusterFK, err := GetClusterForeignKey(r.Context(), serverContext.DB, reqEvent.Cluster) 1043 if err != nil { 1044 log.Error(err, "error getting cluster foreign key", getPqErrKeyVals(err)...) 1045 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1046 1047 return 1048 } 1049 1050 reqEvent.Event.ClusterID = clusterFK 1051 1052 if reqEvent.ParentPolicy != nil { 1053 pfk, err := getParentPolicyForeignKey(r.Context(), serverContext, *reqEvent.ParentPolicy) 1054 if err != nil { 1055 log.Error(err, "error getting parent policy foreign key", getPqErrKeyVals(err)...) 1056 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1057 1058 return 1059 } 1060 1061 reqEvent.Event.ParentPolicyID = &pfk 1062 } 1063 1064 policyFK, err := getPolicyForeignKey(r.Context(), serverContext, reqEvent.Policy) 1065 if err != nil { 1066 log.Error(err, "error getting policy foreign key", getPqErrKeyVals(err)...) 1067 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1068 1069 return 1070 } 1071 1072 reqEvent.Event.PolicyID = policyFK 1073 1074 err = reqEvent.Create(r.Context(), serverContext.DB) 1075 if err != nil { 1076 if errors.Is(err, errDuplicateComplianceEvent) { 1077 writeErrMsgJSON(w, "The compliance event already exists", http.StatusConflict) 1078 1079 return 1080 } 1081 1082 var pqErr *pq.Error 1083 1084 if errors.As(err, &pqErr) && pqErr.Code == postgresForeignKeyViolationCode { 1085 // This can only happen if the cache is out of date due to data loss in the database because if the 1086 // database ID is provided, it is validated against the database. 1087 log.Info( 1088 "Encountered a foreign key violation. Assuming the database lost data, so the cache is "+ 1089 "being cleared", 1090 "message", pqErr.Message, 1091 "detail", pqErr.Detail, 1092 ) 1093 1094 // Temporarily upgrade the lock to a write lock 1095 serverContext.Lock.RUnlock() 1096 serverContext.Lock.Lock() 1097 serverContext.ParentPolicyToID = sync.Map{} 1098 serverContext.PolicyToID = sync.Map{} 1099 clusterKeyCache = sync.Map{} 1100 serverContext.Lock.Unlock() 1101 serverContext.Lock.RLock() 1102 } else { 1103 log.Error(err, "error inserting compliance event", getPqErrKeyVals(err)...) 1104 } 1105 1106 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1107 1108 return 1109 } 1110 1111 // remove the spec so it's not returned in the JSON. 1112 reqEvent.Policy.Spec = nil 1113 1114 resp, err := json.Marshal(reqEvent) 1115 if err != nil { 1116 log.Error(err, "error marshaling reqEvent for the response") 1117 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1118 1119 return 1120 } 1121 1122 w.WriteHeader(http.StatusCreated) 1123 1124 if _, err = w.Write(resp); err != nil { 1125 log.Error(err, "error writing success response") 1126 } 1127 } 1128 1129 func getComplianceEventsQuery(whereClause string, queryArgs *queryOptions) string { 1130 // Getting CSV without the page argument 1131 // Query should fetch all rows (unlimited) 1132 if queryArgs.PerPage == 0 { 1133 return fmt.Sprintf(`%s%s 1134 ORDER BY %s %s;`, 1135 generateGetComplianceEventsQuery(queryArgs.IncludeSpec), 1136 whereClause, 1137 strings.Join(queryArgs.Sort, ", "), 1138 queryArgs.Direction, 1139 ) 1140 } 1141 // Example query 1142 // SELECT compliance_events.id, compliance_events.compliance, ... 1143 // FROM compliance_events 1144 // LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id 1145 // LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id 1146 // LEFT JOIN policies ON compliance_events.policy_id = policies.id 1147 // WHERE (policies.name=$1 OR policies.name=$2) AND (policies.kind=$3) 1148 // ORDER BY compliance_events.timestamp desc 1149 // LIMIT 20 1150 // OFFSET 0 ROWS; 1151 return fmt.Sprintf(`%s%s 1152 ORDER BY %s %s 1153 LIMIT %d 1154 OFFSET %d ROWS;`, 1155 generateGetComplianceEventsQuery(queryArgs.IncludeSpec), 1156 whereClause, 1157 strings.Join(queryArgs.Sort, ", "), 1158 queryArgs.Direction, 1159 queryArgs.PerPage, 1160 (queryArgs.Page-1)*queryArgs.PerPage, 1161 ) 1162 } 1163 1164 func setCSVResponseHeaders(w http.ResponseWriter) { 1165 w.Header().Set("Content-Disposition", "attachment; filename=reports.csv") 1166 w.Header().Set("Content-Type", "text/csv") 1167 // It's going to be divided into chunks. if the user don't get it all at once, 1168 // the user can receive one by one in the meantime 1169 w.Header().Set("Transfer-Encoding", "chunked") 1170 } 1171 1172 func getComplianceEventsCSV(db *sql.DB, w http.ResponseWriter, r *http.Request, 1173 userConfig *rest.Config, 1174 ) { 1175 var writer *csv.Writer 1176 1177 queryArgs, queryArgsErr := parseQueryArgs(r.Context(), r.URL.Query(), db, userConfig, true) 1178 if queryArgs != nil { 1179 headers := getCsvHeader(queryArgs.IncludeSpec) 1180 1181 writer = csv.NewWriter(w) 1182 1183 err := writer.Write(headers) 1184 if err != nil { 1185 log.Error(err, "Failed to write csv header") 1186 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1187 1188 return 1189 } 1190 } 1191 1192 if queryArgsErr != nil { 1193 if errors.Is(queryArgsErr, ErrNoAccess) { 1194 setCSVResponseHeaders(w) 1195 1196 writer.Flush() 1197 1198 return 1199 } 1200 1201 if errors.Is(queryArgsErr, ErrInvalidQueryArg) || errors.Is(queryArgsErr, ErrInvalidQueryArgValue) || 1202 errors.Is(queryArgsErr, ErrInvalidSortOption) { 1203 writeErrMsgJSON(w, queryArgsErr.Error(), http.StatusBadRequest) 1204 1205 return 1206 } 1207 1208 writeErrMsgJSON(w, queryArgsErr.Error(), http.StatusInternalServerError) 1209 1210 return 1211 } 1212 1213 // Note that the where clause could be an empty string if no filters were passed in the query arguments. 1214 whereClause, filterValues := getWhereClause(queryArgs) 1215 1216 query := getComplianceEventsQuery(whereClause, queryArgs) 1217 1218 rows, err := db.QueryContext(r.Context(), query, filterValues...) 1219 if err == nil { 1220 err = rows.Err() 1221 } 1222 1223 if err != nil { 1224 log.Error(err, "Failed to query for compliance events") 1225 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1226 1227 return 1228 } 1229 1230 defer rows.Close() 1231 1232 for rows.Next() { 1233 ce, err := scanIntoComplianceEvent(rows, queryArgs.IncludeSpec) 1234 if err != nil { 1235 log.Error(err, "Failed to unmarshal the database results") 1236 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1237 1238 return 1239 } 1240 1241 stringValues := convertToCsvLine(ce, queryArgs.IncludeSpec) 1242 1243 err = writer.Write(stringValues) 1244 if err != nil { 1245 log.Error(err, "Failed to write csv list") 1246 writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError) 1247 1248 return 1249 } 1250 } 1251 1252 setCSVResponseHeaders(w) 1253 1254 writer.Flush() 1255 } 1256 1257 func convertToCsvLine(ce *ComplianceEvent, includeSpec bool) []string { 1258 nilString := "" 1259 1260 if ce.ParentPolicy == nil { 1261 ce.ParentPolicy = &ParentPolicy{ 1262 KeyID: 0, 1263 Name: "", 1264 Namespace: "", 1265 Categories: nil, 1266 Controls: nil, 1267 Standards: nil, 1268 } 1269 } 1270 1271 if ce.Event.ReportedBy == nil { 1272 ce.Event.ReportedBy = &nilString 1273 } 1274 1275 if ce.Policy.Severity == nil { 1276 ce.Policy.Severity = &nilString 1277 } 1278 1279 if ce.Policy.Namespace == nil { 1280 ce.Policy.Namespace = &nilString 1281 } 1282 1283 values := []string{ 1284 convertToString(ce.EventID), 1285 convertToString(ce.Event.Compliance), 1286 convertToString(ce.Event.Message), 1287 convertToString(ce.Event.Metadata), 1288 convertToString(*ce.Event.ReportedBy), 1289 convertToString(ce.Event.Timestamp), 1290 convertToString(ce.Cluster.ClusterID), 1291 convertToString(ce.Cluster.Name), 1292 convertToString(ce.ParentPolicy.KeyID), 1293 convertToString(ce.ParentPolicy.Name), 1294 convertToString(ce.ParentPolicy.Namespace), 1295 convertToString(ce.ParentPolicy.Categories), 1296 convertToString(ce.ParentPolicy.Controls), 1297 convertToString(ce.ParentPolicy.Standards), 1298 convertToString(ce.Policy.KeyID), 1299 convertToString(ce.Policy.APIGroup), 1300 convertToString(ce.Policy.Kind), 1301 convertToString(ce.Policy.Name), 1302 convertToString(*ce.Policy.Namespace), 1303 convertToString(*ce.Policy.Severity), 1304 } 1305 1306 if includeSpec { 1307 values = append(values, convertToString(ce.Policy.Spec)) 1308 } 1309 1310 return values 1311 } 1312 1313 func convertToString(v interface{}) string { 1314 switch vv := v.(type) { 1315 case *string: 1316 if vv == nil { 1317 return "" 1318 } 1319 1320 return *vv 1321 case string: 1322 return vv 1323 case int32: 1324 // All int32 related id 1325 if int(vv) == 0 { 1326 return "" 1327 } 1328 1329 return strconv.Itoa(int(vv)) 1330 case time.Time: 1331 return vv.String() 1332 case pq.StringArray: 1333 // nil will be [] 1334 return strings.Join(vv, ", ") 1335 case bool: 1336 return strconv.FormatBool(vv) 1337 case JSONMap: 1338 if vv == nil { 1339 return "" 1340 } 1341 1342 jsonByte, err := json.MarshalIndent(vv, "", " ") 1343 if err != nil { 1344 return "" 1345 } 1346 1347 return string(jsonByte) 1348 default: 1349 // case nil: 1350 return fmt.Sprintf("%v", vv) 1351 } 1352 } 1353 1354 // GetClusterForeignKey will return the database ID based on the cluster.ClusterID. 1355 func GetClusterForeignKey(ctx context.Context, db *sql.DB, cluster Cluster) (int32, error) { 1356 // Check cache 1357 key, ok := clusterKeyCache.Load(cluster.ClusterID) 1358 if ok { 1359 return key.(int32), nil 1360 } 1361 1362 err := cluster.GetOrCreate(ctx, db) 1363 if err != nil { 1364 return 0, err 1365 } 1366 1367 clusterKeyCache.Store(cluster.ClusterID, cluster.KeyID) 1368 1369 return cluster.KeyID, nil 1370 } 1371 1372 func getParentPolicyForeignKey( 1373 ctx context.Context, complianceServerCtx *ComplianceServerCtx, parent ParentPolicy, 1374 ) (int32, error) { 1375 if parent.KeyID != 0 { 1376 return parent.KeyID, nil 1377 } 1378 1379 // Check cache 1380 parKey := parent.Key() 1381 1382 key, ok := complianceServerCtx.ParentPolicyToID.Load(parKey) 1383 if ok { 1384 return key.(int32), nil 1385 } 1386 1387 err := parent.GetOrCreate(ctx, complianceServerCtx.DB) 1388 if err != nil { 1389 return 0, err 1390 } 1391 1392 complianceServerCtx.ParentPolicyToID.Store(parKey, parent.KeyID) 1393 1394 return parent.KeyID, nil 1395 } 1396 1397 func getPolicyForeignKey(ctx context.Context, complianceServerCtx *ComplianceServerCtx, pol Policy) (int32, error) { 1398 if pol.KeyID != 0 { 1399 return pol.KeyID, nil 1400 } 1401 1402 // Check cache 1403 polKey := pol.Key() 1404 1405 key, ok := complianceServerCtx.PolicyToID.Load(polKey) 1406 if ok { 1407 return key.(int32), nil 1408 } 1409 1410 err := pol.GetOrCreate(ctx, complianceServerCtx.DB) 1411 if err != nil { 1412 return 0, err 1413 } 1414 1415 complianceServerCtx.PolicyToID.Store(polKey, pol.KeyID) 1416 1417 return pol.KeyID, nil 1418 } 1419 1420 type errorMessage struct { 1421 Message string `json:"message"` 1422 } 1423 1424 // writeErrMsgJSON wraps the given message in JSON like `{"message": <>}` and 1425 // writes the response, setting the header to the given code. Since this message 1426 // will be read by the user, take care not to leak any sensitive details that 1427 // might be in the error message. 1428 func writeErrMsgJSON(w http.ResponseWriter, message string, code int) { 1429 msg := errorMessage{Message: message} 1430 1431 resp, err := json.Marshal(msg) 1432 if err != nil { 1433 log.Error(err, "error marshaling error message", "message", message) 1434 } 1435 1436 w.WriteHeader(code) 1437 1438 if _, err := w.Write(resp); err != nil { 1439 log.Error(err, "error writing error message") 1440 } 1441 }