github.com/letsencrypt/boulder@v0.20251208.0/db/gorm.go (about) 1 package db 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "regexp" 9 "strings" 10 ) 11 12 // Characters allowed in an unquoted identifier by MariaDB. 13 // https://mariadb.com/kb/en/identifier-names/#unquoted 14 var mariaDBUnquotedIdentifierRE = regexp.MustCompile("^[0-9a-zA-Z$_]+$") 15 16 func validMariaDBUnquotedIdentifier(s string) error { 17 if !mariaDBUnquotedIdentifierRE.MatchString(s) { 18 return fmt.Errorf("invalid MariaDB identifier %q", s) 19 } 20 21 allNumeric := true 22 startsNumeric := false 23 for i, c := range []byte(s) { 24 if c < '0' || c > '9' { 25 if startsNumeric && len(s) > i && s[i] == 'e' { 26 return fmt.Errorf("MariaDB identifier looks like floating point: %q", s) 27 } 28 allNumeric = false 29 break 30 } 31 startsNumeric = true 32 } 33 if allNumeric { 34 return fmt.Errorf("MariaDB identifier contains only numerals: %q", s) 35 } 36 return nil 37 } 38 39 // NewMappedSelector returns an object which can be used to automagically query 40 // the provided type-mapped database for rows of the parameterized type. 41 func NewMappedSelector[T any](executor MappedExecutor) (MappedSelector[T], error) { 42 var throwaway T 43 t := reflect.TypeOf(throwaway) 44 45 // We use a very strict mapping of struct fields to table columns here: 46 // - The struct must not have any embedded structs, only named fields. 47 // - The struct field names must be case-insensitively identical to the 48 // column names (no struct tags necessary). 49 // - The struct field names must be case-insensitively unique. 50 // - Every field of the struct must correspond to a database column. 51 // - Note that the reverse is not true: it's perfectly okay for there to be 52 // database columns which do not correspond to fields in the struct; those 53 // columns will be ignored. 54 // TODO: In the future, when we replace borp's TableMap with our own, this 55 // check should be performed at the time the mapping is declared. 56 columns := make([]string, 0) 57 seen := make(map[string]struct{}) 58 for i := range t.NumField() { 59 field := t.Field(i) 60 if field.Anonymous { 61 return nil, fmt.Errorf("struct contains anonymous embedded struct %q", field.Name) 62 } 63 column := strings.ToLower(t.Field(i).Name) 64 err := validMariaDBUnquotedIdentifier(column) 65 if err != nil { 66 return nil, fmt.Errorf("struct field maps to unsafe db column name %q", column) 67 } 68 if _, found := seen[column]; found { 69 return nil, fmt.Errorf("struct fields map to duplicate column name %q", column) 70 } 71 seen[column] = struct{}{} 72 columns = append(columns, column) 73 } 74 75 return &mappedSelector[T]{wrapped: executor, columns: columns}, nil 76 } 77 78 type mappedSelector[T any] struct { 79 wrapped MappedExecutor 80 columns []string 81 } 82 83 // QueryContext performs a SELECT on the appropriate table for T. It combines the best 84 // features of borp, the go stdlib, and generics, using the type parameter of 85 // the typeSelector object to automatically look up the proper table name and 86 // columns to select. It returns an iterable which yields fully-populated 87 // objects of the parameterized type directly. The given clauses MUST be only 88 // the bits of a sql query from "WHERE ..." onwards; if they contain any of the 89 // "SELECT ... FROM ..." portion of the query it will result in an error. The 90 // args take the same kinds of values as borp's SELECT: either one argument per 91 // positional placeholder, or a map of placeholder names to their arguments 92 // (see https://pkg.go.dev/github.com/letsencrypt/borp#readme-ad-hoc-sql). 93 // 94 // The caller is responsible for calling `Rows.Close()` when they are done with 95 // the query. The caller is also responsible for ensuring that the clauses 96 // argument does not contain any user-influenced input. 97 func (ts mappedSelector[T]) QueryContext(ctx context.Context, clauses string, args ...any) (Rows[T], error) { 98 // Look up the table to use based on the type of this TypeSelector. 99 var throwaway T 100 tableMap, err := ts.wrapped.TableFor(reflect.TypeOf(throwaway), false) 101 if err != nil { 102 return nil, fmt.Errorf("database model type not mapped to table name: %w", err) 103 } 104 105 return ts.QueryFrom(ctx, tableMap.TableName, clauses, args...) 106 } 107 108 // QueryFrom is the same as Query, but it additionally takes a table name to 109 // select from, rather than automatically computing the table name from borp's 110 // DbMap. 111 // 112 // The caller is responsible for calling `Rows.Close()` when they are done with 113 // the query. The caller is also responsible for ensuring that the clauses 114 // argument does not contain any user-influenced input. 115 func (ts mappedSelector[T]) QueryFrom(ctx context.Context, tablename string, clauses string, args ...any) (Rows[T], error) { 116 err := validMariaDBUnquotedIdentifier(tablename) 117 if err != nil { 118 return nil, err 119 } 120 121 // Construct the query from the column names, table name, and given clauses. 122 // Note that the column names here are in the order given by 123 query := fmt.Sprintf( 124 "SELECT %s FROM %s %s", 125 strings.Join(ts.columns, ", "), 126 tablename, 127 clauses, 128 ) 129 130 r, err := ts.wrapped.QueryContext(ctx, query, args...) 131 if err != nil { 132 return nil, fmt.Errorf("reading db: %w", err) 133 } 134 135 return &rows[T]{wrapped: r, numCols: len(ts.columns)}, nil 136 } 137 138 // rows is a wrapper around the stdlib's sql.rows, but with a more 139 // type-safe method to get actual row content. 140 type rows[T any] struct { 141 wrapped *sql.Rows 142 numCols int 143 } 144 145 // ForEach calls the given function with each model object retrieved by 146 // repeatedly calling .Get(). It closes the rows object when it hits an error 147 // or finishes iterating over the rows, so it can only be called once. This is 148 // the intended way to use the result of QueryContext or QueryFrom; the other 149 // methods on this type are lower-level and intended for advanced use only. 150 func (r rows[T]) ForEach(do func(*T) error) (err error) { 151 defer func() { 152 // Close the row reader when we exit. Use the named error return to combine 153 // any error from normal execution with any error from closing. 154 closeErr := r.Close() 155 if closeErr != nil && err != nil { 156 err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr) 157 } else if closeErr != nil { 158 err = closeErr 159 } 160 // If closeErr is nil, then just leaving the existing named return alone 161 // will do the right thing. 162 }() 163 164 for r.Next() { 165 row, err := r.Get() 166 if err != nil { 167 return fmt.Errorf("reading row: %w", err) 168 } 169 170 err = do(row) 171 if err != nil { 172 return err 173 } 174 } 175 176 err = r.Err() 177 if err != nil { 178 return fmt.Errorf("iterating over row reader: %w", err) 179 } 180 181 return nil 182 } 183 184 // Next is a wrapper around sql.Rows.Next(). It must be called before every call 185 // to Get(), including the first. 186 func (r rows[T]) Next() bool { 187 return r.wrapped.Next() 188 } 189 190 // Get is a wrapper around sql.Rows.Scan(). Rather than populating an arbitrary 191 // number of &interface{} arguments, it returns a populated object of the 192 // parameterized type. 193 func (r rows[T]) Get() (*T, error) { 194 result := new(T) 195 v := reflect.ValueOf(result) 196 197 // Because sql.Rows.Scan(...) takes a variadic number of individual targets to 198 // read values into, build a slice that can be splatted into the call. Use the 199 // pre-computed list of in-order column names to populate it. 200 scanTargets := make([]any, r.numCols) 201 for i := range scanTargets { 202 field := v.Elem().Field(i) 203 scanTargets[i] = field.Addr().Interface() 204 } 205 206 err := r.wrapped.Scan(scanTargets...) 207 if err != nil { 208 return nil, fmt.Errorf("reading db row: %w", err) 209 } 210 211 return result, nil 212 } 213 214 // Err is a wrapper around sql.Rows.Err(). It should be checked immediately 215 // after Next() returns false for any reason. 216 func (r rows[T]) Err() error { 217 return r.wrapped.Err() 218 } 219 220 // Close is a wrapper around sql.Rows.Close(). It must be called when the caller 221 // is done reading rows, regardless of success or error. 222 func (r rows[T]) Close() error { 223 return r.wrapped.Close() 224 }