
     1  package postgres
     3  import (
     4  	"context"
     5  	"crypto/md5"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"sort"
    10  	"strconv"
    11  	"time"
    13  	""
    14  	""
    15  	""
    16  	""
    17  	""
    18  	""
    20  	""
    21  	""
    22  	""
    23  )
    25  var (
    26  	updateEnrichmentsCounter = promauto.NewCounterVec(
    27  		prometheus.CounterOpts{
    28  			Namespace: "claircore",
    29  			Subsystem: "vulnstore",
    30  			Name:      "updateenrichments_total",
    31  			Help:      "Total number of database queries issued in the UpdateEnrichments method.",
    32  		},
    33  		[]string{"query"},
    34  	)
    35  	updateEnrichmentsDuration = promauto.NewHistogramVec(
    36  		prometheus.HistogramOpts{
    37  			Namespace: "claircore",
    38  			Subsystem: "vulnstore",
    39  			Name:      "updateenrichments_duration_seconds",
    40  			Help:      "Duration of all queries issued in the UpdateEnrichments method.",
    41  		},
    42  		[]string{"query"},
    43  	)
    44  	getEnrichmentsCounter = promauto.NewCounterVec(
    45  		prometheus.CounterOpts{
    46  			Namespace: "claircore",
    47  			Subsystem: "vulnstore",
    48  			Name:      "getenrichments_total",
    49  			Help:      "Total number of database queries issued in the get method.",
    50  		},
    51  		[]string{"query", "success"},
    52  	)
    53  	getEnrichmentsDuration = promauto.NewHistogramVec(
    54  		prometheus.HistogramOpts{
    55  			Namespace: "claircore",
    56  			Subsystem: "vulnstore",
    57  			Name:      "getenrichments_duration_seconds",
    58  			Help:      "Duration of all queries issued in the get method.",
    59  		},
    60  		[]string{"query", "success"},
    61  	)
    62  )
    64  func (s *MatcherStore) UpdateEnrichmentsIter(ctx context.Context, updater string, fp driver.Fingerprint, it datastore.EnrichmentIter) (uuid.UUID, error) {
    65  	ctx = zlog.ContextWithValues(ctx, "component", "datastore/postgres/MatcherStore.UpdateEnrichmentsIter")
    66  	return s.updateEnrichments(ctx, updater, fp, it)
    67  }
    69  // UpdateEnrichments creates a new UpdateOperation, inserts the provided
    70  // EnrichmentRecord(s), and ensures enrichments from previous updates are not
    71  // queried by clients.
    72  func (s *MatcherStore) UpdateEnrichments(ctx context.Context, updater string, fp driver.Fingerprint, es []driver.EnrichmentRecord) (uuid.UUID, error) {
    73  	ctx = zlog.ContextWithValues(ctx, "component", "datastore/postgres/MatcherStore.UpdateEnrichments")
    74  	enIter := func(yield func(record *driver.EnrichmentRecord, err error) bool) {
    75  		for i := range es {
    76  			if !yield(&es[i], nil) {
    77  				break
    78  			}
    79  		}
    80  	}
    81  	return s.updateEnrichments(ctx, updater, fp, enIter)
    82  }
    84  func (s *MatcherStore) updateEnrichments(ctx context.Context, name string, fp driver.Fingerprint, it datastore.EnrichmentIter) (uuid.UUID, error) {
    85  	const (
    86  		create = `
    87  INSERT
    88  INTO
    89  	update_operation (updater, fingerprint, kind)
    90  VALUES
    91  	($1, $2, 'enrichment')
    93  	id, ref;`
    94  		insert = `
    95  INSERT
    96  INTO
    97  	enrichment (hash_kind, hash, updater, tags, data)
    98  VALUES
    99  	($1, $2, $3, $4, $5)
   101  	(hash_kind, hash)
   102  DO
   103  	NOTHING;`
   104  		assoc = `
   105  INSERT
   106  INTO
   107  	uo_enrich (enrich, updater, uo, date)
   108  VALUES
   109  	(
   110  		(
   111  			SELECT
   112  				id
   113  			FROM
   114  				enrichment
   115  			WHERE
   116  				hash_kind = $1
   117  				AND hash = $2
   118  				AND updater = $3
   119  		),
   120  		$3,
   121  		$4,
   122  		transaction_timestamp()
   123  	)
   125  DO
   126  	NOTHING;`
   127  		refreshView = `REFRESH MATERIALIZED VIEW CONCURRENTLY latest_update_operations;`
   128  	)
   129  	ctx = zlog.ContextWithValues(ctx, "component", "datastore/postgres/UpdateEnrichments")
   131  	var id uint64
   132  	var ref uuid.UUID
   134  	start := time.Now()
   136  	if err := s.pool.QueryRow(ctx, create, name, string(fp)).Scan(&id, &ref); err != nil {
   137  		return uuid.Nil, fmt.Errorf("failed to create update_operation: %w", err)
   138  	}
   140  	updateEnrichmentsCounter.WithLabelValues("create").Add(1)
   141  	updateEnrichmentsDuration.WithLabelValues("create").Observe(time.Since(start).Seconds())
   143  	tx, err := s.pool.Begin(ctx)
   144  	if err != nil {
   145  		return uuid.Nil, fmt.Errorf("unable to start transaction: %w", err)
   146  	}
   147  	defer tx.Rollback(ctx)
   149  	zlog.Debug(ctx).
   150  		Str("ref", ref.String()).
   151  		Msg("update_operation created")
   153  	batch := microbatch.NewInsert(tx, 2000, time.Minute)
   154  	start = time.Now()
   155  	enCt := 0
   156  	it(func(en *driver.EnrichmentRecord, iterErr error) bool {
   157  		if iterErr != nil {
   158  			err = iterErr
   159  			return false
   160  		}
   161  		enCt++
   162  		hashKind, hash := hashEnrichment(en)
   163  		err = batch.Queue(ctx, insert,
   164  			hashKind, hash, name, en.Tags, en.Enrichment,
   165  		)
   166  		if err != nil {
   167  			err = fmt.Errorf("failed to queue enrichment: %w", err)
   168  			return false
   169  		}
   170  		if err := batch.Queue(ctx, assoc, hashKind, hash, name, id); err != nil {
   171  			err = fmt.Errorf("failed to queue association: %w", err)
   172  			return false
   173  		}
   174  		return true
   175  	})
   176  	if err != nil {
   177  		return uuid.Nil, fmt.Errorf("iterating on enrichments: %w", err)
   178  	}
   179  	if err := batch.Done(ctx); err != nil {
   180  		return uuid.Nil, fmt.Errorf("failed to finish batch enrichment insert: %w", err)
   181  	}
   182  	updateEnrichmentsCounter.WithLabelValues("insert_batch").Add(1)
   183  	updateEnrichmentsDuration.WithLabelValues("insert_batch").Observe(time.Since(start).Seconds())
   185  	if err := tx.Commit(ctx); err != nil {
   186  		return uuid.Nil, fmt.Errorf("failed to commit transaction: %w", err)
   187  	}
   188  	if _, err = s.pool.Exec(ctx, refreshView); err != nil {
   189  		return uuid.Nil, fmt.Errorf("could not refresh latest_update_operations: %w", err)
   190  	}
   191  	zlog.Debug(ctx).
   192  		Stringer("ref", ref).
   193  		Int("inserted", enCt).
   194  		Msg("update_operation committed")
   195  	return ref, nil
   196  }
   198  func hashEnrichment(r *driver.EnrichmentRecord) (k string, d []byte) {
   199  	h := md5.New()
   200  	sort.Strings(r.Tags)
   201  	for _, t := range r.Tags {
   202  		io.WriteString(h, t)
   203  		h.Write([]byte("\x00"))
   204  	}
   205  	h.Write(r.Enrichment)
   206  	return "md5", h.Sum(nil)
   207  }
   209  func (s *MatcherStore) GetEnrichment(ctx context.Context, name string, tags []string) (res []driver.EnrichmentRecord, err error) {
   210  	const query = `
   211  WITH
   212  	latest
   213  		AS (
   214  			SELECT
   215  				id
   216  			FROM
   217  				latest_update_operations
   218  			WHERE
   219  				updater = $1
   220  			AND
   221  				kind = 'enrichment'
   222  			LIMIT 1
   223  		)
   224  SELECT
   225  	e.tags,
   226  FROM
   227  	enrichment AS e,
   228  	uo_enrich AS uo,
   229  	latest
   230  WHERE
   231  	uo.uo =
   232  	AND uo.enrich =
   233  	AND e.tags && $2::text[];`
   235  	ctx = zlog.ContextWithValues(ctx, "component", "datastore/postgres/GetEnrichment")
   236  	timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
   237  		getEnrichmentsDuration.WithLabelValues("query", strconv.FormatBool(errors.Is(err, nil))).Observe(v)
   238  	}))
   239  	defer timer.ObserveDuration()
   240  	defer func() {
   241  		getEnrichmentsCounter.WithLabelValues("query", strconv.FormatBool(errors.Is(err, nil))).Inc()
   242  	}()
   243  	var (
   244  		c    *pgxpool.Conn
   245  		rows pgx.Rows
   246  	)
   247  	c, err = s.pool.Acquire(ctx)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	defer c.Release()
   252  	res = make([]driver.EnrichmentRecord, 0, 8) // Guess at capacity.
   253  	rows, err = c.Query(ctx, query, name, tags)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	defer rows.Close()
   258  	for rows.Next() {
   259  		i := len(res)
   260  		res = append(res, driver.EnrichmentRecord{})
   261  		r := &res[i]
   262  		err = rows.Scan(&r.Tags, &r.Enrichment)
   263  		if err != nil {
   264  			return nil, err
   265  		}
   266  	}
   267  	err = rows.Err()
   268  	return res, err
   269  }