github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/dal/mutation.go (about) 1 package dal 2 3 import ( 4 "context" 5 "database/sql/driver" 6 "time" 7 8 "github.com/octohelm/storage/internal/sql/scanner" 9 "github.com/octohelm/storage/pkg/datatypes" 10 "github.com/octohelm/storage/pkg/sqlbuilder" 11 ) 12 13 func Prepare[T any](v *T) Mutation[T] { 14 if m, ok := any(v).(ModelWithCreationTime); ok { 15 m.MarkCreatedAt() 16 } 17 18 return &mutation[T]{ 19 target: v, 20 feature: feature{ 21 softDelete: true, 22 }, 23 } 24 } 25 26 type Mutation[T any] interface { 27 IncludesZero(zeroFields ...sqlbuilder.Column) Mutation[T] 28 29 ForDelete(opts ...OptionFunc) Mutation[T] 30 ForUpdateSet(assignments ...sqlbuilder.Assignment) Mutation[T] 31 32 Where(where sqlbuilder.SqlExpr) Mutation[T] 33 34 OnConflict(cols sqlbuilder.ColumnCollection) Mutation[T] 35 DoNothing() Mutation[T] 36 DoUpdateSet(cols ...sqlbuilder.Column) Mutation[T] 37 DoWith(func(onConflictAddition sqlbuilder.OnConflictAddition) sqlbuilder.Addition) Mutation[T] 38 39 Returning(cols ...sqlbuilder.SqlExpr) Mutation[T] 40 Scan(recv any) Mutation[T] 41 42 Save(ctx context.Context) error 43 } 44 45 type mutation[T any] struct { 46 target *T 47 recv any 48 zeroFieldsIncludes []sqlbuilder.Column 49 50 assignmentsForUpdate sqlbuilder.Assignments 51 where sqlbuilder.SqlExpr 52 53 conflict sqlbuilder.ColumnCollection 54 onConflictDoWith func(onConflictAddition sqlbuilder.OnConflictAddition) sqlbuilder.Addition 55 onConflictDoUpdateSet []sqlbuilder.Column 56 57 returning []sqlbuilder.SqlExpr 58 59 forDelete bool 60 61 feature 62 } 63 64 type DeleteFunc func() 65 66 func (c mutation[T]) IncludesZero(zeroFields ...sqlbuilder.Column) Mutation[T] { 67 c.zeroFieldsIncludes = zeroFields 68 return &c 69 } 70 71 func (c mutation[T]) ForDelete(fns ...OptionFunc) Mutation[T] { 72 c.forDelete = true 73 for i := range fns { 74 fns[i](&c) 75 } 76 return &c 77 } 78 79 func (c mutation[T]) ForUpdateSet(assignments ...sqlbuilder.Assignment) Mutation[T] { 80 c.assignmentsForUpdate = assignments 81 return &c 82 } 83 84 func (c mutation[T]) Where(where sqlbuilder.SqlExpr) Mutation[T] { 85 c.where = where 86 return &c 87 } 88 89 func (c mutation[T]) OnConflict(cols sqlbuilder.ColumnCollection) Mutation[T] { 90 c.conflict = cols 91 return &c 92 } 93 94 func (c mutation[T]) DoNothing() Mutation[T] { 95 c.onConflictDoUpdateSet = nil 96 return &c 97 } 98 99 func (c mutation[T]) DoWith(fn func(onConflictAddition sqlbuilder.OnConflictAddition) sqlbuilder.Addition) Mutation[T] { 100 c.onConflictDoWith = fn 101 return &c 102 } 103 104 func (c mutation[T]) DoUpdateSet(cols ...sqlbuilder.Column) Mutation[T] { 105 c.onConflictDoUpdateSet = cols 106 return &c 107 } 108 109 func (c mutation[T]) Returning(cols ...sqlbuilder.SqlExpr) Mutation[T] { 110 if len(cols) != 0 { 111 c.returning = cols 112 } else { 113 c.returning = make([]sqlbuilder.SqlExpr, 0) 114 } 115 return &c 116 } 117 118 func (c mutation[T]) Scan(recv any) Mutation[T] { 119 c.recv = recv 120 return &c 121 } 122 123 func (c *mutation[T]) Save(ctx context.Context) error { 124 s := SessionFor(ctx, c.target) 125 if c.forDelete { 126 return c.del(ctx, s.T(c.target), s) 127 } 128 return c.insertOrUpdate(ctx, s.T(c.target), s) 129 } 130 131 func (c *mutation[T]) buildWhere(t sqlbuilder.Table) sqlbuilder.SqlCondition { 132 if c.where == nil { 133 return nil 134 } 135 where := c.where 136 if c.feature.softDelete { 137 if soft, ok := any(c.target).(ModelWithSoftDelete); ok { 138 f, notDeletedValue := soft.SoftDeleteFieldAndZeroValue() 139 return sqlbuilder.And( 140 where, 141 t.F(f).Expr("# = ?", notDeletedValue), 142 ) 143 } 144 } 145 return sqlbuilder.AsCond(where) 146 } 147 148 func (c *mutation[T]) del(ctx context.Context, t sqlbuilder.Table, s Session) error { 149 where := c.buildWhere(t) 150 if where == nil { 151 // never delete without condition 152 return nil 153 } 154 155 var stmt sqlbuilder.SqlExpr 156 157 additions, hasReturning := c.withReturning(t, nil) 158 159 if c.feature.softDelete { 160 if soft, ok := any(c.target).(ModelWithSoftDelete); ok { 161 soft.MarkDeletedAt() 162 163 f, _ := soft.SoftDeleteFieldAndZeroValue() 164 165 var softDeleteValue driver.Value 166 if v, ok := ctx.(SoftDeleteValueGetter); ok { 167 softDeleteValue = v.GetDeletedAt() 168 } else { 169 softDeleteValue = datatypes.Timestamp(time.Now()) 170 } 171 172 col := t.F(f) 173 stmt = sqlbuilder.Update(t).Where(where, additions...).Set( 174 sqlbuilder.ColumnsAndValues(col, softDeleteValue), 175 ) 176 } 177 } 178 179 if stmt == nil { 180 stmt = sqlbuilder.Delete().From(t, append([]sqlbuilder.Addition{sqlbuilder.Where(where)}, additions...)...) 181 } 182 183 return c.exec(ctx, s, stmt, hasReturning) 184 } 185 186 func (c *mutation[T]) insertOrUpdate(ctx context.Context, t sqlbuilder.Table, s Session) error { 187 additions := make([]sqlbuilder.Addition, 0) 188 189 if c.conflict != nil && c.conflict.Len() > 0 { 190 onConflict := sqlbuilder.OnConflict(c.conflict) 191 192 if onConflictDoWith := c.onConflictDoWith; onConflictDoWith != nil { 193 additions = append(additions, onConflictDoWith(onConflict)) 194 } else { 195 cols := c.onConflictDoUpdateSet 196 if cols == nil { 197 // FIXME ugly hack 198 // sqlite will not RETURNING when ON CONFLICT DO NOTHING 199 c.conflict.RangeCol(func(col sqlbuilder.Column, idx int) bool { 200 cols = append(cols, col) 201 return true 202 }) 203 } 204 205 assignments := make([]sqlbuilder.Assignment, len(cols)) 206 207 for idx, col := range cols { 208 assignments[idx] = sqlbuilder.ColumnsAndValues( 209 col, col.Expr("EXCLUDED.?", sqlbuilder.Expr(col.Name())), 210 ) 211 } 212 213 onConflict = onConflict.DoUpdateSet(assignments...) 214 additions = append(additions, onConflict) 215 } 216 } 217 218 additions, hasReturning := c.withReturning(t, additions) 219 220 zeroFieldsIncludes := make([]string, len(c.zeroFieldsIncludes)) 221 222 for i := range zeroFieldsIncludes { 223 zeroFieldsIncludes[i] = c.zeroFieldsIncludes[i].FieldName() 224 } 225 226 fieldValues := sqlbuilder.FieldValuesFromStructByNonZero(c.target, zeroFieldsIncludes...) 227 228 var stmt sqlbuilder.SqlExpr 229 230 if where := c.buildWhere(t); where != nil { 231 assignmentsForUpdate := c.assignmentsForUpdate 232 if len(assignmentsForUpdate) == 0 { 233 assignmentsForUpdate = sqlbuilder.AssignmentsByFieldValues(t, fieldValues) 234 } 235 stmt = sqlbuilder.Update(t). 236 Where(where, additions...). 237 Set(assignmentsForUpdate...) 238 } else { 239 cols, vals := sqlbuilder.ColumnsAndValuesByFieldValues(t, fieldValues) 240 stmt = sqlbuilder.Insert().Into(t, additions...). 241 Values(cols, vals...) 242 } 243 244 return c.exec(ctx, s, stmt, hasReturning) 245 } 246 247 func (c *mutation[T]) exec(ctx context.Context, s Session, stmt sqlbuilder.SqlExpr, hasReturning bool) error { 248 if hasReturning { 249 rows, err := s.Adapter().Query(ctx, stmt) 250 if err != nil { 251 return err 252 } 253 return scanner.Scan(ctx, rows, c.recv) 254 } 255 _, err := s.Adapter().Exec(ctx, stmt) 256 return err 257 } 258 259 func (c *mutation[T]) withReturning(t sqlbuilder.Table, additions []sqlbuilder.Addition) ([]sqlbuilder.Addition, bool) { 260 hasReturning := false 261 262 if c.returning != nil { 263 hasReturning = true 264 265 if len(c.returning) == 0 { 266 additions = append(additions, sqlbuilder.Returning(sqlbuilder.Expr("*"))) 267 } else { 268 additions = append(additions, sqlbuilder.Returning(sqlbuilder.MultiMayAutoAlias(c.returning...))) 269 } 270 } 271 272 return additions, hasReturning 273 }