github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/dal/querier.go (about)

     1  package dal
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  
     7  	"github.com/octohelm/storage/internal/sql/scanner"
     8  	"github.com/octohelm/storage/pkg/sqlbuilder"
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  //Intersect(q Querier) Querier
    13  //Except(q Querier) Querier
    14  
    15  func InSelect[T any](col sqlbuilder.TypedColumn[T], q Querier) sqlbuilder.ColumnValueExpr[T] {
    16  	return func(v sqlbuilder.Column) sqlbuilder.SqlExpr {
    17  		ex := q.Select(col)
    18  		if ex.IsNil() {
    19  			return nil
    20  		}
    21  		return sqlbuilder.Expr("? IN (?)", v, ex)
    22  	}
    23  }
    24  
    25  func NotInSelect[T any](col sqlbuilder.TypedColumn[T], q Querier) sqlbuilder.ColumnValueExpr[T] {
    26  	return func(v sqlbuilder.Column) sqlbuilder.SqlExpr {
    27  		ex := q.Select(col)
    28  		if ex.IsNil() {
    29  			return nil
    30  		}
    31  		return sqlbuilder.Expr("? NOT IN (?)", v, ex)
    32  	}
    33  }
    34  
    35  type QuerierPatcher interface {
    36  	Apply(q Querier) Querier
    37  }
    38  
    39  type Querier interface {
    40  	sqlbuilder.SqlExpr
    41  
    42  	ExistsTable(table sqlbuilder.Table) bool
    43  	Apply(patchers ...QuerierPatcher) Querier
    44  
    45  	With(t sqlbuilder.Table, build sqlbuilder.BuildSubQuery, modifiers ...string) Querier
    46  	AsTemporaryTable(tableName string) TemporaryTable
    47  
    48  	Join(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier
    49  	CrossJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier
    50  	LeftJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier
    51  	RightJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier
    52  	FullJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier
    53  
    54  	Where(where sqlbuilder.SqlExpr) Querier
    55  	WhereAnd(where sqlbuilder.SqlExpr) Querier
    56  	WhereOr(where sqlbuilder.SqlExpr) Querier
    57  
    58  	OrderBy(orders ...*sqlbuilder.Order) Querier
    59  
    60  	GroupBy(cols ...sqlbuilder.SqlExpr) Querier
    61  	Having(where sqlbuilder.SqlExpr) Querier
    62  
    63  	Limit(v int64) Querier
    64  	Offset(v int64) Querier
    65  
    66  	Distinct(extras ...sqlbuilder.SqlExpr) Querier
    67  	Select(projects ...sqlbuilder.SqlExpr) Querier
    68  
    69  	Scan(v any) Querier
    70  
    71  	Find(ctx context.Context) error
    72  	Count(ctx context.Context) (int, error)
    73  }
    74  
    75  func From(from sqlbuilder.Table, fns ...OptionFunc) Querier {
    76  	q := &querier{
    77  		from:   from,
    78  		tables: []sqlbuilder.Table{from},
    79  		limit:  -1,
    80  		feature: feature{
    81  			softDelete: true,
    82  		},
    83  	}
    84  
    85  	for i := range fns {
    86  		fns[i](q)
    87  	}
    88  
    89  	if tmpT, ok := from.(QuerierPatcher); ok {
    90  		return q.Apply(tmpT)
    91  	}
    92  
    93  	return q
    94  }
    95  
    96  type querier struct {
    97  	from   sqlbuilder.Table
    98  	tables []sqlbuilder.Table
    99  
   100  	withStmt *sqlbuilder.WithStmt
   101  
   102  	orders []*sqlbuilder.Order
   103  
   104  	distinct []sqlbuilder.SqlExpr
   105  	groupBy  []sqlbuilder.SqlExpr
   106  	having   sqlbuilder.SqlExpr
   107  
   108  	limit  int64
   109  	offset int64
   110  
   111  	where    sqlbuilder.SqlExpr
   112  	projects []sqlbuilder.SqlExpr
   113  
   114  	joins []sqlbuilder.Addition
   115  
   116  	feature
   117  
   118  	recv any
   119  }
   120  
   121  func (q *querier) ExistsTable(table sqlbuilder.Table) bool {
   122  	for _, t := range q.tables {
   123  		if t == table || t.TableName() == table.TableName() {
   124  			return true
   125  		}
   126  	}
   127  	return false
   128  }
   129  
   130  func (q *querier) Apply(patchers ...QuerierPatcher) Querier {
   131  	var applied Querier = q
   132  
   133  	for _, p := range patchers {
   134  		if p != nil {
   135  			applied = p.Apply(applied)
   136  		}
   137  	}
   138  
   139  	return applied
   140  }
   141  
   142  func (q querier) With(t sqlbuilder.Table, build sqlbuilder.BuildSubQuery, modifiers ...string) Querier {
   143  	q.tables = append(q.tables, t)
   144  	if q.withStmt == nil {
   145  		q.withStmt = sqlbuilder.With(t, build, modifiers...)
   146  		return &q
   147  	}
   148  	q.withStmt = q.withStmt.With(t, build)
   149  	return &q
   150  }
   151  
   152  func (q querier) CrossJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier {
   153  	q.tables = append(q.tables, t)
   154  	q.joins = append(q.joins, sqlbuilder.CrossJoin(t).On(sqlbuilder.AsCond(on)))
   155  	return &q
   156  }
   157  
   158  func (q querier) LeftJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier {
   159  	q.tables = append(q.tables, t)
   160  	q.joins = append(q.joins, sqlbuilder.LeftJoin(t).On(sqlbuilder.AsCond(on)))
   161  	return &q
   162  }
   163  
   164  func (q querier) RightJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier {
   165  	q.tables = append(q.tables, t)
   166  	q.joins = append(q.joins, sqlbuilder.RightJoin(t).On(sqlbuilder.AsCond(on)))
   167  	return &q
   168  }
   169  
   170  func (q querier) FullJoin(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier {
   171  	q.tables = append(q.tables, t)
   172  	q.joins = append(q.joins, sqlbuilder.FullJoin(t).On(sqlbuilder.AsCond(on)))
   173  	return &q
   174  }
   175  
   176  func (q querier) Join(t sqlbuilder.Table, on sqlbuilder.SqlExpr) Querier {
   177  	q.tables = append(q.tables, t)
   178  	q.joins = append(q.joins, sqlbuilder.Join(t).On(sqlbuilder.AsCond(on)))
   179  	return &q
   180  }
   181  
   182  func (q *querier) IsNil() bool {
   183  	if q.whereStmtNotEmpty {
   184  		return sqlbuilder.IsNilExpr(q.where) || q.from == nil
   185  	}
   186  	return q.from == nil
   187  }
   188  
   189  func (q *querier) Ex(ctx context.Context) *sqlbuilder.Ex {
   190  	return q.build().Ex(ctx)
   191  }
   192  
   193  func resolveModel(v any) any {
   194  	if canNew, ok := v.(interface{ New() any }); ok {
   195  		return canNew.New()
   196  	} else {
   197  		tpe := reflect.TypeOf(v)
   198  		for tpe.Kind() == reflect.Ptr {
   199  			tpe = tpe.Elem()
   200  		}
   201  		if tpe.Kind() == reflect.Struct {
   202  			return reflect.New(tpe).Interface().(sqlbuilder.Model)
   203  		}
   204  	}
   205  	return nil
   206  }
   207  
   208  func (q querier) Scan(v any) Querier {
   209  	if len(q.projects) == 0 {
   210  		if m, ok := resolveModel(v).(sqlbuilder.Model); ok {
   211  			q.projects = []sqlbuilder.SqlExpr{sqlbuilder.ColumnsByStruct(m)}
   212  		}
   213  	}
   214  	q.recv = v
   215  	return &q
   216  }
   217  
   218  func (q querier) Select(projects ...sqlbuilder.SqlExpr) Querier {
   219  	q.projects = projects
   220  	return &q
   221  }
   222  
   223  func (q querier) Where(where sqlbuilder.SqlExpr) Querier {
   224  	q.where = where
   225  	return &q
   226  }
   227  
   228  func (q querier) WhereAnd(where sqlbuilder.SqlExpr) Querier {
   229  	q.where = sqlbuilder.And(q.where, where)
   230  	return &q
   231  }
   232  
   233  func (q querier) WhereOr(where sqlbuilder.SqlExpr) Querier {
   234  	q.where = sqlbuilder.Or(q.where, where)
   235  	return &q
   236  }
   237  
   238  func (q querier) OrderBy(orders ...*sqlbuilder.Order) Querier {
   239  	q.orders = orders
   240  	return &q
   241  }
   242  
   243  func (q querier) GroupBy(cols ...sqlbuilder.SqlExpr) Querier {
   244  	q.groupBy = cols
   245  	return &q
   246  }
   247  
   248  func (q querier) Having(having sqlbuilder.SqlExpr) Querier {
   249  	q.having = having
   250  	return &q
   251  }
   252  
   253  func (q querier) Limit(v int64) Querier {
   254  	q.limit = v
   255  	return &q
   256  }
   257  
   258  func (q querier) Offset(v int64) Querier {
   259  	q.offset = v
   260  	return &q
   261  }
   262  
   263  func (q querier) Distinct(extras ...sqlbuilder.SqlExpr) Querier {
   264  	q.distinct = extras
   265  	return &q
   266  }
   267  
   268  func (q *querier) buildWhere(t sqlbuilder.Table) sqlbuilder.SqlExpr {
   269  	if q.feature.softDelete {
   270  		if newModel, ok := q.from.(interface{ New() sqlbuilder.Model }); ok {
   271  			m := newModel.New()
   272  			if soft, ok := m.(ModelWithSoftDelete); ok {
   273  				f, _ := soft.SoftDeleteFieldAndZeroValue()
   274  				return sqlbuilder.And(
   275  					q.where,
   276  					sqlbuilder.CastCol[int](t.F(f)).V(sqlbuilder.Eq(0)),
   277  				)
   278  			}
   279  		}
   280  	}
   281  	return q.where
   282  }
   283  
   284  func (q *querier) build() sqlbuilder.SqlExpr {
   285  	from := q.from
   286  
   287  	modifies := make([]sqlbuilder.SqlExpr, 0)
   288  
   289  	if q.distinct != nil {
   290  		modifies = append(modifies, sqlbuilder.Expr("DISTINCT"))
   291  
   292  		if len(q.distinct) > 0 {
   293  			modifies = append(modifies, q.distinct...)
   294  		}
   295  	}
   296  
   297  	additions := make([]sqlbuilder.Addition, 0, 10)
   298  
   299  	if where := q.buildWhere(from); where != nil {
   300  		additions = append(additions, sqlbuilder.Where(sqlbuilder.AsCond(where)))
   301  	}
   302  
   303  	if n := len(q.joins); n > 0 {
   304  		additions = append(additions, q.joins...)
   305  	}
   306  
   307  	if n := len(q.orders); n > 0 {
   308  		additions = append(additions, sqlbuilder.OrderBy(q.orders...))
   309  	}
   310  
   311  	if n := len(q.groupBy); n > 0 {
   312  		additions = append(additions, sqlbuilder.GroupBy(q.groupBy...).Having(sqlbuilder.AsCond(q.having)))
   313  	}
   314  
   315  	if q.limit > 0 {
   316  		additions = append(additions, sqlbuilder.Limit(q.limit).Offset(q.offset))
   317  	}
   318  
   319  	var projects sqlbuilder.SqlExpr
   320  
   321  	if q.projects != nil {
   322  		projects = sqlbuilder.MultiMayAutoAlias(q.projects...)
   323  	}
   324  
   325  	if q.withStmt != nil {
   326  		return q.withStmt.Exec(func(tables ...sqlbuilder.Table) sqlbuilder.SqlExpr {
   327  			return sqlbuilder.Select(projects, modifies...).From(from, additions...)
   328  		})
   329  	}
   330  
   331  	return sqlbuilder.Select(projects, modifies...).From(from, additions...)
   332  }
   333  
   334  func (q *querier) Count(ctx context.Context) (int, error) {
   335  	var c int
   336  	if err := q.Limit(-1).Select(sqlbuilder.Count()).Scan(&c).Find(ctx); err != nil {
   337  		return 0, err
   338  	}
   339  	return c, nil
   340  }
   341  
   342  func (q *querier) Find(ctx context.Context) error {
   343  	s := SessionFor(ctx, q.from)
   344  
   345  	if q.recv == nil {
   346  		return errors.New("missing receiver. need to use Scan to bind one")
   347  	}
   348  	rows, err := s.Adapter().Query(ctx, q.build())
   349  	if err != nil {
   350  		return err
   351  	}
   352  
   353  	done := make(chan error)
   354  
   355  	go func() {
   356  		defer close(done)
   357  
   358  		if err := scanner.Scan(ctx, rows, q.recv); err != nil {
   359  			if errors.Is(err, ErrSkipScan) || errors.Is(err, context.Canceled) {
   360  				done <- nil
   361  				return
   362  			}
   363  			done <- err
   364  		}
   365  	}()
   366  
   367  	select {
   368  	case <-ctx.Done():
   369  		return nil
   370  	default:
   371  		return <-done
   372  	}
   373  }
   374  
   375  type ScanIterator = scanner.ScanIterator
   376  
   377  var ErrSkipScan = errors.New("scan skip")
   378  
   379  func Recv[T any](next func(v *T) error) ScanIterator {
   380  	return &typedScanner[T]{next: next}
   381  }
   382  
   383  type typedScanner[T any] struct {
   384  	next func(v *T) error
   385  }
   386  
   387  func (*typedScanner[T]) New() any {
   388  	return new(T)
   389  }
   390  
   391  func (t *typedScanner[T]) Next(v any) error {
   392  	return t.next(v.(*T))
   393  }
   394  
   395  type TemporaryTable interface {
   396  	sqlbuilder.Table
   397  	TableWrapper
   398  	QuerierPatcher
   399  }
   400  
   401  func (q *querier) AsTemporaryTable(tableName string) TemporaryTable {
   402  	projects := q.projects
   403  
   404  	cols := make([]sqlbuilder.TableDefinition, 0, len(projects))
   405  
   406  	for _, p := range projects {
   407  		if col, ok := p.(sqlbuilder.Column); ok {
   408  			cols = append(cols, col)
   409  		}
   410  	}
   411  
   412  	tmpT := sqlbuilder.T(tableName, cols...)
   413  
   414  	return &tmpTable{
   415  		Table:  tmpT,
   416  		origin: q.from,
   417  		build: func(table sqlbuilder.Table) sqlbuilder.SqlExpr {
   418  			return q
   419  		},
   420  	}
   421  }
   422  
   423  type tmpTable struct {
   424  	sqlbuilder.Table
   425  	origin sqlbuilder.Table
   426  	build  func(table sqlbuilder.Table) sqlbuilder.SqlExpr
   427  }
   428  
   429  func (t *tmpTable) Unwrap() sqlbuilder.Model {
   430  	return t.origin
   431  }
   432  
   433  func (t *tmpTable) Apply(q Querier) Querier {
   434  	return q.With(t.Table, t.build)
   435  }