github.com/RevenueMonster/sqlike@v1.0.6/sqlike/update.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  
     6  	sqldialect "github.com/RevenueMonster/sqlike/sql/dialect"
     7  	sqldriver "github.com/RevenueMonster/sqlike/sql/driver"
     8  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
     9  	"github.com/RevenueMonster/sqlike/sqlike/actions"
    10  	"github.com/RevenueMonster/sqlike/sqlike/logs"
    11  	"github.com/RevenueMonster/sqlike/sqlike/options"
    12  )
    13  
    14  // UpdateOne :
    15  func (tb *Table) UpdateOne(ctx context.Context, act actions.UpdateOneStatement, opts ...*options.UpdateOneOptions) (int64, error) {
    16  	x := new(actions.UpdateOneActions)
    17  	if act != nil {
    18  		*x = *(act.(*actions.UpdateOneActions))
    19  	}
    20  	opt := new(options.UpdateOneOptions)
    21  	if len(opts) > 0 && opts[0] != nil {
    22  		opt = opts[0]
    23  	}
    24  
    25  	x.Limit(1)
    26  	return update(
    27  		ctx,
    28  		tb.dbName,
    29  		tb.name,
    30  		tb.driver,
    31  		tb.dialect,
    32  		tb.logger,
    33  		&x.UpdateActions,
    34  		&opt.UpdateOptions,
    35  	)
    36  }
    37  
    38  // Update :
    39  func (tb *Table) Update(ctx context.Context, act actions.UpdateStatement, opts ...*options.UpdateOptions) (int64, error) {
    40  	x := new(actions.UpdateActions)
    41  	if act != nil {
    42  		*x = *(act.(*actions.UpdateActions))
    43  	}
    44  	opt := new(options.UpdateOptions)
    45  	if len(opts) > 0 && opts[0] != nil {
    46  		opt = opts[0]
    47  	}
    48  	return update(
    49  		ctx,
    50  		tb.dbName,
    51  		tb.name,
    52  		tb.driver,
    53  		tb.dialect,
    54  		tb.logger,
    55  		x,
    56  		opt,
    57  	)
    58  }
    59  
    60  func update(ctx context.Context, dbName, tbName string, driver sqldriver.Driver, dialect sqldialect.Dialect, logger logs.Logger, act *actions.UpdateActions, opt *options.UpdateOptions) (int64, error) {
    61  	if act.Database == "" {
    62  		act.Database = dbName
    63  	}
    64  	if act.Table == "" {
    65  		act.Table = tbName
    66  	}
    67  	if len(act.Values) < 1 {
    68  		return 0, ErrNoValueUpdate
    69  	}
    70  	stmt := sqlstmt.AcquireStmt(dialect)
    71  	defer sqlstmt.ReleaseStmt(stmt)
    72  	if err := dialect.Update(stmt, act); err != nil {
    73  		return 0, err
    74  	}
    75  	result, err := sqldriver.Execute(
    76  		ctx,
    77  		driver,
    78  		stmt,
    79  		getLogger(logger, opt.Debug),
    80  	)
    81  	if err != nil {
    82  		return 0, err
    83  	}
    84  	return result.RowsAffected()
    85  }