github.com/RevenueMonster/sqlike@v1.0.6/sqlike/modify.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"reflect"
     7  
     8  	"github.com/RevenueMonster/sqlike/reflext"
     9  	sqldialect "github.com/RevenueMonster/sqlike/sql/dialect"
    10  	sqldriver "github.com/RevenueMonster/sqlike/sql/driver"
    11  	"github.com/RevenueMonster/sqlike/sql/expr"
    12  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    13  	"github.com/RevenueMonster/sqlike/sqlike/actions"
    14  	"github.com/RevenueMonster/sqlike/sqlike/logs"
    15  	"github.com/RevenueMonster/sqlike/sqlike/options"
    16  )
    17  
    18  // ModifyOne :
    19  func (tb *Table) ModifyOne(ctx context.Context, update interface{}, opts ...*options.ModifyOneOptions) error {
    20  	return modifyOne(
    21  		ctx,
    22  		tb.dbName,
    23  		tb.name,
    24  		tb.pk,
    25  		tb.client.cache,
    26  		tb.dialect,
    27  		tb.driver,
    28  		tb.logger,
    29  		update,
    30  		opts,
    31  	)
    32  }
    33  
    34  func modifyOne(ctx context.Context, dbName, tbName, pk string, cache reflext.StructMapper, dialect sqldialect.Dialect, driver sqldriver.Driver, logger logs.Logger, update interface{}, opts []*options.ModifyOneOptions) error {
    35  	v := reflext.ValueOf(update)
    36  	if !v.IsValid() {
    37  		return ErrInvalidInput
    38  	}
    39  
    40  	t := v.Type()
    41  	if !reflext.IsKind(t, reflect.Ptr) {
    42  		return ErrUnaddressableEntity
    43  	}
    44  
    45  	if v.IsNil() {
    46  		return ErrNilEntity
    47  	}
    48  
    49  	cdc := cache.CodecByType(t)
    50  	opt := new(options.ModifyOneOptions)
    51  	if len(opts) > 0 && opts[0] != nil {
    52  		opt = opts[0]
    53  	}
    54  
    55  	fields := skipColumns(cdc.Properties(), opt.Omits)
    56  	x := new(actions.UpdateActions)
    57  	x.Table = tbName
    58  
    59  	var pkv = [2]interface{}{}
    60  	for _, sf := range fields {
    61  		fv := cache.FieldByIndexesReadOnly(v, sf.Index())
    62  		if _, ok := sf.Tag().LookUp("primary_key"); ok {
    63  			if pkv[0] != nil {
    64  				x.Set(expr.ColumnValue(pkv[0].(string), pkv[1]))
    65  			}
    66  			pkv[0] = sf.Name()
    67  			pkv[1] = fv.Interface()
    68  			continue
    69  		}
    70  		if sf.Name() == pk && pkv[0] == nil {
    71  			pkv[0] = sf.Name()
    72  			pkv[1] = fv.Interface()
    73  			continue
    74  		}
    75  		x.Set(expr.ColumnValue(sf.Name(), fv.Interface()))
    76  	}
    77  
    78  	if pkv[0] == nil {
    79  		return errors.New("sqlike: missing primary key field")
    80  	}
    81  
    82  	x.Where(expr.Equal(pkv[0], pkv[1]))
    83  	x.Limit(1)
    84  	x.Table = tbName
    85  	x.Database = dbName
    86  
    87  	stmt := sqlstmt.AcquireStmt(dialect)
    88  	defer sqlstmt.ReleaseStmt(stmt)
    89  	if err := dialect.Update(stmt, x); err != nil {
    90  		return err
    91  	}
    92  
    93  	result, err := sqldriver.Execute(
    94  		ctx,
    95  		driver,
    96  		stmt,
    97  		getLogger(logger, opt.Debug),
    98  	)
    99  	if err != nil {
   100  		return err
   101  	}
   102  	if !opt.NoStrict {
   103  		affected, err := result.RowsAffected()
   104  		if err != nil {
   105  			return err
   106  		}
   107  		if affected < 1 {
   108  			return ErrNoRecordAffected
   109  		}
   110  	}
   111  	return nil
   112  }