github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/builder.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  
    21  	"github.com/ecodeclub/eorm/internal/errs"
    22  	"github.com/ecodeclub/eorm/internal/model"
    23  	"github.com/ecodeclub/eorm/internal/query"
    24  	"github.com/valyala/bytebufferpool"
    25  )
    26  
    27  var _ Executor = &Inserter[any]{}
    28  var _ Executor = &Updater[any]{}
    29  var _ Executor = &Deleter[any]{}
    30  
    31  var EmptyQuery = Query{}
    32  
    33  // Query 代表一个查询
    34  type Query = query.Query
    35  
    36  // Querier 查询器,代表最基本的查询
    37  type Querier[T any] struct {
    38  	core
    39  	Session
    40  	qc *QueryContext
    41  }
    42  
    43  // RawQuery 创建一个 Querier 实例
    44  // 泛型参数 T 是目标类型。
    45  // 例如,如果查询 User 的数据, 那么 T 就是 User
    46  func RawQuery[T any](sess Session, sql string, args ...any) Querier[T] {
    47  	return Querier[T]{
    48  		core:    sess.getCore(),
    49  		Session: sess,
    50  		qc: &QueryContext{
    51  			q: Query{
    52  				SQL:  sql,
    53  				Args: args,
    54  			},
    55  			Type: RAW,
    56  		},
    57  	}
    58  }
    59  
    60  func newQuerier[T any](sess Session, q Query, meta *model.TableMeta, typ string) Querier[T] {
    61  	return Querier[T]{
    62  		core:    sess.getCore(),
    63  		Session: sess,
    64  		qc: &QueryContext{
    65  			q:    q,
    66  			meta: meta,
    67  			Type: typ,
    68  		},
    69  	}
    70  }
    71  
    72  // Exec 执行 SQL
    73  func (q Querier[T]) Exec(ctx context.Context) Result {
    74  	var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
    75  		res, err := q.Session.execContext(ctx, qc.q)
    76  		return &QueryResult{Result: res, Err: err}
    77  	}
    78  
    79  	ms := q.ms
    80  	for i := len(ms) - 1; i >= 0; i-- {
    81  		handler = ms[i](handler)
    82  	}
    83  	qr := handler(ctx, q.qc)
    84  	var res sql.Result
    85  	if qr.Result != nil {
    86  		res = qr.Result.(sql.Result)
    87  	}
    88  	return Result{err: qr.Err, res: res}
    89  }
    90  
    91  // Get 执行查询并且返回第一行数据
    92  // 注意在不同的数据库里面,排序可能会不同
    93  // 在没有查找到数据的情况下,会返回 ErrNoRows
    94  func (q Querier[T]) Get(ctx context.Context) (*T, error) {
    95  	res := get[T](ctx, q.Session, q.core, q.qc)
    96  	if res.Err != nil {
    97  		return nil, res.Err
    98  	}
    99  	return res.Result.(*T), nil
   100  }
   101  
   102  type builder struct {
   103  	core
   104  	// 使用 bytebufferpool 以减少内存分配
   105  	// 每次调用 Get 之后不要忘记再调用 Put
   106  	buffer *bytebufferpool.ByteBuffer
   107  	meta   *model.TableMeta
   108  	args   []interface{}
   109  	// aliases map[string]struct{}
   110  }
   111  
   112  func (b *builder) quote(val string) {
   113  	b.writeByte(b.dialect.Quote)
   114  	b.writeString(val)
   115  	b.writeByte(b.dialect.Quote)
   116  }
   117  
   118  func (b *builder) space() {
   119  	b.writeByte(' ')
   120  }
   121  
   122  func (b *builder) point() {
   123  	b.writeByte('.')
   124  }
   125  
   126  func (b *builder) writeString(val string) {
   127  	_, _ = b.buffer.WriteString(val)
   128  }
   129  
   130  func (b *builder) writeByte(c byte) {
   131  	_ = b.buffer.WriteByte(c)
   132  }
   133  
   134  func (b *builder) end() {
   135  	b.writeByte(';')
   136  }
   137  
   138  func (b *builder) comma() {
   139  	b.writeByte(',')
   140  }
   141  
   142  func (b *builder) parameter(arg interface{}) {
   143  	if b.args == nil {
   144  		// TODO 4 may be not a good number
   145  		b.args = make([]interface{}, 0, 4)
   146  	}
   147  	b.writeByte('?')
   148  	b.args = append(b.args, arg)
   149  }
   150  
   151  func (b *builder) buildExpr(expr Expr) error {
   152  	switch e := expr.(type) {
   153  	case nil:
   154  	case RawExpr:
   155  		b.buildRawExpr(e)
   156  	case Column:
   157  		// _, ok := b.aliases[e.name]
   158  		// if ok {
   159  		// 	b.quote(e.name)
   160  		// 	return nil
   161  		// }
   162  		return b.buildColumn(e)
   163  	case Aggregate:
   164  		if err := b.buildHavingAggregate(e); err != nil {
   165  			return err
   166  		}
   167  	case valueExpr:
   168  		b.parameter(e.val)
   169  	case binaryExpr:
   170  		if err := b.buildBinaryExpr(e); err != nil {
   171  			return err
   172  		}
   173  	case Predicate:
   174  		if err := b.buildBinaryExpr(binaryExpr(e)); err != nil {
   175  			return err
   176  		}
   177  	case values:
   178  		if err := b.buildIns(e); err != nil {
   179  			return err
   180  		}
   181  	case Subquery:
   182  		return b.buildSubquery(e, false)
   183  	case SubqueryExpr:
   184  		b.writeString(e.pred)
   185  		b.writeByte(' ')
   186  		return b.buildSubquery(e.s, false)
   187  	default:
   188  		return errs.NewErrUnsupportedExpressionType()
   189  	}
   190  	return nil
   191  }
   192  
   193  func (b *builder) buildPredicates(predicates []Predicate) error {
   194  	p := predicates[0]
   195  	for i := 1; i < len(predicates); i++ {
   196  		p = p.And(predicates[i])
   197  	}
   198  	return b.buildExpr(p)
   199  }
   200  
   201  func (b *builder) buildHavingAggregate(aggregate Aggregate) error {
   202  	b.writeString(aggregate.fn)
   203  
   204  	b.writeByte('(')
   205  	if aggregate.distinct {
   206  		b.writeString("DISTINCT ")
   207  	}
   208  	cMeta, ok := b.meta.FieldMap[aggregate.arg]
   209  	if !ok {
   210  		return errs.NewInvalidFieldError(aggregate.arg)
   211  	}
   212  	b.quote(cMeta.ColumnName)
   213  	b.writeByte(')')
   214  	return nil
   215  }
   216  
   217  func (b *builder) buildBinaryExpr(e binaryExpr) error {
   218  	err := b.buildSubExpr(e.left)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	b.writeString(e.op.Text)
   223  	return b.buildSubExpr(e.right)
   224  }
   225  
   226  func (b *builder) buildRawExpr(e RawExpr) {
   227  	b.writeString(e.raw)
   228  	b.args = append(b.args, e.args...)
   229  }
   230  
   231  func (b *builder) buildSubExpr(subExpr Expr) error {
   232  	switch r := subExpr.(type) {
   233  	case MathExpr:
   234  		b.writeByte('(')
   235  		if err := b.buildBinaryExpr(binaryExpr(r)); err != nil {
   236  			return err
   237  		}
   238  		b.writeByte(')')
   239  	case Predicate:
   240  		b.writeByte('(')
   241  		if err := b.buildBinaryExpr(binaryExpr(r)); err != nil {
   242  			return err
   243  		}
   244  		b.writeByte(')')
   245  	default:
   246  		if err := b.buildExpr(r); err != nil {
   247  			return err
   248  		}
   249  	}
   250  	return nil
   251  }
   252  
   253  func (b *builder) buildIns(is values) error {
   254  	b.writeByte('(')
   255  	for idx, inVal := range is.data {
   256  		if idx > 0 {
   257  			b.writeByte(',')
   258  		}
   259  
   260  		b.args = append(b.args, inVal)
   261  		b.writeByte('?')
   262  
   263  	}
   264  	b.writeByte(')')
   265  	return nil
   266  }
   267  
   268  func (q Querier[T]) GetMulti(ctx context.Context) ([]*T, error) {
   269  	res := getMulti[T](ctx, q.Session, q.core, q.qc)
   270  	if res.Err != nil {
   271  		return nil, res.Err
   272  	}
   273  	return res.Result.([]*T), nil
   274  }
   275  
   276  func (b *builder) buildColumn(c Column) error {
   277  	switch table := c.table.(type) {
   278  	case nil:
   279  		fd, ok := b.meta.FieldMap[c.name]
   280  		// 字段不对,或者说列不对
   281  		if !ok {
   282  			return errs.NewInvalidFieldError(c.name)
   283  		}
   284  		b.quote(fd.ColumnName)
   285  		if c.alias != "" {
   286  			// b.aliases[c.alias] = struct{}{}
   287  			b.writeString(" AS ")
   288  			b.quote(c.alias)
   289  		}
   290  	case Table:
   291  		m, err := b.metaRegistry.Get(table.entity)
   292  		if err != nil {
   293  			return err
   294  		}
   295  		fd, ok := m.FieldMap[c.name]
   296  		if !ok {
   297  			return errs.NewInvalidFieldError(c.name)
   298  		}
   299  		if table.alias != "" {
   300  			b.quote(table.alias)
   301  			b.point()
   302  		}
   303  		b.quote(fd.ColumnName)
   304  		if c.alias != "" {
   305  			b.writeString(" AS ")
   306  			b.quote(c.alias)
   307  		}
   308  	default:
   309  		return errs.NewUnsupportedTableReferenceError(table)
   310  	}
   311  	return nil
   312  }
   313  
   314  // buildSubquery 構建子查詢 SQL,
   315  // useAlias 決定是否顯示別名,即使有別名
   316  func (b *builder) buildSubquery(sub Subquery, useAlias bool) error {
   317  	q, err := sub.q.Build()
   318  	if err != nil {
   319  		return err
   320  	}
   321  	b.writeByte('(')
   322  	// 拿掉最後 ';'
   323  	b.writeString(q.SQL[:len(q.SQL)-1])
   324  	// 因為有 build() ,所以理應 args 也需要跟 SQL 一起處理
   325  	if len(q.Args) > 0 {
   326  		b.addArgs(q.Args...)
   327  	}
   328  	b.writeByte(')')
   329  	if useAlias {
   330  		b.writeString(" AS ")
   331  		b.quote(sub.getAlias())
   332  	}
   333  	return nil
   334  }
   335  
   336  func (b *builder) addArgs(args ...any) {
   337  	if b.args == nil {
   338  		b.args = make([]any, 0, 8)
   339  	}
   340  	b.args = append(b.args, args...)
   341  }