github.com/RevenueMonster/sqlike@v1.0.6/sqlike/find.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  
     7  	"github.com/RevenueMonster/sqlike/reflext"
     8  	"github.com/RevenueMonster/sqlike/sql/codec"
     9  	sqldialect "github.com/RevenueMonster/sqlike/sql/dialect"
    10  	sqldriver "github.com/RevenueMonster/sqlike/sql/driver"
    11  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    12  	"github.com/RevenueMonster/sqlike/sqlike/actions"
    13  	"github.com/RevenueMonster/sqlike/sqlike/logs"
    14  	"github.com/RevenueMonster/sqlike/sqlike/options"
    15  	"github.com/RevenueMonster/sqlike/sqlike/primitive"
    16  )
    17  
    18  // SingleResult : single result is an interface implementing apis as similar as driver.Result
    19  type SingleResult interface {
    20  	Scan(dest ...interface{}) error
    21  	Decode(dest interface{}) error
    22  	Columns() []string
    23  	ColumnTypes() ([]*sql.ColumnType, error)
    24  	Error() error
    25  }
    26  
    27  // FindOne : find single record on the table, you should alway check the return error to ensure it have result return.
    28  func (tb *Table) FindOne(ctx context.Context, act actions.SelectOneStatement, opts ...*options.FindOneOptions) SingleResult {
    29  	x := new(actions.FindOneActions)
    30  	if act != nil {
    31  		*x = *(act.(*actions.FindOneActions))
    32  	}
    33  	opt := new(options.FindOneOptions)
    34  	if len(opts) > 0 && opts[0] != nil {
    35  		opt = opts[0]
    36  	}
    37  	x.Limit(1)
    38  	rslt := find(
    39  		ctx,
    40  		tb.dbName,
    41  		tb.name,
    42  		tb.client.cache,
    43  		tb.codec,
    44  		tb.driver,
    45  		tb.dialect,
    46  		tb.logger,
    47  		&x.FindActions,
    48  		&opt.FindOptions,
    49  		opt.FindOptions.LockMode,
    50  	)
    51  	rslt.close = true
    52  	if rslt.err != nil {
    53  		return rslt
    54  	}
    55  	if !rslt.Next() {
    56  		rslt.err = sql.ErrNoRows
    57  	}
    58  	return rslt
    59  }
    60  
    61  // Find : find multiple records on the table.
    62  func (tb *Table) Find(ctx context.Context, act actions.SelectStatement, opts ...*options.FindOptions) (*Result, error) {
    63  	x := new(actions.FindActions)
    64  	if act != nil {
    65  		*x = *(act.(*actions.FindActions))
    66  	}
    67  	opt := new(options.FindOptions)
    68  	if len(opts) > 0 && opts[0] != nil {
    69  		opt = opts[0]
    70  	}
    71  	// has limit and limit value is zero
    72  	if !opt.NoLimit && x.Count < 1 {
    73  		x.Limit(100)
    74  	}
    75  	csr := find(
    76  		ctx,
    77  		tb.dbName,
    78  		tb.name,
    79  		tb.client.cache,
    80  		tb.codec,
    81  		tb.driver,
    82  		tb.dialect,
    83  		tb.logger,
    84  		x,
    85  		opt,
    86  		opt.LockMode,
    87  	)
    88  	if csr.err != nil {
    89  		return nil, csr.err
    90  	}
    91  	return csr, nil
    92  }
    93  
    94  func find(ctx context.Context, dbName, tbName string, cache reflext.StructMapper, cdc codec.Codecer, driver sqldriver.Driver, dialect sqldialect.Dialect, logger logs.Logger, act *actions.FindActions, opt *options.FindOptions, lock options.LockMode) *Result {
    95  	if act.Database == "" {
    96  		act.Database = dbName
    97  	}
    98  	if act.Table == "" {
    99  		act.Table = tbName
   100  	}
   101  
   102  	groups := extractResolution(ctx)
   103  	if len(groups) > 0 {
   104  		if len(act.Conditions.Values) > 0 {
   105  			act.Conditions.Values = append(act.Conditions.Values, primitive.And)
   106  		}
   107  
   108  		for _, group := range groups {
   109  			act.Conditions.Values = append(act.Conditions.Values, group.Values...)
   110  		}
   111  	}
   112  
   113  	rslt := new(Result)
   114  	rslt.cache = cache
   115  	rslt.codec = cdc
   116  
   117  	stmt := sqlstmt.AcquireStmt(dialect)
   118  	defer sqlstmt.ReleaseStmt(stmt)
   119  	if err := dialect.Select(stmt, act, lock); err != nil {
   120  		rslt.err = err
   121  		return rslt
   122  	}
   123  	rows, err := sqldriver.Query(
   124  		ctx,
   125  		driver,
   126  		stmt,
   127  		getLogger(logger, opt.Debug),
   128  	)
   129  	if err != nil {
   130  		rslt.err = err
   131  		return rslt
   132  	}
   133  	rslt.rows = rows
   134  	rslt.columnTypes, rslt.err = rows.ColumnTypes()
   135  	if rslt.err != nil {
   136  		defer rslt.rows.Close()
   137  	}
   138  	for _, col := range rslt.columnTypes {
   139  		rslt.columns = append(rslt.columns, col.Name())
   140  	}
   141  	return rslt
   142  }