github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/select.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  
    20  	"github.com/ecodeclub/eorm/internal/errs"
    21  	"github.com/valyala/bytebufferpool"
    22  )
    23  
    24  var _ QueryBuilder = &Selector[any]{}
    25  
    26  // Selector select 构造器
    27  type Selector[T any] struct {
    28  	Session
    29  	selectorBuilder
    30  	table TableReference
    31  }
    32  
    33  // NewSelector 创建一个 Selector
    34  func NewSelector[T any](sess Session) *Selector[T] {
    35  	return &Selector[T]{
    36  		selectorBuilder: selectorBuilder{
    37  			builder: builder{
    38  				core:   sess.getCore(),
    39  				buffer: bytebufferpool.Get(),
    40  			},
    41  		},
    42  		Session: sess,
    43  	}
    44  }
    45  
    46  // tableOf -> get selector table
    47  func (s *Selector[T]) tableOf() any {
    48  	switch tb := s.table.(type) {
    49  	case Table:
    50  		return tb.entity
    51  	default:
    52  		// 不使用 new(T) 来规避内存分配
    53  		return (*T)(nil)
    54  	}
    55  }
    56  
    57  // Build returns Select Query
    58  func (s *Selector[T]) Build() (Query, error) {
    59  	defer bytebufferpool.Put(s.buffer)
    60  	var err error
    61  	s.meta, err = s.metaRegistry.Get(s.tableOf())
    62  	if err != nil {
    63  		return EmptyQuery, err
    64  	}
    65  	s.writeString("SELECT ")
    66  	if s.distinct {
    67  		s.writeString("DISTINCT ")
    68  	}
    69  	if len(s.columns) == 0 {
    70  		switch s.table.(type) {
    71  		case Table, nil:
    72  			if err = s.buildAllColumns(); err != nil {
    73  				return EmptyQuery, err
    74  			}
    75  		default:
    76  			return EmptyQuery, errs.NewMustSpecifyColumnsError()
    77  		}
    78  	} else {
    79  		err = s.buildSelectedList()
    80  		if err != nil {
    81  			return EmptyQuery, err
    82  		}
    83  	}
    84  	s.writeString(" FROM ")
    85  
    86  	if err = s.buildTable(s.table); err != nil {
    87  		return EmptyQuery, err
    88  	}
    89  
    90  	if len(s.where) > 0 {
    91  		s.writeString(" WHERE ")
    92  		err = s.buildPredicates(s.where)
    93  		if err != nil {
    94  			return EmptyQuery, err
    95  		}
    96  	}
    97  
    98  	// group by
    99  	if len(s.groupBy) > 0 {
   100  		err = s.buildGroupBy()
   101  		if err != nil {
   102  			return EmptyQuery, err
   103  		}
   104  	}
   105  
   106  	// order by
   107  	if len(s.orderBy) > 0 {
   108  		err = s.buildOrderBy()
   109  		if err != nil {
   110  			return EmptyQuery, err
   111  		}
   112  	}
   113  
   114  	// having
   115  	if len(s.having) > 0 {
   116  		s.writeString(" HAVING ")
   117  		err = s.buildPredicates(s.having)
   118  		if err != nil {
   119  			return EmptyQuery, err
   120  		}
   121  	}
   122  
   123  	if s.offset > 0 {
   124  		s.writeString(" OFFSET ")
   125  		s.parameter(s.offset)
   126  	}
   127  
   128  	if s.limit > 0 {
   129  		s.writeString(" LIMIT ")
   130  		s.parameter(s.limit)
   131  	}
   132  	s.end()
   133  	return Query{SQL: s.buffer.String(), Args: s.args}, nil
   134  }
   135  
   136  func (s *Selector[T]) buildTable(table TableReference) error {
   137  	switch t := table.(type) {
   138  	case nil:
   139  		s.quote(s.meta.TableName)
   140  	case Table:
   141  		m, err := s.metaRegistry.Get(t.entity)
   142  		if err != nil {
   143  			return err
   144  		}
   145  		s.quote(m.TableName)
   146  		if t.alias != "" {
   147  			s.writeString(" AS ")
   148  			s.quote(t.alias)
   149  		}
   150  	case Join:
   151  		if err := s.buildJoin(t); err != nil {
   152  			return err
   153  		}
   154  	case Subquery:
   155  		return s.buildSubquery(t, true)
   156  	default:
   157  		return errs.NewUnsupportedTableReferenceError(table)
   158  	}
   159  	return nil
   160  }
   161  
   162  func (s *Selector[T]) buildOrderBy() error {
   163  	s.writeString(" ORDER BY ")
   164  	for i, ob := range s.orderBy {
   165  		if i > 0 {
   166  			s.comma()
   167  		}
   168  		for _, c := range ob.fields {
   169  			cMeta, ok := s.meta.FieldMap[c]
   170  			if !ok {
   171  				return errs.NewInvalidFieldError(c)
   172  			}
   173  			s.quote(cMeta.ColumnName)
   174  		}
   175  		s.space()
   176  		s.writeString(ob.order)
   177  	}
   178  	return nil
   179  }
   180  
   181  func (s *Selector[T]) buildGroupBy() error {
   182  	s.writeString(" GROUP BY ")
   183  	for i, gb := range s.groupBy {
   184  		cMeta, ok := s.meta.FieldMap[gb]
   185  		if !ok {
   186  			return errs.NewInvalidFieldError(gb)
   187  		}
   188  		if i > 0 {
   189  			s.comma()
   190  		}
   191  		s.quote(cMeta.ColumnName)
   192  	}
   193  	return nil
   194  }
   195  
   196  func (s *Selector[T]) buildAllColumns() error {
   197  	for i, cMeta := range s.meta.Columns {
   198  		// 永远不会返回 error
   199  		_ = s.buildColumns(i, cMeta.FieldName)
   200  	}
   201  	return nil
   202  }
   203  
   204  // buildSelectedList users specify columns
   205  func (s *Selector[T]) buildSelectedList() error {
   206  	for i, selectable := range s.columns {
   207  		if i > 0 {
   208  			s.comma()
   209  		}
   210  		switch expr := selectable.(type) {
   211  		case Column:
   212  			err := s.builder.buildColumn(expr)
   213  			if err != nil {
   214  				return errs.NewInvalidFieldError(expr.name)
   215  			}
   216  		case columns:
   217  			for j, c := range expr.cs {
   218  				err := s.buildColumns(j, c)
   219  				if err != nil {
   220  					return err
   221  				}
   222  			}
   223  		case Aggregate:
   224  			if err := s.selectAggregate(expr); err != nil {
   225  				return err
   226  			}
   227  		case RawExpr:
   228  			s.buildRawExpr(expr)
   229  		}
   230  	}
   231  	return nil
   232  
   233  }
   234  func (s *Selector[T]) selectAggregate(aggregate Aggregate) error {
   235  	s.writeString(aggregate.fn)
   236  
   237  	s.writeByte('(')
   238  	if aggregate.distinct {
   239  		s.writeString("DISTINCT ")
   240  	}
   241  	cMeta, ok := s.meta.FieldMap[aggregate.arg]
   242  	// s.aliases[aggregate.alias] = struct{}{}
   243  	if !ok {
   244  		return errs.NewInvalidFieldError(aggregate.arg)
   245  	}
   246  	if aggregate.table != nil {
   247  		if alias := aggregate.table.getAlias(); alias != "" {
   248  			s.quote(alias)
   249  			s.point()
   250  		}
   251  	}
   252  	s.quote(cMeta.ColumnName)
   253  	s.writeByte(')')
   254  	if aggregate.alias != "" {
   255  		// if _, ok := s.aliases[aggregate.alias]; ok {
   256  		// 	s.writeString(" AS ")
   257  		// 	s.quote(aggregate.alias)
   258  		// }
   259  		s.writeString(" AS ")
   260  		s.quote(aggregate.alias)
   261  	}
   262  	return nil
   263  }
   264  
   265  func (s *Selector[T]) buildColumns(index int, name string) error {
   266  	if index > 0 {
   267  		s.comma()
   268  	}
   269  	cMeta, ok := s.meta.FieldMap[name]
   270  	if !ok {
   271  		return errs.NewInvalidFieldError(name)
   272  	}
   273  	s.quote(cMeta.ColumnName)
   274  	return nil
   275  }
   276  
   277  func (s *Selector[T]) buildUsing(using []string) error {
   278  	s.writeString(" USING (")
   279  	for i, col := range using {
   280  		err := s.buildColumns(i, col)
   281  		if err != nil {
   282  			return err
   283  		}
   284  	}
   285  	s.writeByte(')')
   286  	return nil
   287  }
   288  
   289  // Select 指定查询的列。
   290  // 列可以是物理列,也可以是聚合函数,或者 RawExpr
   291  func (s *Selector[T]) Select(columns ...Selectable) *Selector[T] {
   292  	s.columns = columns
   293  	return s
   294  }
   295  
   296  // From specifies the table which must be pointer of structure
   297  func (s *Selector[T]) From(tbl TableReference) *Selector[T] {
   298  	s.table = tbl
   299  	return s
   300  }
   301  
   302  // Where accepts predicates
   303  func (s *Selector[T]) Where(predicates ...Predicate) *Selector[T] {
   304  	s.where = predicates
   305  	return s
   306  }
   307  
   308  // Distinct indicates using keyword DISTINCT
   309  func (s *Selector[T]) Distinct() *Selector[T] {
   310  	s.distinct = true
   311  	return s
   312  }
   313  
   314  // Having accepts predicates
   315  func (s *Selector[T]) Having(predicates ...Predicate) *Selector[T] {
   316  	s.having = predicates
   317  	return s
   318  }
   319  
   320  // GroupBy means "GROUP BY"
   321  func (s *Selector[T]) GroupBy(columns ...string) *Selector[T] {
   322  	s.groupBy = columns
   323  	return s
   324  }
   325  
   326  // OrderBy means "ORDER BY"
   327  func (s *Selector[T]) OrderBy(orderBys ...OrderBy) *Selector[T] {
   328  	s.orderBy = orderBys
   329  	return s
   330  }
   331  
   332  // Limit limits the size of result set
   333  func (s *Selector[T]) Limit(limit int) *Selector[T] {
   334  	s.limit = limit
   335  	return s
   336  }
   337  
   338  // Offset was used by "LIMIT"
   339  func (s *Selector[T]) Offset(offset int) *Selector[T] {
   340  	s.offset = offset
   341  	return s
   342  }
   343  
   344  func (s *Selector[T]) AsSubquery(alias string) Subquery {
   345  	var table TableReference
   346  	if s.table == nil {
   347  		table = TableOf(new(T), alias)
   348  	}
   349  	return Subquery{
   350  		entity:  table,
   351  		q:       s,
   352  		alias:   alias,
   353  		columns: s.columns,
   354  	}
   355  }
   356  
   357  // Get 方法会执行查询,并且返回一条数据
   358  // 注意,在不同的数据库情况下,第一条数据可能是按照不同的列来排序的
   359  // 而且要注意,这个方法会强制设置 Limit 1
   360  // 在没有查找到数据的情况下,会返回 ErrNoRows
   361  func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
   362  	query, err := s.Limit(1).Build()
   363  	if err != nil {
   364  		return nil, err
   365  	}
   366  	return newQuerier[T](s.Session, query, s.meta, SELECT).Get(ctx)
   367  }
   368  
   369  // OrderBy specify fields and ASC
   370  type OrderBy struct {
   371  	fields []string
   372  	order  string
   373  }
   374  
   375  // ASC means ORDER BY fields ASC
   376  func ASC(fields ...string) OrderBy {
   377  	return OrderBy{
   378  		fields: fields,
   379  		order:  "ASC",
   380  	}
   381  }
   382  
   383  // DESC means ORDER BY fields DESC
   384  func DESC(fields ...string) OrderBy {
   385  	return OrderBy{
   386  		fields: fields,
   387  		order:  "DESC",
   388  	}
   389  }
   390  
   391  // Selectable is a tag interface which represents SELECT XXX
   392  type Selectable interface {
   393  	selected()
   394  }
   395  
   396  func (s *Selector[T]) GetMulti(ctx context.Context) ([]*T, error) {
   397  	query, err := s.Build()
   398  	if err != nil {
   399  		return nil, err
   400  	}
   401  	return newQuerier[T](s.Session, query, s.meta, SELECT).GetMulti(ctx)
   402  }
   403  
   404  func (s *Selector[T]) buildJoin(t Join) error {
   405  	s.writeByte('(')
   406  	if err := s.buildTable(t.left); err != nil {
   407  		return err
   408  	}
   409  	s.space()
   410  	s.writeString(t.typ)
   411  	s.space()
   412  	if err := s.buildTable(t.right); err != nil {
   413  		return err
   414  	}
   415  	if len(t.using) > 0 {
   416  		if err := s.buildUsing(t.using); err != nil {
   417  			return err
   418  		}
   419  	}
   420  	if len(t.on) > 0 {
   421  		s.writeString(" ON ")
   422  		if err := s.buildPredicates(t.on); err != nil {
   423  			return err
   424  		}
   425  	}
   426  	s.writeByte(')')
   427  	return nil
   428  }