github.com/wfusion/gofusion@v1.1.14/db/candy.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"reflect"
     8  	"strings"
     9  
    10  	"github.com/pkg/errors"
    11  	"gorm.io/gorm"
    12  
    13  	"github.com/wfusion/gofusion/common/constant"
    14  	"github.com/wfusion/gofusion/common/utils"
    15  	"github.com/wfusion/gofusion/log"
    16  )
    17  
    18  type scanOption struct {
    19  	dbName string
    20  
    21  	cursors       []any
    22  	cursorWhere   any
    23  	cursorColumns []string
    24  
    25  	where           any
    26  	sqlAndArguments []any
    27  
    28  	order any
    29  
    30  	batch int
    31  	limit int
    32  
    33  	log log.Loggable
    34  }
    35  
    36  type scanOptionGeneric[T any, TS ~[]*T] struct {
    37  	dal DalInterface[T, TS]
    38  }
    39  
    40  func ScanDAL[T any, TS ~[]*T](dal DalInterface[T, TS]) utils.OptionFunc[scanOptionGeneric[T, TS]] {
    41  	return func(o *scanOptionGeneric[T, TS]) {
    42  		o.dal = dal
    43  	}
    44  }
    45  
    46  func ScanUse(dbName string) utils.OptionFunc[scanOption] {
    47  	return func(o *scanOption) {
    48  		o.dbName = dbName
    49  	}
    50  }
    51  
    52  func ScanWhere(where any, sqlAndArguments ...any) utils.OptionFunc[scanOption] {
    53  	return func(o *scanOption) {
    54  		o.where = where
    55  		o.sqlAndArguments = sqlAndArguments
    56  	}
    57  }
    58  
    59  func ScanCursor(cursorWhere any, cursorColumns []string, cursors ...any) utils.OptionFunc[scanOption] {
    60  	return func(o *scanOption) {
    61  		o.cursors = cursors
    62  		o.cursorWhere = cursorWhere
    63  		o.cursorColumns = cursorColumns
    64  	}
    65  }
    66  
    67  func ScanOrder(order any) utils.OptionFunc[scanOption] {
    68  	return func(o *scanOption) {
    69  		o.order = order
    70  	}
    71  }
    72  
    73  func ScanBatch(batch int) utils.OptionFunc[scanOption] {
    74  	return func(o *scanOption) {
    75  		o.batch = batch
    76  	}
    77  }
    78  
    79  func ScanLimit(limit int) utils.OptionFunc[scanOption] {
    80  	return func(o *scanOption) {
    81  		o.limit = limit
    82  	}
    83  }
    84  
    85  func ScanLog(log log.Loggable) utils.OptionFunc[scanOption] {
    86  	return func(o *scanOption) {
    87  		o.log = log
    88  	}
    89  }
    90  
    91  func Scan[T any, TS ~[]*T](ctx context.Context, cb func(TS) bool, opts ...utils.OptionExtender) (err error) {
    92  	var (
    93  		tx    *gorm.DB
    94  		mList TS
    95  	)
    96  
    97  	o := utils.ApplyOptions[useOption](opts...)
    98  	opt := utils.ApplyOptions[scanOption](opts...)
    99  	optG := utils.ApplyOptions[scanOptionGeneric[T, TS]](opts...)
   100  
   101  	// get db instance
   102  	switch {
   103  	case optG.dal != nil:
   104  		tx = optG.dal.ReadDB(ctx)
   105  	case opt.dbName != "":
   106  		tx = Use(ctx, opt.dbName, AppName(o.appName)).GetProxy()
   107  	default:
   108  		panic(errors.New("unknown which table to scan"))
   109  	}
   110  
   111  	// default values
   112  	if opt.cursors == nil {
   113  		opt.cursors = []any{0}
   114  	}
   115  	if opt.cursorWhere == nil {
   116  		opt.cursorWhere = "id > ?"
   117  	}
   118  	if len(opt.cursorColumns) == 0 {
   119  		opt.cursorColumns = []string{"id"}
   120  	}
   121  	if opt.order == nil {
   122  		opt.order = fmt.Sprintf("%s ASC", strings.Join(opt.cursorColumns, constant.Comma))
   123  	}
   124  	if opt.batch == 0 {
   125  		opt.batch = 100
   126  	}
   127  	if opt.limit == 0 {
   128  		opt.limit = math.MaxInt
   129  	}
   130  
   131  	count := 0
   132  	tx = tx.WithContext(ctx)
   133  	if opt.log != nil {
   134  		opt.log.Info(ctx, "scan begin [where[%s][%+v] cursor[%s][%+v] order[%s] limit[%v] batch[%v]]",
   135  			opt.where, opt.sqlAndArguments, opt.cursorWhere, opt.cursors, opt.order, opt.limit, opt.batch)
   136  
   137  		defer func() { opt.log.Info(ctx, "scan end [count[%v]]", count) }()
   138  	}
   139  
   140  	// scan
   141  	for hasMore := true; hasMore; hasMore = len(mList) >= opt.batch {
   142  		// init model slice
   143  		mList = make(TS, 0, opt.batch)
   144  
   145  		// db query
   146  		q := tx.Where(opt.cursorWhere, opt.cursors...)
   147  		if opt.where != nil {
   148  			q = q.Where(opt.where, opt.sqlAndArguments...)
   149  		}
   150  		if opt.order != nil {
   151  			q = q.Order(opt.order)
   152  		}
   153  		if err = q.Limit(opt.batch).Find(&mList).Error; err != nil {
   154  			if opt.log != nil {
   155  				opt.log.Warn(ctx, "scan quit because meet with error [err[%s]]", err)
   156  			}
   157  			break
   158  		}
   159  
   160  		if len(mList) > 0 {
   161  			// callback
   162  			if !cb(mList) {
   163  				if opt.log != nil {
   164  					opt.log.Info(ctx, "scan quit because callback return false")
   165  				}
   166  				break
   167  			}
   168  
   169  			// get next cursor
   170  			next := mList[len(mList)-1]
   171  			nextVal := reflect.Indirect(reflect.ValueOf(next))
   172  			for idx, col := range opt.cursorColumns {
   173  				fieldVal := nextVal.FieldByNameFunc(func(s string) bool { return strings.EqualFold(s, col) })
   174  				if !fieldVal.IsValid() {
   175  					fieldVal, _ = utils.GetGormColumnValue(next, col)
   176  				}
   177  				if !fieldVal.IsValid() {
   178  					err = errors.Errorf("scan cursor column value is not found [col[%s]]", col)
   179  					if opt.log != nil {
   180  						opt.log.Error(ctx, "%s", err)
   181  					}
   182  					return
   183  				}
   184  				opt.cursors[idx] = fieldVal.Interface()
   185  			}
   186  		}
   187  
   188  		// check if exceed max
   189  		if count += len(mList); count >= opt.limit {
   190  			if opt.log != nil {
   191  				opt.log.Info(ctx, "scan quit because reach max [count[%v] max[%v]]", count, opt.limit)
   192  			}
   193  			break
   194  		}
   195  	}
   196  
   197  	return
   198  }