github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_select.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"sync"
    20  
    21  	"github.com/ecodeclub/eorm/internal/merger/batchmerger"
    22  
    23  	"github.com/ecodeclub/eorm/internal/sharding"
    24  
    25  	"github.com/ecodeclub/eorm/internal/errs"
    26  	"github.com/valyala/bytebufferpool"
    27  )
    28  
    29  type ShardingSelector[T any] struct {
    30  	shardingSelectorBuilder
    31  	table *T
    32  	db    Session
    33  	lock  sync.Mutex
    34  }
    35  
    36  func NewShardingSelector[T any](db Session) *ShardingSelector[T] {
    37  	b := shardingSelectorBuilder{}
    38  	b.core = db.getCore()
    39  	b.buffer = bytebufferpool.Get()
    40  	return &ShardingSelector[T]{
    41  		shardingSelectorBuilder: b,
    42  		db:                      db,
    43  	}
    44  }
    45  
    46  func (s *ShardingSelector[T]) Build(ctx context.Context) ([]sharding.Query, error) {
    47  	var err error
    48  	if s.meta == nil {
    49  		s.meta, err = s.metaRegistry.Get(new(T))
    50  		if err != nil {
    51  			return nil, err
    52  		}
    53  	}
    54  	shardingRes, err := s.findDst(ctx, s.where...)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	res := make([]sharding.Query, 0, len(shardingRes.Dsts))
    59  	defer bytebufferpool.Put(s.buffer)
    60  	for _, dst := range shardingRes.Dsts {
    61  		q, err := s.buildQuery(dst.DB, dst.Table, dst.Name)
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		res = append(res, q)
    66  		s.args = nil
    67  		s.buffer.Reset()
    68  	}
    69  	return res, nil
    70  }
    71  
    72  func (s *ShardingSelector[T]) buildQuery(db, tbl, ds string) (sharding.Query, error) {
    73  	var err error
    74  	s.writeString("SELECT ")
    75  	if len(s.columns) == 0 {
    76  		if err = s.buildAllColumns(); err != nil {
    77  			return sharding.EmptyQuery, err
    78  		}
    79  	} else {
    80  		err = s.buildSelectedList()
    81  		if err != nil {
    82  			return sharding.EmptyQuery, err
    83  		}
    84  	}
    85  	s.writeString(" FROM ")
    86  	s.quote(db)
    87  	s.writeByte('.')
    88  	s.quote(tbl)
    89  
    90  	if len(s.where) > 0 {
    91  		s.writeString(" WHERE ")
    92  		p := s.where[0]
    93  		for i := 1; i < len(s.where); i++ {
    94  			p = p.And(s.where[i])
    95  		}
    96  		if err = s.buildExpr(p); err != nil {
    97  			return sharding.EmptyQuery, err
    98  		}
    99  	}
   100  
   101  	// group by
   102  	if len(s.groupBy) > 0 {
   103  		err = s.buildGroupBy()
   104  		if err != nil {
   105  			return sharding.EmptyQuery, err
   106  		}
   107  	}
   108  
   109  	// order by
   110  	if len(s.orderBy) > 0 {
   111  		err = s.buildOrderBy()
   112  		if err != nil {
   113  			return sharding.EmptyQuery, err
   114  		}
   115  	}
   116  
   117  	// having
   118  	if len(s.having) > 0 {
   119  		s.writeString(" HAVING ")
   120  		p := s.having[0]
   121  		for i := 1; i < len(s.having); i++ {
   122  			p = p.And(s.having[i])
   123  		}
   124  		if err = s.buildExpr(p); err != nil {
   125  			return sharding.EmptyQuery, err
   126  		}
   127  	}
   128  
   129  	if s.offset > 0 {
   130  		s.writeString(" OFFSET ")
   131  		s.parameter(s.offset)
   132  	}
   133  
   134  	if s.limit > 0 {
   135  		s.writeString(" LIMIT ")
   136  		s.parameter(s.limit)
   137  	}
   138  	s.end()
   139  	return sharding.Query{SQL: s.buffer.String(), Args: s.args, Datasource: ds, DB: db}, nil
   140  }
   141  
   142  func (s *ShardingSelector[T]) buildAllColumns() error {
   143  	for i, cMeta := range s.meta.Columns {
   144  		_ = s.buildColumns(i, cMeta.FieldName)
   145  	}
   146  	return nil
   147  }
   148  
   149  func (s *ShardingSelector[T]) buildSelectedList() error {
   150  	for i, selectable := range s.columns {
   151  		if i > 0 {
   152  			s.comma()
   153  		}
   154  		switch expr := selectable.(type) {
   155  		case Column:
   156  			err := s.builder.buildColumn(expr)
   157  			if err != nil {
   158  				return errs.NewInvalidFieldError(expr.name)
   159  			}
   160  		case columns:
   161  			for j, c := range expr.cs {
   162  				err := s.buildColumns(j, c)
   163  				if err != nil {
   164  					return err
   165  				}
   166  			}
   167  		case Aggregate:
   168  			if err := s.selectAggregate(expr); err != nil {
   169  				return err
   170  			}
   171  		case RawExpr:
   172  			s.buildRawExpr(expr)
   173  		}
   174  	}
   175  	return nil
   176  
   177  }
   178  func (s *ShardingSelector[T]) selectAggregate(aggregate Aggregate) error {
   179  	s.writeString(aggregate.fn)
   180  
   181  	s.writeByte('(')
   182  	if aggregate.distinct {
   183  		s.writeString("DISTINCT ")
   184  	}
   185  	cMeta, ok := s.meta.FieldMap[aggregate.arg]
   186  	if !ok {
   187  		return errs.NewInvalidFieldError(aggregate.arg)
   188  	}
   189  	if aggregate.table != nil {
   190  		if alias := aggregate.table.getAlias(); alias != "" {
   191  			s.quote(alias)
   192  			s.point()
   193  		}
   194  	}
   195  	s.quote(cMeta.ColumnName)
   196  	s.writeByte(')')
   197  	if aggregate.alias != "" {
   198  		s.writeString(" AS ")
   199  		s.quote(aggregate.alias)
   200  	}
   201  	return nil
   202  }
   203  
   204  func (s *ShardingSelector[T]) buildColumns(index int, name string) error {
   205  	if index > 0 {
   206  		s.comma()
   207  	}
   208  	cMeta, ok := s.meta.FieldMap[name]
   209  	if !ok {
   210  		return errs.NewInvalidFieldError(name)
   211  	}
   212  	s.quote(cMeta.ColumnName)
   213  	return nil
   214  }
   215  
   216  func (s *ShardingSelector[T]) buildExpr(expr Expr) error {
   217  	switch exp := expr.(type) {
   218  	case nil:
   219  	case Column:
   220  		exp.alias = ""
   221  		_ = s.buildColumn(exp)
   222  	case valueExpr:
   223  		s.parameter(exp.val)
   224  	case RawExpr:
   225  		s.buildRawExpr(exp)
   226  	case Predicate:
   227  		if err := s.buildBinaryExpr(binaryExpr(exp)); err != nil {
   228  			return err
   229  		}
   230  	default:
   231  		return errs.NewErrUnsupportedExpressionType()
   232  	}
   233  	return nil
   234  }
   235  
   236  func (s *ShardingSelector[T]) buildOrderBy() error {
   237  	s.writeString(" ORDER BY ")
   238  	for i, ob := range s.orderBy {
   239  		if i > 0 {
   240  			s.comma()
   241  		}
   242  		for _, c := range ob.fields {
   243  			cMeta, ok := s.meta.FieldMap[c]
   244  			if !ok {
   245  				return errs.NewInvalidFieldError(c)
   246  			}
   247  			s.quote(cMeta.ColumnName)
   248  		}
   249  		s.space()
   250  		s.writeString(ob.order)
   251  	}
   252  	return nil
   253  }
   254  
   255  func (s *ShardingSelector[T]) buildGroupBy() error {
   256  	s.writeString(" GROUP BY ")
   257  	for i, gb := range s.groupBy {
   258  		cMeta, ok := s.meta.FieldMap[gb]
   259  		if !ok {
   260  			return errs.NewInvalidFieldError(gb)
   261  		}
   262  		if i > 0 {
   263  			s.comma()
   264  		}
   265  		s.quote(cMeta.ColumnName)
   266  	}
   267  	return nil
   268  }
   269  
   270  func (s *ShardingSelector[T]) Get(ctx context.Context) (*T, error) {
   271  	qs, err := s.Limit(1).Build(ctx)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  	if len(qs) == 0 {
   276  		return nil, errs.ErrNotGenShardingQuery
   277  	}
   278  	// TODO 要确保前面的改写 SQL 只能生成一个 SQL
   279  	if len(qs) > 1 {
   280  		return nil, errs.ErrOnlyResultOneQuery
   281  	}
   282  	q := qs[0]
   283  	// TODO 利用 ctx 传递 DB name
   284  	row, err := s.db.queryContext(ctx, q)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  	if !row.Next() {
   289  		return nil, ErrNoRows
   290  	}
   291  	tp := new(T)
   292  	val := s.valCreator.NewPrimitiveValue(tp, s.meta)
   293  	if err = val.SetColumns(row); err != nil {
   294  		return nil, err
   295  	}
   296  	return tp, nil
   297  }
   298  
   299  func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) {
   300  	qs, err := s.Build(ctx)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	mgr := batchmerger.NewMerger()
   306  	rowsList, err := s.db.queryMulti(ctx, qs)
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  	rows, err := mgr.Merge(ctx, rowsList.AsSlice())
   311  	if err != nil {
   312  		return nil, err
   313  	}
   314  	defer rows.Close()
   315  	var res []*T
   316  	for rows.Next() {
   317  		tp := new(T)
   318  		val := s.valCreator.NewPrimitiveValue(tp, s.meta)
   319  		if err = val.SetColumns(rows); err != nil {
   320  			return nil, err
   321  		}
   322  		res = append(res, tp)
   323  	}
   324  	return res, nil
   325  }
   326  
   327  // Select 指定查询的列。
   328  // 列可以是物理列,也可以是聚合函数,或者 RawExpr
   329  func (s *ShardingSelector[T]) Select(columns ...Selectable) *ShardingSelector[T] {
   330  	s.columns = columns
   331  	return s
   332  }
   333  
   334  // From specifies the table which must be pointer of structure
   335  func (s *ShardingSelector[T]) From(tbl *T) *ShardingSelector[T] {
   336  	s.table = tbl
   337  	return s
   338  }
   339  
   340  // Where accepts predicates
   341  func (s *ShardingSelector[T]) Where(predicates ...Predicate) *ShardingSelector[T] {
   342  	s.where = predicates
   343  	return s
   344  }
   345  
   346  // Having accepts predicates
   347  func (s *ShardingSelector[T]) Having(predicates ...Predicate) *ShardingSelector[T] {
   348  	s.having = predicates
   349  	return s
   350  }
   351  
   352  // GroupBy means "GROUP BY"
   353  func (s *ShardingSelector[T]) GroupBy(columns ...string) *ShardingSelector[T] {
   354  	s.groupBy = columns
   355  	return s
   356  }
   357  
   358  // OrderBy means "ORDER BY"
   359  func (s *ShardingSelector[T]) OrderBy(orderBys ...OrderBy) *ShardingSelector[T] {
   360  	s.orderBy = orderBys
   361  	return s
   362  }
   363  
   364  // Limit limits the size of result set
   365  func (s *ShardingSelector[T]) Limit(limit int) *ShardingSelector[T] {
   366  	s.limit = limit
   367  	return s
   368  }
   369  
   370  // Offset was used by "LIMIT"
   371  func (s *ShardingSelector[T]) Offset(offset int) *ShardingSelector[T] {
   372  	s.offset = offset
   373  	return s
   374  }