go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/db/query.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"database/sql"
    12  	"errors"
    13  	"reflect"
    14  
    15  	"go.charczuk.com/sdk/errutil"
    16  )
    17  
    18  // Query is the intermediate result of a query.
    19  type Query struct {
    20  	inv  *Invocation
    21  	err  error
    22  	stmt string
    23  	args []any
    24  }
    25  
    26  // Do runs a given query, yielding the raw results.
    27  func (q *Query) Do() (rows *sql.Rows, err error) {
    28  	defer func() {
    29  		err = q.finish(recover(), nil, err)
    30  	}()
    31  	rows, err = q.query()
    32  	return
    33  }
    34  
    35  // Any returns if there are any results for the query.
    36  func (q *Query) Any() (found bool, err error) {
    37  	var rows *sql.Rows
    38  	defer func() {
    39  		err = q.finish(recover(), nil, err)
    40  		err = q.rowsClose(rows, err)
    41  	}()
    42  	rows, err = q.query()
    43  	if err != nil {
    44  		return
    45  	}
    46  	found = rows.Next()
    47  	return
    48  }
    49  
    50  // None returns if there are no results for the query.
    51  func (q *Query) None() (notFound bool, err error) {
    52  	var rows *sql.Rows
    53  	defer func() {
    54  		err = q.finish(recover(), nil, err)
    55  		err = q.rowsClose(rows, err)
    56  	}()
    57  	rows, err = q.query()
    58  	if err != nil {
    59  		return
    60  	}
    61  	notFound = !rows.Next()
    62  	return
    63  }
    64  
    65  // Scan writes the results to a given set of local variables.
    66  // It returns if the query produced a row, and returns `ErrTooManyRows` if there
    67  // are multiple row results.
    68  func (q *Query) Scan(args ...any) (found bool, err error) {
    69  	var rows *sql.Rows
    70  	defer func() {
    71  		err = q.finish(recover(), nil, err)
    72  		err = q.rowsClose(rows, err)
    73  	}()
    74  
    75  	rows, err = q.query()
    76  	if err != nil {
    77  		return
    78  	}
    79  	found, err = Scan(rows, args...)
    80  	return
    81  }
    82  
    83  // Out writes the query result to a single object via. reflection mapping. If there is more than one result, the first
    84  // result is mapped to to object, and ErrTooManyRows is returned. Out() will apply column values for any colums
    85  // in the row result to the object, potentially zeroing existing values out.
    86  func (q *Query) Out(object any) (found bool, err error) {
    87  	var rows *sql.Rows
    88  	defer func() {
    89  		err = q.finish(recover(), nil, err)
    90  		err = q.rowsClose(rows, err)
    91  	}()
    92  
    93  	rows, err = q.query()
    94  	if err != nil {
    95  		return
    96  	}
    97  	sliceType := reflectType(object)
    98  	if sliceType.Kind() != reflect.Struct {
    99  		err = ErrDestinationNotStruct
   100  		return
   101  	}
   102  	columnMeta := q.inv.conn.TypeMeta(object)
   103  	if rows.Next() {
   104  		found = true
   105  		if populatable, ok := object.(Populatable); ok {
   106  			err = populatable.Populate(rows)
   107  		} else {
   108  			err = PopulateByName(object, rows, columnMeta)
   109  		}
   110  		if err != nil {
   111  			return
   112  		}
   113  	} else if err = Zero(object); err != nil {
   114  		return
   115  	}
   116  	if rows.Next() {
   117  		err = ErrTooManyRows
   118  	}
   119  	return
   120  }
   121  
   122  // OutMany writes the query results to a slice of objects.
   123  func (q *Query) OutMany(collection any) (err error) {
   124  	var rows *sql.Rows
   125  	defer func() {
   126  		// err = q.finish(nil, nil, err)
   127  		err = q.finish(recover(), nil, err)
   128  		err = q.rowsClose(rows, err)
   129  	}()
   130  
   131  	rows, err = q.query()
   132  	if err != nil {
   133  		return
   134  	}
   135  
   136  	sliceType := reflectType(collection)
   137  	if sliceType.Kind() != reflect.Slice {
   138  		err = ErrCollectionNotSlice
   139  		return
   140  	}
   141  
   142  	sliceInnerType := reflectSliceType(collection)
   143  	collectionValue := reflectValue(collection)
   144  	v := makeNew(sliceInnerType)
   145  	isStruct := sliceInnerType.Kind() == reflect.Struct
   146  	var meta *TypeMeta
   147  	if isStruct {
   148  		meta = q.inv.conn.TypeMetaFromType(newColumnCacheKey(sliceInnerType), sliceInnerType)
   149  	}
   150  
   151  	isPopulatable := IsPopulatable(v)
   152  
   153  	var didSetRows bool
   154  	for rows.Next() {
   155  		newObj := makeNew(sliceInnerType)
   156  		if isPopulatable {
   157  			err = newObj.(Populatable).Populate(rows)
   158  		} else if isStruct {
   159  			err = PopulateByName(newObj, rows, meta)
   160  		} else {
   161  			err = rows.Scan(newObj)
   162  		}
   163  		if err != nil {
   164  			return
   165  		}
   166  
   167  		newObjValue := reflectValue(newObj)
   168  		collectionValue.Set(reflect.Append(collectionValue, newObjValue))
   169  		didSetRows = true
   170  	}
   171  
   172  	// this initializes the slice if we didn't add elements to it.
   173  	if !didSetRows {
   174  		collectionValue.Set(reflect.MakeSlice(sliceType, 0, 0))
   175  	}
   176  	return
   177  }
   178  
   179  // Each executes the consumer for each result of the query (one to many).
   180  func (q *Query) Each(consumer RowsConsumer) (err error) {
   181  	var rows *sql.Rows
   182  	defer func() {
   183  		err = q.finish(recover(), nil, err)
   184  		err = q.rowsClose(rows, err)
   185  	}()
   186  
   187  	rows, err = q.query()
   188  	if err != nil {
   189  		return
   190  	}
   191  
   192  	err = Each(rows, consumer)
   193  	return
   194  }
   195  
   196  // First executes the consumer for the first result of a query.
   197  // It returns `ErrTooManyRows` if more than one result is returned.
   198  func (q *Query) First(consumer RowsConsumer) (found bool, err error) {
   199  	var rows *sql.Rows
   200  	defer func() {
   201  		err = q.finish(recover(), nil, err)
   202  		err = q.rowsClose(rows, err)
   203  	}()
   204  	rows, err = q.query()
   205  	if err != nil {
   206  		return
   207  	}
   208  	found, err = First(rows, consumer)
   209  	return
   210  }
   211  
   212  // --------------------------------------------------------------------------------
   213  // helpers
   214  // --------------------------------------------------------------------------------
   215  
   216  func (q *Query) rowsClose(rows *sql.Rows, err error) error {
   217  	if rows == nil {
   218  		return err
   219  	}
   220  	if closeErr := rows.Close(); closeErr != nil {
   221  		return errutil.Append(err, closeErr)
   222  	}
   223  	return err
   224  }
   225  
   226  func (q *Query) query() (rows *sql.Rows, err error) {
   227  	if q.err != nil {
   228  		err = q.err
   229  		return
   230  	}
   231  
   232  	var queryError error
   233  	dbc := q.inv.db
   234  	ctx := q.inv.ctx
   235  	rows, queryError = dbc.QueryContext(ctx, q.stmt, q.args...)
   236  	if queryError != nil && !errors.Is(queryError, sql.ErrNoRows) {
   237  		err = queryError
   238  	}
   239  	return
   240  }
   241  
   242  func (q *Query) finish(r any, res sql.Result, err error) error {
   243  	return q.inv.finish(q.stmt, r, res, err)
   244  }
   245  
   246  // Each iterates over a given result set, calling the rows consumer.
   247  func Each(rows *sql.Rows, consumer RowsConsumer) (err error) {
   248  	for rows.Next() {
   249  		if err = consumer(rows); err != nil {
   250  			return
   251  		}
   252  	}
   253  	return
   254  }
   255  
   256  // First returns the first result of a result set to a consumer.
   257  // If there are more than one row in the result, they are ignored.
   258  func First(rows *sql.Rows, consumer RowsConsumer) (found bool, err error) {
   259  	if found = rows.Next(); found {
   260  		if err = consumer(rows); err != nil {
   261  			return
   262  		}
   263  	}
   264  	return
   265  }
   266  
   267  // Scan reads the first row from a resultset and scans it to a given set of args.
   268  // If more than one row is returned it will return ErrTooManyRows.
   269  func Scan(rows *sql.Rows, args ...any) (found bool, err error) {
   270  	if rows.Next() {
   271  		found = true
   272  		if err = rows.Scan(args...); err != nil {
   273  			return
   274  		}
   275  	}
   276  	if rows.Next() {
   277  		err = ErrTooManyRows
   278  	}
   279  	return
   280  }