github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/dal/querier.go (about) 1 package dal 2 3 import ( 4 "context" 5 "reflect" 6 7 "github.com/octohelm/storage/internal/sql/scanner" 8 "github.com/octohelm/storage/pkg/sqlbuilder" 9 "github.com/pkg/errors" 10 ) 11 12 //Intersect(q Querier) Querier 13 //Except(q Querier) Querier 14 15 func InSelect[T any](col sqlbuilder.TypedColumn[T], q Querier) sqlbuilder.ColumnValueExpr[T] { 16 return func(v sqlbuilder.Column) sqlbuilder.SqlExpr { 17 ex := q.Select(col) 18 if ex.IsNil() { 19 return nil 20 } 21 return sqlbuilder.Expr("? IN (?)", v, ex) 22 } 23 } 24 25 func NotInSelect[T any](col sqlbuilder.TypedColumn[T], q Querier) sqlbuilder.ColumnValueExpr[T] { 26 return func(v sqlbuilder.Column) sqlbuilder.SqlExpr { 27 ex := q.Select(col) 28 if ex.IsNil() { 29 return nil 30 } 31 return sqlbuilder.Expr("? NOT IN (?)", v, ex) 32 } 33 } 34 35 type QuerierPatcher interface { 36 Apply(q Querier) Querier 37 } 38 39 type Querier interface { 40 sqlbuilder.SqlExpr 41 42 ExistsTable(table sqlbuilder.Table) bool 43 Apply(patchers ...QuerierPatcher) Querier 44 45 With(t sqlbuilder.Table, build sqlbuilder.BuildSubQuery, modifiers ...string) Querier 46 AsTemporaryTable(tableName string) TemporaryTable 47 48 Join(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier 49 CrossJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier 50 LeftJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier 51 RightJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier 52 FullJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier 53 54 Where(where sqlbuilder.SqlExpr) Querier 55 WhereAnd(where sqlbuilder.SqlExpr) Querier 56 WhereOr(where sqlbuilder.SqlExpr) Querier 57 58 OrderBy(orders ...*sqlbuilder.Order) Querier 59 60 GroupBy(cols ...sqlbuilder.SqlExpr) Querier 61 Having(where sqlbuilder.SqlExpr) Querier 62 63 Limit(v int64) Querier 64 Offset(v int64) Querier 65 66 Distinct(extras ...sqlbuilder.SqlExpr) Querier 67 Select(projects ...sqlbuilder.SqlExpr) Querier 68 69 Scan(v any) Querier 70 71 Find(ctx context.Context) error 72 Count(ctx context.Context) (int, error) 73 } 74 75 func From(from sqlbuilder.Table, fns ...OptionFunc) Querier { 76 q := &querier{ 77 from: from, 78 tables: []sqlbuilder.Table{from}, 79 limit: -1, 80 feature: feature{ 81 softDelete: true, 82 }, 83 } 84 85 for i := range fns { 86 fns[i](q) 87 } 88 89 if tmpT, ok := from.(QuerierPatcher); ok { 90 return q.Apply(tmpT) 91 } 92 93 return q 94 } 95 96 type querier struct { 97 from sqlbuilder.Table 98 tables []sqlbuilder.Table 99 100 withStmt *sqlbuilder.WithStmt 101 102 orders []*sqlbuilder.Order 103 104 distinct []sqlbuilder.SqlExpr 105 groupBy []sqlbuilder.SqlExpr 106 having sqlbuilder.SqlExpr 107 108 limit int64 109 offset int64 110 111 where sqlbuilder.SqlExpr 112 projects []sqlbuilder.SqlExpr 113 114 joins []sqlbuilder.Addition 115 116 feature 117 118 recv any 119 } 120 121 func (q *querier) ExistsTable(table sqlbuilder.Table) bool { 122 for _, t := range q.tables { 123 if t == table || t.TableName() == table.TableName() { 124 return true 125 } 126 } 127 return false 128 } 129 130 func (q *querier) Apply(patchers ...QuerierPatcher) Querier { 131 var applied Querier = q 132 133 for _, p := range patchers { 134 if p != nil { 135 applied = p.Apply(applied) 136 } 137 } 138 139 return applied 140 } 141 142 func (q querier) With(t sqlbuilder.Table, build sqlbuilder.BuildSubQuery, modifiers ...string) Querier { 143 q.tables = append(q.tables, t) 144 if q.withStmt == nil { 145 q.withStmt = sqlbuilder.With(t, build, modifiers...) 146 return &q 147 } 148 q.withStmt = q.withStmt.With(t, build) 149 return &q 150 } 151 152 func (q querier) CrossJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier { 153 q.tables = append(q.tables, t) 154 q.joins = append(q.joins, sqlbuilder.CrossJoin(t).On(sqlbuilder.AsCond(on))) 155 return &q 156 } 157 158 func (q querier) LeftJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier { 159 q.tables = append(q.tables, t) 160 q.joins = append(q.joins, sqlbuilder.LeftJoin(t).On(sqlbuilder.AsCond(on))) 161 return &q 162 } 163 164 func (q querier) RightJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier { 165 q.tables = append(q.tables, t) 166 q.joins = append(q.joins, sqlbuilder.RightJoin(t).On(sqlbuilder.AsCond(on))) 167 return &q 168 } 169 170 func (q querier) FullJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier { 171 q.tables = append(q.tables, t) 172 q.joins = append(q.joins, sqlbuilder.FullJoin(t).On(sqlbuilder.AsCond(on))) 173 return &q 174 } 175 176 func (q querier) Join(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier { 177 q.tables = append(q.tables, t) 178 q.joins = append(q.joins, sqlbuilder.Join(t).On(sqlbuilder.AsCond(on))) 179 return &q 180 } 181 182 func (q *querier) IsNil() bool { 183 if q.whereStmtNotEmpty { 184 return sqlbuilder.IsNilExpr(q.where) || q.from == nil 185 } 186 return q.from == nil 187 } 188 189 func (q *querier) Ex(ctx context.Context) *sqlbuilder.Ex { 190 return q.build().Ex(ctx) 191 } 192 193 func resolveModel(v any) any { 194 if canNew, ok := v.(interface{ New() any }); ok { 195 return canNew.New() 196 } else { 197 tpe := reflect.TypeOf(v) 198 for tpe.Kind() == reflect.Ptr { 199 tpe = tpe.Elem() 200 } 201 if tpe.Kind() == reflect.Struct { 202 return reflect.New(tpe).Interface().(sqlbuilder.Model) 203 } 204 } 205 return nil 206 } 207 208 func (q querier) Scan(v any) Querier { 209 if len(q.projects) == 0 { 210 if m, ok := resolveModel(v).(sqlbuilder.Model); ok { 211 q.projects = []sqlbuilder.SqlExpr{sqlbuilder.ColumnsByStruct(m)} 212 } 213 } 214 q.recv = v 215 return &q 216 } 217 218 func (q querier) Select(projects ...sqlbuilder.SqlExpr) Querier { 219 q.projects = projects 220 return &q 221 } 222 223 func (q querier) Where(where sqlbuilder.SqlExpr) Querier { 224 q.where = where 225 return &q 226 } 227 228 func (q querier) WhereAnd(where sqlbuilder.SqlExpr) Querier { 229 q.where = sqlbuilder.And(q.where, where) 230 return &q 231 } 232 233 func (q querier) WhereOr(where sqlbuilder.SqlExpr) Querier { 234 q.where = sqlbuilder.Or(q.where, where) 235 return &q 236 } 237 238 func (q querier) OrderBy(orders ...*sqlbuilder.Order) Querier { 239 q.orders = orders 240 return &q 241 } 242 243 func (q querier) GroupBy(cols ...sqlbuilder.SqlExpr) Querier { 244 q.groupBy = cols 245 return &q 246 } 247 248 func (q querier) Having(having sqlbuilder.SqlExpr) Querier { 249 q.having = having 250 return &q 251 } 252 253 func (q querier) Limit(v int64) Querier { 254 q.limit = v 255 return &q 256 } 257 258 func (q querier) Offset(v int64) Querier { 259 q.offset = v 260 return &q 261 } 262 263 func (q querier) Distinct(extras ...sqlbuilder.SqlExpr) Querier { 264 q.distinct = extras 265 return &q 266 } 267 268 func (q *querier) buildWhere(t sqlbuilder.Table) sqlbuilder.SqlExpr { 269 if q.feature.softDelete { 270 if newModel, ok := q.from.(interface{ New() sqlbuilder.Model }); ok { 271 m := newModel.New() 272 if soft, ok := m.(ModelWithSoftDelete); ok { 273 f, _ := soft.SoftDeleteFieldAndZeroValue() 274 return sqlbuilder.And( 275 q.where, 276 sqlbuilder.CastCol[int](t.F(f)).V(sqlbuilder.Eq(0)), 277 ) 278 } 279 } 280 } 281 return q.where 282 } 283 284 func (q *querier) build() sqlbuilder.SqlExpr { 285 from := q.from 286 287 modifies := make([]sqlbuilder.SqlExpr, 0) 288 289 if q.distinct != nil { 290 modifies = append(modifies, sqlbuilder.Expr("DISTINCT")) 291 292 if len(q.distinct) > 0 { 293 modifies = append(modifies, q.distinct...) 294 } 295 } 296 297 additions := make([]sqlbuilder.Addition, 0, 10) 298 299 if where := q.buildWhere(from); where != nil { 300 additions = append(additions, sqlbuilder.Where(sqlbuilder.AsCond(where))) 301 } 302 303 if n := len(q.joins); n > 0 { 304 additions = append(additions, q.joins...) 305 } 306 307 if n := len(q.orders); n > 0 { 308 additions = append(additions, sqlbuilder.OrderBy(q.orders...)) 309 } 310 311 if n := len(q.groupBy); n > 0 { 312 additions = append(additions, sqlbuilder.GroupBy(q.groupBy...).Having(sqlbuilder.AsCond(q.having))) 313 } 314 315 if q.limit > 0 { 316 additions = append(additions, sqlbuilder.Limit(q.limit).Offset(q.offset)) 317 } 318 319 var projects sqlbuilder.SqlExpr 320 321 if q.projects != nil { 322 projects = sqlbuilder.MultiMayAutoAlias(q.projects...) 323 } 324 325 if q.withStmt != nil { 326 return q.withStmt.Exec(func(tables ...sqlbuilder.Table) sqlbuilder.SqlExpr { 327 return sqlbuilder.Select(projects, modifies...).From(from, additions...) 328 }) 329 } 330 331 return sqlbuilder.Select(projects, modifies...).From(from, additions...) 332 } 333 334 func (q *querier) Count(ctx context.Context) (int, error) { 335 var c int 336 if err := q.Limit(-1).Select(sqlbuilder.Count()).Scan(&c).Find(ctx); err != nil { 337 return 0, err 338 } 339 return c, nil 340 } 341 342 func (q *querier) Find(ctx context.Context) error { 343 s := SessionFor(ctx, q.from) 344 345 if q.recv == nil { 346 return errors.New("missing receiver. need to use Scan to bind one") 347 } 348 rows, err := s.Adapter().Query(ctx, q.build()) 349 if err != nil { 350 return err 351 } 352 353 done := make(chan error) 354 355 go func() { 356 defer close(done) 357 358 if err := scanner.Scan(ctx, rows, q.recv); err != nil { 359 if errors.Is(err, ErrSkipScan) || errors.Is(err, context.Canceled) { 360 done <- nil 361 return 362 } 363 done <- err 364 } 365 }() 366 367 select { 368 case <-ctx.Done(): 369 return nil 370 default: 371 return <-done 372 } 373 } 374 375 type ScanIterator = scanner.ScanIterator 376 377 var ErrSkipScan = errors.New("scan skip") 378 379 func Recv[T any](next func(v *T) error) ScanIterator { 380 return &typedScanner[T]{next: next} 381 } 382 383 type typedScanner[T any] struct { 384 next func(v *T) error 385 } 386 387 func (*typedScanner[T]) New() any { 388 return new(T) 389 } 390 391 func (t *typedScanner[T]) Next(v any) error { 392 return t.next(v.(*T)) 393 } 394 395 type TemporaryTable interface { 396 sqlbuilder.Table 397 TableWrapper 398 QuerierPatcher 399 } 400 401 func (q *querier) AsTemporaryTable(tableName string) TemporaryTable { 402 projects := q.projects 403 404 cols := make([]sqlbuilder.TableDefinition, 0, len(projects)) 405 406 for _, p := range projects { 407 if col, ok := p.(sqlbuilder.Column); ok { 408 cols = append(cols, col) 409 } 410 } 411 412 tmpT := sqlbuilder.T(tableName, cols...) 413 414 return &tmpTable{ 415 Table: tmpT, 416 origin: q.from, 417 build: func(table sqlbuilder.Table) sqlbuilder.SqlExpr { 418 return q 419 }, 420 } 421 } 422 423 type tmpTable struct { 424 sqlbuilder.Table 425 origin sqlbuilder.Table 426 build func(table sqlbuilder.Table) sqlbuilder.SqlExpr 427 } 428 429 func (t *tmpTable) Unwrap() sqlbuilder.Model { 430 return t.origin 431 } 432 433 func (t *tmpTable) Apply(q Querier) Querier { 434 return q.With(t.Table, t.build) 435 }