github.com/letsencrypt/boulder@v0.20251208.0/db/multi.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  )
     8  
     9  // MultiInserter makes it easy to construct a
    10  // `INSERT INTO table (...) VALUES ...;`
    11  // query which inserts multiple rows into the same table. It can also execute
    12  // the resulting query.
    13  type MultiInserter struct {
    14  	// These are validated by the constructor as containing only characters
    15  	// that are allowed in an unquoted identifier.
    16  	// https://mariadb.com/kb/en/identifier-names/#unquoted
    17  	table  string
    18  	fields []string
    19  
    20  	values [][]any
    21  }
    22  
    23  // NewMultiInserter creates a new MultiInserter, checking for reasonable table
    24  // name and list of fields.
    25  // Safety: `table` and `fields` must contain only strings that are known at
    26  // compile time. They must not contain user-controlled strings.
    27  func NewMultiInserter(table string, fields []string) (*MultiInserter, error) {
    28  	if len(table) == 0 || len(fields) == 0 {
    29  		return nil, fmt.Errorf("empty table name or fields list")
    30  	}
    31  
    32  	err := validMariaDBUnquotedIdentifier(table)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	for _, field := range fields {
    37  		err := validMariaDBUnquotedIdentifier(field)
    38  		if err != nil {
    39  			return nil, err
    40  		}
    41  	}
    42  
    43  	return &MultiInserter{
    44  		table:  table,
    45  		fields: fields,
    46  		values: make([][]any, 0),
    47  	}, nil
    48  }
    49  
    50  // Add registers another row to be included in the Insert query.
    51  func (mi *MultiInserter) Add(row []any) error {
    52  	if len(row) != len(mi.fields) {
    53  		return fmt.Errorf("field count mismatch, got %d, expected %d", len(row), len(mi.fields))
    54  	}
    55  	mi.values = append(mi.values, row)
    56  	return nil
    57  }
    58  
    59  // query returns the formatted query string, and the slice of arguments for
    60  // for borp to use in place of the query's question marks. Currently only
    61  // used by .Insert(), below.
    62  func (mi *MultiInserter) query() (string, []any) {
    63  	var questionsBuf strings.Builder
    64  	var queryArgs []any
    65  	for _, row := range mi.values {
    66  		// Safety: We are interpolating a string that will be used in a SQL
    67  		// query, but we constructed that string in this function and know it
    68  		// consists only of question marks joined with commas.
    69  		fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(len(mi.fields)))
    70  		queryArgs = append(queryArgs, row...)
    71  	}
    72  
    73  	questions := strings.TrimRight(questionsBuf.String(), ",")
    74  
    75  	// Safety: we are interpolating `mi.table` and `mi.fields` into an SQL
    76  	// query. We know they contain, respectively, a valid unquoted identifier
    77  	// and a slice of valid unquoted identifiers because we verified that in
    78  	// the constructor. We know the query overall has valid syntax because we
    79  	// generate it entirely within this function.
    80  	query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", mi.table, strings.Join(mi.fields, ","), questions)
    81  
    82  	return query, queryArgs
    83  }
    84  
    85  // Insert inserts all the collected rows into the database represented by
    86  // `queryer`.
    87  func (mi *MultiInserter) Insert(ctx context.Context, db Execer) error {
    88  	if len(mi.values) == 0 {
    89  		return nil
    90  	}
    91  
    92  	query, queryArgs := mi.query()
    93  	res, err := db.ExecContext(ctx, query, queryArgs...)
    94  	if err != nil {
    95  		return err
    96  	}
    97  
    98  	affected, err := res.RowsAffected()
    99  	if err != nil {
   100  		return err
   101  	}
   102  	if affected != int64(len(mi.values)) {
   103  		return fmt.Errorf("unexpected number of rows inserted: %d != %d", affected, len(mi.values))
   104  	}
   105  
   106  	return nil
   107  }