github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/dal/mutation.go (about)

     1  package dal
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"time"
     7  
     8  	"github.com/octohelm/storage/internal/sql/scanner"
     9  	"github.com/octohelm/storage/pkg/datatypes"
    10  	"github.com/octohelm/storage/pkg/sqlbuilder"
    11  )
    12  
    13  func Prepare[T any](v *T) Mutation[T] {
    14  	if m, ok := any(v).(ModelWithCreationTime); ok {
    15  		m.MarkCreatedAt()
    16  	}
    17  
    18  	return &mutation[T]{
    19  		target: v,
    20  		feature: feature{
    21  			softDelete: true,
    22  		},
    23  	}
    24  }
    25  
    26  type Mutation[T any] interface {
    27  	IncludesZero(zeroFields ...sqlbuilder.Column) Mutation[T]
    28  
    29  	ForDelete(opts ...OptionFunc) Mutation[T]
    30  	ForUpdateSet(assignments ...sqlbuilder.Assignment) Mutation[T]
    31  
    32  	Where(where sqlbuilder.SqlExpr) Mutation[T]
    33  
    34  	OnConflict(cols sqlbuilder.ColumnCollection) Mutation[T]
    35  	DoNothing() Mutation[T]
    36  	DoUpdateSet(cols ...sqlbuilder.Column) Mutation[T]
    37  	DoWith(func(onConflictAddition sqlbuilder.OnConflictAddition) sqlbuilder.Addition) Mutation[T]
    38  
    39  	Returning(cols ...sqlbuilder.SqlExpr) Mutation[T]
    40  	Scan(recv any) Mutation[T]
    41  
    42  	Save(ctx context.Context) error
    43  }
    44  
    45  type mutation[T any] struct {
    46  	target             *T
    47  	recv               any
    48  	zeroFieldsIncludes []sqlbuilder.Column
    49  
    50  	assignmentsForUpdate sqlbuilder.Assignments
    51  	where                sqlbuilder.SqlExpr
    52  
    53  	conflict              sqlbuilder.ColumnCollection
    54  	onConflictDoWith      func(onConflictAddition sqlbuilder.OnConflictAddition) sqlbuilder.Addition
    55  	onConflictDoUpdateSet []sqlbuilder.Column
    56  
    57  	returning []sqlbuilder.SqlExpr
    58  
    59  	forDelete bool
    60  
    61  	feature
    62  }
    63  
    64  type DeleteFunc func()
    65  
    66  func (c mutation[T]) IncludesZero(zeroFields ...sqlbuilder.Column) Mutation[T] {
    67  	c.zeroFieldsIncludes = zeroFields
    68  	return &c
    69  }
    70  
    71  func (c mutation[T]) ForDelete(fns ...OptionFunc) Mutation[T] {
    72  	c.forDelete = true
    73  	for i := range fns {
    74  		fns[i](&c)
    75  	}
    76  	return &c
    77  }
    78  
    79  func (c mutation[T]) ForUpdateSet(assignments ...sqlbuilder.Assignment) Mutation[T] {
    80  	c.assignmentsForUpdate = assignments
    81  	return &c
    82  }
    83  
    84  func (c mutation[T]) Where(where sqlbuilder.SqlExpr) Mutation[T] {
    85  	c.where = where
    86  	return &c
    87  }
    88  
    89  func (c mutation[T]) OnConflict(cols sqlbuilder.ColumnCollection) Mutation[T] {
    90  	c.conflict = cols
    91  	return &c
    92  }
    93  
    94  func (c mutation[T]) DoNothing() Mutation[T] {
    95  	c.onConflictDoUpdateSet = nil
    96  	return &c
    97  }
    98  
    99  func (c mutation[T]) DoWith(fn func(onConflictAddition sqlbuilder.OnConflictAddition) sqlbuilder.Addition) Mutation[T] {
   100  	c.onConflictDoWith = fn
   101  	return &c
   102  }
   103  
   104  func (c mutation[T]) DoUpdateSet(cols ...sqlbuilder.Column) Mutation[T] {
   105  	c.onConflictDoUpdateSet = cols
   106  	return &c
   107  }
   108  
   109  func (c mutation[T]) Returning(cols ...sqlbuilder.SqlExpr) Mutation[T] {
   110  	if len(cols) != 0 {
   111  		c.returning = cols
   112  	} else {
   113  		c.returning = make([]sqlbuilder.SqlExpr, 0)
   114  	}
   115  	return &c
   116  }
   117  
   118  func (c mutation[T]) Scan(recv any) Mutation[T] {
   119  	c.recv = recv
   120  	return &c
   121  }
   122  
   123  func (c *mutation[T]) Save(ctx context.Context) error {
   124  	s := SessionFor(ctx, c.target)
   125  	if c.forDelete {
   126  		return c.del(ctx, s.T(c.target), s)
   127  	}
   128  	return c.insertOrUpdate(ctx, s.T(c.target), s)
   129  }
   130  
   131  func (c *mutation[T]) buildWhere(t sqlbuilder.Table) sqlbuilder.SqlCondition {
   132  	if c.where == nil {
   133  		return nil
   134  	}
   135  	where := c.where
   136  	if c.feature.softDelete {
   137  		if soft, ok := any(c.target).(ModelWithSoftDelete); ok {
   138  			f, notDeletedValue := soft.SoftDeleteFieldAndZeroValue()
   139  			return sqlbuilder.And(
   140  				where,
   141  				t.F(f).Expr("# = ?", notDeletedValue),
   142  			)
   143  		}
   144  	}
   145  	return sqlbuilder.AsCond(where)
   146  }
   147  
   148  func (c *mutation[T]) del(ctx context.Context, t sqlbuilder.Table, s Session) error {
   149  	where := c.buildWhere(t)
   150  	if where == nil {
   151  		// never delete without condition
   152  		return nil
   153  	}
   154  
   155  	var stmt sqlbuilder.SqlExpr
   156  
   157  	additions, hasReturning := c.withReturning(t, nil)
   158  
   159  	if c.feature.softDelete {
   160  		if soft, ok := any(c.target).(ModelWithSoftDelete); ok {
   161  			soft.MarkDeletedAt()
   162  
   163  			f, _ := soft.SoftDeleteFieldAndZeroValue()
   164  
   165  			var softDeleteValue driver.Value
   166  			if v, ok := ctx.(SoftDeleteValueGetter); ok {
   167  				softDeleteValue = v.GetDeletedAt()
   168  			} else {
   169  				softDeleteValue = datatypes.Timestamp(time.Now())
   170  			}
   171  
   172  			col := t.F(f)
   173  			stmt = sqlbuilder.Update(t).Where(where, additions...).Set(
   174  				sqlbuilder.ColumnsAndValues(col, softDeleteValue),
   175  			)
   176  		}
   177  	}
   178  
   179  	if stmt == nil {
   180  		stmt = sqlbuilder.Delete().From(t, append([]sqlbuilder.Addition{sqlbuilder.Where(where)}, additions...)...)
   181  	}
   182  
   183  	return c.exec(ctx, s, stmt, hasReturning)
   184  }
   185  
   186  func (c *mutation[T]) insertOrUpdate(ctx context.Context, t sqlbuilder.Table, s Session) error {
   187  	additions := make([]sqlbuilder.Addition, 0)
   188  
   189  	if c.conflict != nil && c.conflict.Len() > 0 {
   190  		onConflict := sqlbuilder.OnConflict(c.conflict)
   191  
   192  		if onConflictDoWith := c.onConflictDoWith; onConflictDoWith != nil {
   193  			additions = append(additions, onConflictDoWith(onConflict))
   194  		} else {
   195  			cols := c.onConflictDoUpdateSet
   196  			if cols == nil {
   197  				// FIXME ugly hack
   198  				// sqlite will not RETURNING when ON CONFLICT DO NOTHING
   199  				c.conflict.RangeCol(func(col sqlbuilder.Column, idx int) bool {
   200  					cols = append(cols, col)
   201  					return true
   202  				})
   203  			}
   204  
   205  			assignments := make([]sqlbuilder.Assignment, len(cols))
   206  
   207  			for idx, col := range cols {
   208  				assignments[idx] = sqlbuilder.ColumnsAndValues(
   209  					col, col.Expr("EXCLUDED.?", sqlbuilder.Expr(col.Name())),
   210  				)
   211  			}
   212  
   213  			onConflict = onConflict.DoUpdateSet(assignments...)
   214  			additions = append(additions, onConflict)
   215  		}
   216  	}
   217  
   218  	additions, hasReturning := c.withReturning(t, additions)
   219  
   220  	zeroFieldsIncludes := make([]string, len(c.zeroFieldsIncludes))
   221  
   222  	for i := range zeroFieldsIncludes {
   223  		zeroFieldsIncludes[i] = c.zeroFieldsIncludes[i].FieldName()
   224  	}
   225  
   226  	fieldValues := sqlbuilder.FieldValuesFromStructByNonZero(c.target, zeroFieldsIncludes...)
   227  
   228  	var stmt sqlbuilder.SqlExpr
   229  
   230  	if where := c.buildWhere(t); where != nil {
   231  		assignmentsForUpdate := c.assignmentsForUpdate
   232  		if len(assignmentsForUpdate) == 0 {
   233  			assignmentsForUpdate = sqlbuilder.AssignmentsByFieldValues(t, fieldValues)
   234  		}
   235  		stmt = sqlbuilder.Update(t).
   236  			Where(where, additions...).
   237  			Set(assignmentsForUpdate...)
   238  	} else {
   239  		cols, vals := sqlbuilder.ColumnsAndValuesByFieldValues(t, fieldValues)
   240  		stmt = sqlbuilder.Insert().Into(t, additions...).
   241  			Values(cols, vals...)
   242  	}
   243  
   244  	return c.exec(ctx, s, stmt, hasReturning)
   245  }
   246  
   247  func (c *mutation[T]) exec(ctx context.Context, s Session, stmt sqlbuilder.SqlExpr, hasReturning bool) error {
   248  	if hasReturning {
   249  		rows, err := s.Adapter().Query(ctx, stmt)
   250  		if err != nil {
   251  			return err
   252  		}
   253  		return scanner.Scan(ctx, rows, c.recv)
   254  	}
   255  	_, err := s.Adapter().Exec(ctx, stmt)
   256  	return err
   257  }
   258  
   259  func (c *mutation[T]) withReturning(t sqlbuilder.Table, additions []sqlbuilder.Addition) ([]sqlbuilder.Addition, bool) {
   260  	hasReturning := false
   261  
   262  	if c.returning != nil {
   263  		hasReturning = true
   264  
   265  		if len(c.returning) == 0 {
   266  			additions = append(additions, sqlbuilder.Returning(sqlbuilder.Expr("*")))
   267  		} else {
   268  			additions = append(additions, sqlbuilder.Returning(sqlbuilder.MultiMayAutoAlias(c.returning...)))
   269  		}
   270  	}
   271  
   272  	return additions, hasReturning
   273  }