gitee.com/h79/goutils@v1.22.10/dao/wrapper/base.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"gitee.com/h79/goutils/common"
     7  	"gitee.com/h79/goutils/dao/option"
     8  	"gorm.io/gorm"
     9  	"reflect"
    10  )
    11  
    12  type Base struct {
    13  	DB      *gorm.DB
    14  	ctx     context.Context
    15  	cancel  context.CancelFunc
    16  	order   interface{}
    17  	Results interface{}
    18  }
    19  
    20  func (b *Base) NewDB() *gorm.DB {
    21  	return b.DB.Session(&gorm.Session{NewDB: true, Context: b.ctx})
    22  }
    23  
    24  func (b *Base) WithResult(result interface{}) *Base {
    25  	if !IsValid(result) {
    26  		panic("the result is need ptr or slice")
    27  	}
    28  	b.Results = result
    29  	return b
    30  }
    31  
    32  func (b *Base) WithOrder(order interface{}) *Base {
    33  	b.order = order
    34  	return b
    35  }
    36  
    37  func (b *Base) Cancel() {
    38  	b.cancel()
    39  }
    40  
    41  // Get 获取
    42  func (b *Base) Get(model interface{}) (result interface{}, err error) {
    43  	err = b.DB.Model(model).First(&result).Error
    44  	return
    45  }
    46  
    47  // Gets 获取批量结果
    48  func (b *Base) Gets(model interface{}) (results []*interface{}, err error) {
    49  	err = b.DB.Model(model).Find(&results).Error
    50  	return
    51  }
    52  
    53  func (b *Base) Count(model interface{}, count *int64) error {
    54  	return b.DB.Model(model).Count(count).Error
    55  }
    56  
    57  func (b *Base) Create(model interface{}, opts ...option.QueryFunc) error {
    58  	// 根据 `map` 更新属性
    59  	// db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "active": false})
    60  	q := option.Query{
    61  		Q: make(map[string]interface{}),
    62  	}
    63  	for _, o := range opts {
    64  		o.Apply(&q)
    65  	}
    66  	db := b.DB.Model(model)
    67  	return db.Create(q.Q).Error
    68  }
    69  
    70  func (b *Base) Delete(model interface{}, cond IQuery) error {
    71  	db := b.DB
    72  	if cond.Is() {
    73  		db = db.Where(cond.Query(), cond.Value()...)
    74  	}
    75  	return db.Delete(model).Error
    76  }
    77  
    78  func (b *Base) UpdateColumn(model interface{}, where IQuery, column string, value interface{}) error {
    79  	return b.Updates(model, where, option.QueryFunc(func(q *option.Query) {
    80  		q.Q[column] = value
    81  	}))
    82  }
    83  
    84  func (b *Base) Updates(model interface{}, where IQuery, opts ...option.QueryFunc) error {
    85  	// 根据 `map` 更新属性
    86  	// db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "active": false})
    87  	q := option.Query{
    88  		Q: make(map[string]interface{}),
    89  	}
    90  	for _, o := range opts {
    91  		o.Apply(&q)
    92  	}
    93  	db := b.DB.Model(model)
    94  	if where != nil && where.Is() {
    95  		db = db.Where(where.Query(), where.Value()...)
    96  	}
    97  	return db.Updates(q.Q).Error
    98  }
    99  
   100  func (b *Base) First(model interface{}, where IQuery, sel ISelect) error {
   101  	//db.Model(User{ID: 10}).First(&result)
   102  	// SELECT * FROM users WHERE id = 10;
   103  	db := b.DB.Model(model)
   104  	if where != nil && where.Is() {
   105  		db = db.Where(where.Query(), where.Value()...)
   106  	}
   107  	if sel != nil && sel.Is() {
   108  		db = db.Select(sel.Query(), sel.Value()...)
   109  	}
   110  	if !common.IsNil(b.order) {
   111  		db = db.Order(b.order)
   112  	}
   113  	return db.First(b.Results).Error
   114  }
   115  
   116  func (b *Base) Find(model interface{}, where IQuery, sel ISelect) error {
   117  
   118  	db := b.DB.Model(model)
   119  
   120  	if where != nil && where.Is() {
   121  		db = db.Where(where.Query(), where.Value()...)
   122  	}
   123  	if sel != nil && sel.Is() {
   124  		db = db.Select(sel.Query(), sel.Value()...)
   125  	}
   126  	if !common.IsNil(b.order) {
   127  		db = db.Order(b.order)
   128  	}
   129  	return db.Find(b.Results).Error
   130  }
   131  
   132  func IsNotFound(db *gorm.DB) int {
   133  	if errors.Is(db.Error, gorm.ErrRecordNotFound) {
   134  		return 1
   135  	}
   136  	if db.Error != nil {
   137  		return -1
   138  	}
   139  	return 0
   140  }
   141  
   142  func IsValid(r interface{}) bool {
   143  	if r == nil {
   144  		return false
   145  	}
   146  	val := reflect.Indirect(reflect.ValueOf(r))
   147  	k := val.Kind()
   148  	switch k {
   149  	case reflect.Struct:
   150  		fallthrough
   151  	case reflect.Slice:
   152  		return true
   153  	}
   154  	return false
   155  }