github.com/blend/go-sdk@v1.20220411.3/db/query.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"database/sql"
    12  	"reflect"
    13  
    14  	"github.com/blend/go-sdk/ex"
    15  )
    16  
    17  // --------------------------------------------------------------------------------
    18  // Query Result
    19  // --------------------------------------------------------------------------------
    20  
    21  // Query is the intermediate result of a query.
    22  type Query struct {
    23  	Invocation *Invocation
    24  	Statement  string
    25  	Err        error
    26  	Args       []interface{}
    27  }
    28  
    29  // Do runs a given query, yielding the raw results.
    30  func (q *Query) Do() (rows *sql.Rows, err error) {
    31  	defer func() {
    32  		err = q.finish(recover(), err)
    33  	}()
    34  	rows, err = q.query()
    35  	return
    36  }
    37  
    38  // Any returns if there are any results for the query.
    39  func (q *Query) Any() (found bool, err error) {
    40  	var rows *sql.Rows
    41  	defer func() {
    42  		err = q.finish(recover(), err)
    43  		err = q.rowsClose(rows, err)
    44  	}()
    45  	rows, err = q.query()
    46  	if err != nil {
    47  		return
    48  	}
    49  	found = rows.Next()
    50  	return
    51  }
    52  
    53  // None returns if there are no results for the query.
    54  func (q *Query) None() (notFound bool, err error) {
    55  	var rows *sql.Rows
    56  	defer func() {
    57  		err = q.finish(recover(), err)
    58  		err = q.rowsClose(rows, err)
    59  	}()
    60  	rows, err = q.query()
    61  	if err != nil {
    62  		return
    63  	}
    64  	notFound = !rows.Next()
    65  	return
    66  }
    67  
    68  // Scan writes the results to a given set of local variables.
    69  // It returns if the query produced a row, and returns `ErrTooManyRows` if there
    70  // are multiple row results.
    71  func (q *Query) Scan(args ...interface{}) (found bool, err error) {
    72  	var rows *sql.Rows
    73  	defer func() {
    74  		err = q.finish(recover(), err)
    75  		err = q.rowsClose(rows, err)
    76  	}()
    77  
    78  	rows, err = q.query()
    79  	if err != nil {
    80  		return
    81  	}
    82  	found, err = Scan(rows, args...)
    83  	return
    84  }
    85  
    86  // Out writes the query result to a single object via. reflection mapping. If there is more than one result, the first
    87  // result is mapped to to object, and ErrTooManyRows is returned. Out() will apply column values for any colums
    88  // in the row result to the object, potentially zeroing existing values out.
    89  func (q *Query) Out(object interface{}) (found bool, err error) {
    90  	var rows *sql.Rows
    91  	defer func() {
    92  		err = q.finish(recover(), err)
    93  		err = q.rowsClose(rows, err)
    94  	}()
    95  
    96  	rows, err = q.query()
    97  	if err != nil {
    98  		return
    99  	}
   100  	found, err = Out(rows, object)
   101  	return
   102  }
   103  
   104  // OutMany writes the query results to a slice of objects.
   105  func (q *Query) OutMany(collection interface{}) (err error) {
   106  	var rows *sql.Rows
   107  	defer func() {
   108  		err = q.finish(recover(), err)
   109  		err = q.rowsClose(rows, err)
   110  	}()
   111  
   112  	rows, err = q.query()
   113  	if err != nil {
   114  		return
   115  	}
   116  	err = OutMany(rows, collection)
   117  	return
   118  }
   119  
   120  // Each executes the consumer for each result of the query (one to many).
   121  func (q *Query) Each(consumer RowsConsumer) (err error) {
   122  	var rows *sql.Rows
   123  	defer func() {
   124  		err = q.finish(recover(), err)
   125  		err = q.rowsClose(rows, err)
   126  	}()
   127  
   128  	rows, err = q.query()
   129  	if err != nil {
   130  		return
   131  	}
   132  
   133  	err = Each(rows, consumer)
   134  	return
   135  }
   136  
   137  // First executes the consumer for the first result of a query.
   138  // It returns `ErrTooManyRows` if more than one result is returned.
   139  func (q *Query) First(consumer RowsConsumer) (found bool, err error) {
   140  	var rows *sql.Rows
   141  	defer func() {
   142  		err = q.finish(recover(), err)
   143  		err = q.rowsClose(rows, err)
   144  	}()
   145  	rows, err = q.query()
   146  	if err != nil {
   147  		return
   148  	}
   149  	found, err = First(rows, consumer)
   150  	return
   151  }
   152  
   153  // --------------------------------------------------------------------------------
   154  // helpers
   155  // --------------------------------------------------------------------------------
   156  
   157  func (q *Query) rowsClose(rows *sql.Rows, err error) error {
   158  	if rows == nil {
   159  		return err
   160  	}
   161  	return ex.Nest(err, rows.Close())
   162  }
   163  
   164  func (q *Query) query() (rows *sql.Rows, err error) {
   165  	// fast abort if there was an issue ahead of returning the query.
   166  	if q.Err != nil {
   167  		err = q.Err
   168  		return
   169  	}
   170  
   171  	var queryError error
   172  	db := q.Invocation.DB
   173  	ctx := q.Invocation.Context
   174  	rows, queryError = db.QueryContext(ctx, q.Statement, q.Args...)
   175  	if queryError != nil && !ex.Is(queryError, sql.ErrNoRows) {
   176  		err = Error(queryError)
   177  	}
   178  	return
   179  }
   180  
   181  func (q *Query) finish(r interface{}, err error) error {
   182  	return q.Invocation.finish(q.Statement, r, nil, err)
   183  }
   184  
   185  // Out reads a given rows set out into an object reference.
   186  func Out(rows *sql.Rows, object interface{}) (found bool, err error) {
   187  	sliceType := ReflectType(object)
   188  	if sliceType.Kind() != reflect.Struct {
   189  		err = Error(ErrDestinationNotStruct)
   190  		return
   191  	}
   192  	columnMeta := Columns(object)
   193  	if rows.Next() {
   194  		found = true
   195  		if populatable, ok := object.(Populatable); ok {
   196  			err = populatable.Populate(rows)
   197  		} else {
   198  			err = PopulateByName(object, rows, columnMeta)
   199  		}
   200  		if err != nil {
   201  			return
   202  		}
   203  	} else if err = Zero(object); err != nil {
   204  		return
   205  	}
   206  	if rows.Next() {
   207  		err = Error(ErrTooManyRows)
   208  	}
   209  	return
   210  }
   211  
   212  // OutMany reads a given result set into a given collection.
   213  func OutMany(rows *sql.Rows, collection interface{}) (err error) {
   214  	sliceType := ReflectType(collection)
   215  	if sliceType.Kind() != reflect.Slice {
   216  		err = Error(ErrCollectionNotSlice)
   217  		return
   218  	}
   219  
   220  	sliceInnerType := ReflectSliceType(collection)
   221  	collectionValue := ReflectValue(collection)
   222  	v := makeNew(sliceInnerType)
   223  	meta := ColumnsFromType(newColumnCacheKey(sliceInnerType), sliceInnerType)
   224  
   225  	isPopulatable := IsPopulatable(v)
   226  
   227  	var didSetRows bool
   228  	for rows.Next() {
   229  		newObj := makeNew(sliceInnerType)
   230  		if isPopulatable {
   231  			err = AsPopulatable(newObj).Populate(rows)
   232  		} else {
   233  			err = PopulateByName(newObj, rows, meta)
   234  		}
   235  		if err != nil {
   236  			return
   237  		}
   238  
   239  		newObjValue := ReflectValue(newObj)
   240  		collectionValue.Set(reflect.Append(collectionValue, newObjValue))
   241  		didSetRows = true
   242  	}
   243  
   244  	// this initializes the slice if we didn't add elements to it.
   245  	if !didSetRows {
   246  		collectionValue.Set(reflect.MakeSlice(sliceType, 0, 0))
   247  	}
   248  	return
   249  }
   250  
   251  // Each iterates over a given result set, calling the rows consumer.
   252  func Each(rows *sql.Rows, consumer RowsConsumer) (err error) {
   253  	for rows.Next() {
   254  		if err = consumer(rows); err != nil {
   255  			err = Error(err)
   256  			return
   257  		}
   258  	}
   259  	return
   260  }
   261  
   262  // First returns the first result of a result set to a consumer.
   263  // If there are more than one row in the result, they are ignored.
   264  func First(rows *sql.Rows, consumer RowsConsumer) (found bool, err error) {
   265  	if found = rows.Next(); found {
   266  		if err = consumer(rows); err != nil {
   267  			return
   268  		}
   269  	}
   270  	return
   271  }
   272  
   273  // Scan reads the first row from a resultset and scans it to a given set of args.
   274  // If more than one row is returned it will return ErrTooManyRows.
   275  func Scan(rows *sql.Rows, args ...interface{}) (found bool, err error) {
   276  	if rows.Next() {
   277  		found = true
   278  		if err = rows.Scan(args...); err != nil {
   279  			err = Error(err)
   280  			return
   281  		}
   282  	}
   283  	if rows.Next() {
   284  		err = Error(ErrTooManyRows)
   285  	}
   286  	return
   287  }