github.com/letsencrypt/boulder@v0.20251208.0/db/gorm.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strings"
    10  )
    11  
    12  // Characters allowed in an unquoted identifier by MariaDB.
    13  // https://mariadb.com/kb/en/identifier-names/#unquoted
    14  var mariaDBUnquotedIdentifierRE = regexp.MustCompile("^[0-9a-zA-Z$_]+$")
    15  
    16  func validMariaDBUnquotedIdentifier(s string) error {
    17  	if !mariaDBUnquotedIdentifierRE.MatchString(s) {
    18  		return fmt.Errorf("invalid MariaDB identifier %q", s)
    19  	}
    20  
    21  	allNumeric := true
    22  	startsNumeric := false
    23  	for i, c := range []byte(s) {
    24  		if c < '0' || c > '9' {
    25  			if startsNumeric && len(s) > i && s[i] == 'e' {
    26  				return fmt.Errorf("MariaDB identifier looks like floating point: %q", s)
    27  			}
    28  			allNumeric = false
    29  			break
    30  		}
    31  		startsNumeric = true
    32  	}
    33  	if allNumeric {
    34  		return fmt.Errorf("MariaDB identifier contains only numerals: %q", s)
    35  	}
    36  	return nil
    37  }
    38  
    39  // NewMappedSelector returns an object which can be used to automagically query
    40  // the provided type-mapped database for rows of the parameterized type.
    41  func NewMappedSelector[T any](executor MappedExecutor) (MappedSelector[T], error) {
    42  	var throwaway T
    43  	t := reflect.TypeOf(throwaway)
    44  
    45  	// We use a very strict mapping of struct fields to table columns here:
    46  	// - The struct must not have any embedded structs, only named fields.
    47  	// - The struct field names must be case-insensitively identical to the
    48  	//   column names (no struct tags necessary).
    49  	// - The struct field names must be case-insensitively unique.
    50  	// - Every field of the struct must correspond to a database column.
    51  	//   - Note that the reverse is not true: it's perfectly okay for there to be
    52  	//     database columns which do not correspond to fields in the struct; those
    53  	//     columns will be ignored.
    54  	// TODO: In the future, when we replace borp's TableMap with our own, this
    55  	// check should be performed at the time the mapping is declared.
    56  	columns := make([]string, 0)
    57  	seen := make(map[string]struct{})
    58  	for i := range t.NumField() {
    59  		field := t.Field(i)
    60  		if field.Anonymous {
    61  			return nil, fmt.Errorf("struct contains anonymous embedded struct %q", field.Name)
    62  		}
    63  		column := strings.ToLower(t.Field(i).Name)
    64  		err := validMariaDBUnquotedIdentifier(column)
    65  		if err != nil {
    66  			return nil, fmt.Errorf("struct field maps to unsafe db column name %q", column)
    67  		}
    68  		if _, found := seen[column]; found {
    69  			return nil, fmt.Errorf("struct fields map to duplicate column name %q", column)
    70  		}
    71  		seen[column] = struct{}{}
    72  		columns = append(columns, column)
    73  	}
    74  
    75  	return &mappedSelector[T]{wrapped: executor, columns: columns}, nil
    76  }
    77  
    78  type mappedSelector[T any] struct {
    79  	wrapped MappedExecutor
    80  	columns []string
    81  }
    82  
    83  // QueryContext performs a SELECT on the appropriate table for T. It combines the best
    84  // features of borp, the go stdlib, and generics, using the type parameter of
    85  // the typeSelector object to automatically look up the proper table name and
    86  // columns to select. It returns an iterable which yields fully-populated
    87  // objects of the parameterized type directly. The given clauses MUST be only
    88  // the bits of a sql query from "WHERE ..." onwards; if they contain any of the
    89  // "SELECT ... FROM ..." portion of the query it will result in an error. The
    90  // args take the same kinds of values as borp's SELECT: either one argument per
    91  // positional placeholder, or a map of placeholder names to their arguments
    92  // (see https://pkg.go.dev/github.com/letsencrypt/borp#readme-ad-hoc-sql).
    93  //
    94  // The caller is responsible for calling `Rows.Close()` when they are done with
    95  // the query. The caller is also responsible for ensuring that the clauses
    96  // argument does not contain any user-influenced input.
    97  func (ts mappedSelector[T]) QueryContext(ctx context.Context, clauses string, args ...any) (Rows[T], error) {
    98  	// Look up the table to use based on the type of this TypeSelector.
    99  	var throwaway T
   100  	tableMap, err := ts.wrapped.TableFor(reflect.TypeOf(throwaway), false)
   101  	if err != nil {
   102  		return nil, fmt.Errorf("database model type not mapped to table name: %w", err)
   103  	}
   104  
   105  	return ts.QueryFrom(ctx, tableMap.TableName, clauses, args...)
   106  }
   107  
   108  // QueryFrom is the same as Query, but it additionally takes a table name to
   109  // select from, rather than automatically computing the table name from borp's
   110  // DbMap.
   111  //
   112  // The caller is responsible for calling `Rows.Close()` when they are done with
   113  // the query. The caller is also responsible for ensuring that the clauses
   114  // argument does not contain any user-influenced input.
   115  func (ts mappedSelector[T]) QueryFrom(ctx context.Context, tablename string, clauses string, args ...any) (Rows[T], error) {
   116  	err := validMariaDBUnquotedIdentifier(tablename)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	// Construct the query from the column names, table name, and given clauses.
   122  	// Note that the column names here are in the order given by
   123  	query := fmt.Sprintf(
   124  		"SELECT %s FROM %s %s",
   125  		strings.Join(ts.columns, ", "),
   126  		tablename,
   127  		clauses,
   128  	)
   129  
   130  	r, err := ts.wrapped.QueryContext(ctx, query, args...)
   131  	if err != nil {
   132  		return nil, fmt.Errorf("reading db: %w", err)
   133  	}
   134  
   135  	return &rows[T]{wrapped: r, numCols: len(ts.columns)}, nil
   136  }
   137  
   138  // rows is a wrapper around the stdlib's sql.rows, but with a more
   139  // type-safe method to get actual row content.
   140  type rows[T any] struct {
   141  	wrapped *sql.Rows
   142  	numCols int
   143  }
   144  
   145  // ForEach calls the given function with each model object retrieved by
   146  // repeatedly calling .Get(). It closes the rows object when it hits an error
   147  // or finishes iterating over the rows, so it can only be called once. This is
   148  // the intended way to use the result of QueryContext or QueryFrom; the other
   149  // methods on this type are lower-level and intended for advanced use only.
   150  func (r rows[T]) ForEach(do func(*T) error) (err error) {
   151  	defer func() {
   152  		// Close the row reader when we exit. Use the named error return to combine
   153  		// any error from normal execution with any error from closing.
   154  		closeErr := r.Close()
   155  		if closeErr != nil && err != nil {
   156  			err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
   157  		} else if closeErr != nil {
   158  			err = closeErr
   159  		}
   160  		// If closeErr is nil, then just leaving the existing named return alone
   161  		// will do the right thing.
   162  	}()
   163  
   164  	for r.Next() {
   165  		row, err := r.Get()
   166  		if err != nil {
   167  			return fmt.Errorf("reading row: %w", err)
   168  		}
   169  
   170  		err = do(row)
   171  		if err != nil {
   172  			return err
   173  		}
   174  	}
   175  
   176  	err = r.Err()
   177  	if err != nil {
   178  		return fmt.Errorf("iterating over row reader: %w", err)
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  // Next is a wrapper around sql.Rows.Next(). It must be called before every call
   185  // to Get(), including the first.
   186  func (r rows[T]) Next() bool {
   187  	return r.wrapped.Next()
   188  }
   189  
   190  // Get is a wrapper around sql.Rows.Scan(). Rather than populating an arbitrary
   191  // number of &interface{} arguments, it returns a populated object of the
   192  // parameterized type.
   193  func (r rows[T]) Get() (*T, error) {
   194  	result := new(T)
   195  	v := reflect.ValueOf(result)
   196  
   197  	// Because sql.Rows.Scan(...) takes a variadic number of individual targets to
   198  	// read values into, build a slice that can be splatted into the call. Use the
   199  	// pre-computed list of in-order column names to populate it.
   200  	scanTargets := make([]any, r.numCols)
   201  	for i := range scanTargets {
   202  		field := v.Elem().Field(i)
   203  		scanTargets[i] = field.Addr().Interface()
   204  	}
   205  
   206  	err := r.wrapped.Scan(scanTargets...)
   207  	if err != nil {
   208  		return nil, fmt.Errorf("reading db row: %w", err)
   209  	}
   210  
   211  	return result, nil
   212  }
   213  
   214  // Err is a wrapper around sql.Rows.Err(). It should be checked immediately
   215  // after Next() returns false for any reason.
   216  func (r rows[T]) Err() error {
   217  	return r.wrapped.Err()
   218  }
   219  
   220  // Close is a wrapper around sql.Rows.Close(). It must be called when the caller
   221  // is done reading rows, regardless of success or error.
   222  func (r rows[T]) Close() error {
   223  	return r.wrapped.Close()
   224  }