github.com/letsencrypt/boulder@v0.20251208.0/db/multi.go (about) 1 package db 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 ) 8 9 // MultiInserter makes it easy to construct a 10 // `INSERT INTO table (...) VALUES ...;` 11 // query which inserts multiple rows into the same table. It can also execute 12 // the resulting query. 13 type MultiInserter struct { 14 // These are validated by the constructor as containing only characters 15 // that are allowed in an unquoted identifier. 16 // https://mariadb.com/kb/en/identifier-names/#unquoted 17 table string 18 fields []string 19 20 values [][]any 21 } 22 23 // NewMultiInserter creates a new MultiInserter, checking for reasonable table 24 // name and list of fields. 25 // Safety: `table` and `fields` must contain only strings that are known at 26 // compile time. They must not contain user-controlled strings. 27 func NewMultiInserter(table string, fields []string) (*MultiInserter, error) { 28 if len(table) == 0 || len(fields) == 0 { 29 return nil, fmt.Errorf("empty table name or fields list") 30 } 31 32 err := validMariaDBUnquotedIdentifier(table) 33 if err != nil { 34 return nil, err 35 } 36 for _, field := range fields { 37 err := validMariaDBUnquotedIdentifier(field) 38 if err != nil { 39 return nil, err 40 } 41 } 42 43 return &MultiInserter{ 44 table: table, 45 fields: fields, 46 values: make([][]any, 0), 47 }, nil 48 } 49 50 // Add registers another row to be included in the Insert query. 51 func (mi *MultiInserter) Add(row []any) error { 52 if len(row) != len(mi.fields) { 53 return fmt.Errorf("field count mismatch, got %d, expected %d", len(row), len(mi.fields)) 54 } 55 mi.values = append(mi.values, row) 56 return nil 57 } 58 59 // query returns the formatted query string, and the slice of arguments for 60 // for borp to use in place of the query's question marks. Currently only 61 // used by .Insert(), below. 62 func (mi *MultiInserter) query() (string, []any) { 63 var questionsBuf strings.Builder 64 var queryArgs []any 65 for _, row := range mi.values { 66 // Safety: We are interpolating a string that will be used in a SQL 67 // query, but we constructed that string in this function and know it 68 // consists only of question marks joined with commas. 69 fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(len(mi.fields))) 70 queryArgs = append(queryArgs, row...) 71 } 72 73 questions := strings.TrimRight(questionsBuf.String(), ",") 74 75 // Safety: we are interpolating `mi.table` and `mi.fields` into an SQL 76 // query. We know they contain, respectively, a valid unquoted identifier 77 // and a slice of valid unquoted identifiers because we verified that in 78 // the constructor. We know the query overall has valid syntax because we 79 // generate it entirely within this function. 80 query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", mi.table, strings.Join(mi.fields, ","), questions) 81 82 return query, queryArgs 83 } 84 85 // Insert inserts all the collected rows into the database represented by 86 // `queryer`. 87 func (mi *MultiInserter) Insert(ctx context.Context, db Execer) error { 88 if len(mi.values) == 0 { 89 return nil 90 } 91 92 query, queryArgs := mi.query() 93 res, err := db.ExecContext(ctx, query, queryArgs...) 94 if err != nil { 95 return err 96 } 97 98 affected, err := res.RowsAffected() 99 if err != nil { 100 return err 101 } 102 if affected != int64(len(mi.values)) { 103 return fmt.Errorf("unexpected number of rows inserted: %d != %d", affected, len(mi.values)) 104 } 105 106 return nil 107 }