github.com/RevenueMonster/sqlike@v1.0.6/sqlike/paginate.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  
     7  	"github.com/RevenueMonster/sqlike/reflext"
     8  	"github.com/RevenueMonster/sqlike/sql/expr"
     9  	"github.com/RevenueMonster/sqlike/sqlike/actions"
    10  	"github.com/RevenueMonster/sqlike/sqlike/options"
    11  	"github.com/RevenueMonster/sqlike/sqlike/primitive"
    12  )
    13  
    14  // ErrInvalidCursor :
    15  var ErrInvalidCursor = errors.New("sqlike: invalid cursor")
    16  
    17  // Paginate :
    18  func (tb *Table) Paginate(ctx context.Context, act actions.PaginateStatement, opts ...*options.PaginateOptions) (*Paginator, error) {
    19  	x := new(actions.PaginateActions)
    20  	if act != nil {
    21  		*x = *(act.(*actions.PaginateActions))
    22  	}
    23  	if x.Table == "" {
    24  		x.Table = tb.name
    25  	}
    26  	opt := new(options.PaginateOptions)
    27  	if len(opts) > 0 && opts[0] != nil {
    28  		opt = opts[0]
    29  	}
    30  	// sort by primary key
    31  	length := len(x.Sorts)
    32  	fields := make([]interface{}, length+1)
    33  	sort := expr.Asc(tb.pk)
    34  	if length > 0 {
    35  		x := x.Sorts[length-1].(primitive.Sort)
    36  		if x.Order == primitive.Descending {
    37  			sort = expr.Desc(tb.pk)
    38  		}
    39  	}
    40  	x.Sorts = append(x.Sorts, sort)
    41  	for i, sf := range x.Sorts {
    42  		fields[i] = sf.(primitive.Sort).Field
    43  	}
    44  	if x.Count == 1 {
    45  		return nil, errors.New("sqlike: pagination required more than 1 limit")
    46  	}
    47  	if x.Count == 0 {
    48  		x.Count = 100
    49  	}
    50  	return &Paginator{
    51  		ctx:    ctx,
    52  		table:  tb,
    53  		fields: fields,
    54  		action: x.FindActions,
    55  		option: &opt.FindOptions,
    56  	}, nil
    57  }
    58  
    59  // Paginator :
    60  type Paginator struct {
    61  	ctx    context.Context
    62  	table  *Table
    63  	fields []interface{}
    64  	values []interface{}
    65  	action actions.FindActions
    66  	option *options.FindOptions
    67  	err    error
    68  }
    69  
    70  // NextCursor :
    71  func (pg *Paginator) NextCursor(ctx context.Context, cursor interface{}) (err error) {
    72  	if pg.err != nil {
    73  		return pg.err
    74  	}
    75  	if cursor == nil || reflext.IsZero(reflext.ValueOf(cursor)) {
    76  		return ErrInvalidCursor
    77  	}
    78  	fa := actions.FindOne().Select(pg.fields...).Where(
    79  		expr.Equal(pg.table.pk, cursor),
    80  	).(*actions.FindOneActions)
    81  	fa.Limit(1)
    82  	result := find(
    83  		ctx,
    84  		pg.table.dbName,
    85  		pg.table.name,
    86  		pg.table.client.cache,
    87  		pg.table.codec,
    88  		pg.table.driver,
    89  		pg.table.dialect,
    90  		pg.table.logger,
    91  		&fa.FindActions,
    92  		&options.FindOptions{Debug: pg.option.Debug},
    93  		options.NoLock,
    94  	)
    95  	// prevent memory leak
    96  	defer result.Close()
    97  	pg.values, err = result.nextValues()
    98  	return
    99  }
   100  
   101  // All :
   102  func (pg *Paginator) All(results interface{}) error {
   103  	if pg.err != nil {
   104  		return pg.err
   105  	}
   106  	result := find(
   107  		pg.ctx,
   108  		pg.table.dbName,
   109  		pg.table.name,
   110  		pg.table.client.cache,
   111  		pg.table.codec,
   112  		pg.table.driver,
   113  		pg.table.dialect,
   114  		pg.table.logger,
   115  		pg.buildAction(),
   116  		pg.option,
   117  		options.NoLock,
   118  	)
   119  	return result.All(results)
   120  }
   121  
   122  func (pg *Paginator) buildAction() *actions.FindActions {
   123  	action := pg.action
   124  	if len(pg.values) < 1 {
   125  		return &action
   126  	}
   127  	length := len(pg.fields)
   128  	filters := make([]interface{}, 0, length)
   129  	fields := make([]interface{}, 0)
   130  	for i, sf := range action.Sorts {
   131  		var v primitive.C
   132  		val := toString(pg.values[i])
   133  		x := sf.(primitive.Sort)
   134  		if i == length-1 {
   135  			if x.Order == primitive.Ascending {
   136  				fields = append(fields, expr.GreaterOrEqual(x.Field, val))
   137  			} else {
   138  				fields = append(fields, expr.LesserOrEqual(x.Field, val))
   139  			}
   140  			continue
   141  		}
   142  		if x.Order == primitive.Ascending {
   143  			filters = append(filters, expr.GreaterOrEqual(x.Field, val))
   144  			v = expr.GreaterThan(x.Field, val)
   145  		} else {
   146  			filters = append(filters, expr.LesserOrEqual(x.Field, val))
   147  			v = expr.LesserThan(x.Field, val)
   148  		}
   149  		fields = append(fields, v)
   150  	}
   151  	filters = append(filters, expr.Or(fields...))
   152  	if len(action.Conditions.Values) > 0 {
   153  		action.Conditions.Values = append(action.Conditions.Values, primitive.And)
   154  	}
   155  	action.Conditions.Values = append(action.Conditions.Values, expr.And(filters...))
   156  	return &action
   157  }
   158  
   159  func toString(v interface{}) interface{} {
   160  	switch vi := v.(type) {
   161  	case []byte:
   162  		return string(vi)
   163  	default:
   164  		return vi
   165  	}
   166  }