github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/select.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 20 "github.com/ecodeclub/eorm/internal/errs" 21 "github.com/valyala/bytebufferpool" 22 ) 23 24 var _ QueryBuilder = &Selector[any]{} 25 26 // Selector select 构造器 27 type Selector[T any] struct { 28 Session 29 selectorBuilder 30 table TableReference 31 } 32 33 // NewSelector 创建一个 Selector 34 func NewSelector[T any](sess Session) *Selector[T] { 35 return &Selector[T]{ 36 selectorBuilder: selectorBuilder{ 37 builder: builder{ 38 core: sess.getCore(), 39 buffer: bytebufferpool.Get(), 40 }, 41 }, 42 Session: sess, 43 } 44 } 45 46 // tableOf -> get selector table 47 func (s *Selector[T]) tableOf() any { 48 switch tb := s.table.(type) { 49 case Table: 50 return tb.entity 51 default: 52 // 不使用 new(T) 来规避内存分配 53 return (*T)(nil) 54 } 55 } 56 57 // Build returns Select Query 58 func (s *Selector[T]) Build() (Query, error) { 59 defer bytebufferpool.Put(s.buffer) 60 var err error 61 s.meta, err = s.metaRegistry.Get(s.tableOf()) 62 if err != nil { 63 return EmptyQuery, err 64 } 65 s.writeString("SELECT ") 66 if s.distinct { 67 s.writeString("DISTINCT ") 68 } 69 if len(s.columns) == 0 { 70 switch s.table.(type) { 71 case Table, nil: 72 if err = s.buildAllColumns(); err != nil { 73 return EmptyQuery, err 74 } 75 default: 76 return EmptyQuery, errs.NewMustSpecifyColumnsError() 77 } 78 } else { 79 err = s.buildSelectedList() 80 if err != nil { 81 return EmptyQuery, err 82 } 83 } 84 s.writeString(" FROM ") 85 86 if err = s.buildTable(s.table); err != nil { 87 return EmptyQuery, err 88 } 89 90 if len(s.where) > 0 { 91 s.writeString(" WHERE ") 92 err = s.buildPredicates(s.where) 93 if err != nil { 94 return EmptyQuery, err 95 } 96 } 97 98 // group by 99 if len(s.groupBy) > 0 { 100 err = s.buildGroupBy() 101 if err != nil { 102 return EmptyQuery, err 103 } 104 } 105 106 // order by 107 if len(s.orderBy) > 0 { 108 err = s.buildOrderBy() 109 if err != nil { 110 return EmptyQuery, err 111 } 112 } 113 114 // having 115 if len(s.having) > 0 { 116 s.writeString(" HAVING ") 117 err = s.buildPredicates(s.having) 118 if err != nil { 119 return EmptyQuery, err 120 } 121 } 122 123 if s.offset > 0 { 124 s.writeString(" OFFSET ") 125 s.parameter(s.offset) 126 } 127 128 if s.limit > 0 { 129 s.writeString(" LIMIT ") 130 s.parameter(s.limit) 131 } 132 s.end() 133 return Query{SQL: s.buffer.String(), Args: s.args}, nil 134 } 135 136 func (s *Selector[T]) buildTable(table TableReference) error { 137 switch t := table.(type) { 138 case nil: 139 s.quote(s.meta.TableName) 140 case Table: 141 m, err := s.metaRegistry.Get(t.entity) 142 if err != nil { 143 return err 144 } 145 s.quote(m.TableName) 146 if t.alias != "" { 147 s.writeString(" AS ") 148 s.quote(t.alias) 149 } 150 case Join: 151 if err := s.buildJoin(t); err != nil { 152 return err 153 } 154 case Subquery: 155 return s.buildSubquery(t, true) 156 default: 157 return errs.NewUnsupportedTableReferenceError(table) 158 } 159 return nil 160 } 161 162 func (s *Selector[T]) buildOrderBy() error { 163 s.writeString(" ORDER BY ") 164 for i, ob := range s.orderBy { 165 if i > 0 { 166 s.comma() 167 } 168 for _, c := range ob.fields { 169 cMeta, ok := s.meta.FieldMap[c] 170 if !ok { 171 return errs.NewInvalidFieldError(c) 172 } 173 s.quote(cMeta.ColumnName) 174 } 175 s.space() 176 s.writeString(ob.order) 177 } 178 return nil 179 } 180 181 func (s *Selector[T]) buildGroupBy() error { 182 s.writeString(" GROUP BY ") 183 for i, gb := range s.groupBy { 184 cMeta, ok := s.meta.FieldMap[gb] 185 if !ok { 186 return errs.NewInvalidFieldError(gb) 187 } 188 if i > 0 { 189 s.comma() 190 } 191 s.quote(cMeta.ColumnName) 192 } 193 return nil 194 } 195 196 func (s *Selector[T]) buildAllColumns() error { 197 for i, cMeta := range s.meta.Columns { 198 // 永远不会返回 error 199 _ = s.buildColumns(i, cMeta.FieldName) 200 } 201 return nil 202 } 203 204 // buildSelectedList users specify columns 205 func (s *Selector[T]) buildSelectedList() error { 206 for i, selectable := range s.columns { 207 if i > 0 { 208 s.comma() 209 } 210 switch expr := selectable.(type) { 211 case Column: 212 err := s.builder.buildColumn(expr) 213 if err != nil { 214 return errs.NewInvalidFieldError(expr.name) 215 } 216 case columns: 217 for j, c := range expr.cs { 218 err := s.buildColumns(j, c) 219 if err != nil { 220 return err 221 } 222 } 223 case Aggregate: 224 if err := s.selectAggregate(expr); err != nil { 225 return err 226 } 227 case RawExpr: 228 s.buildRawExpr(expr) 229 } 230 } 231 return nil 232 233 } 234 func (s *Selector[T]) selectAggregate(aggregate Aggregate) error { 235 s.writeString(aggregate.fn) 236 237 s.writeByte('(') 238 if aggregate.distinct { 239 s.writeString("DISTINCT ") 240 } 241 cMeta, ok := s.meta.FieldMap[aggregate.arg] 242 // s.aliases[aggregate.alias] = struct{}{} 243 if !ok { 244 return errs.NewInvalidFieldError(aggregate.arg) 245 } 246 if aggregate.table != nil { 247 if alias := aggregate.table.getAlias(); alias != "" { 248 s.quote(alias) 249 s.point() 250 } 251 } 252 s.quote(cMeta.ColumnName) 253 s.writeByte(')') 254 if aggregate.alias != "" { 255 // if _, ok := s.aliases[aggregate.alias]; ok { 256 // s.writeString(" AS ") 257 // s.quote(aggregate.alias) 258 // } 259 s.writeString(" AS ") 260 s.quote(aggregate.alias) 261 } 262 return nil 263 } 264 265 func (s *Selector[T]) buildColumns(index int, name string) error { 266 if index > 0 { 267 s.comma() 268 } 269 cMeta, ok := s.meta.FieldMap[name] 270 if !ok { 271 return errs.NewInvalidFieldError(name) 272 } 273 s.quote(cMeta.ColumnName) 274 return nil 275 } 276 277 func (s *Selector[T]) buildUsing(using []string) error { 278 s.writeString(" USING (") 279 for i, col := range using { 280 err := s.buildColumns(i, col) 281 if err != nil { 282 return err 283 } 284 } 285 s.writeByte(')') 286 return nil 287 } 288 289 // Select 指定查询的列。 290 // 列可以是物理列,也可以是聚合函数,或者 RawExpr 291 func (s *Selector[T]) Select(columns ...Selectable) *Selector[T] { 292 s.columns = columns 293 return s 294 } 295 296 // From specifies the table which must be pointer of structure 297 func (s *Selector[T]) From(tbl TableReference) *Selector[T] { 298 s.table = tbl 299 return s 300 } 301 302 // Where accepts predicates 303 func (s *Selector[T]) Where(predicates ...Predicate) *Selector[T] { 304 s.where = predicates 305 return s 306 } 307 308 // Distinct indicates using keyword DISTINCT 309 func (s *Selector[T]) Distinct() *Selector[T] { 310 s.distinct = true 311 return s 312 } 313 314 // Having accepts predicates 315 func (s *Selector[T]) Having(predicates ...Predicate) *Selector[T] { 316 s.having = predicates 317 return s 318 } 319 320 // GroupBy means "GROUP BY" 321 func (s *Selector[T]) GroupBy(columns ...string) *Selector[T] { 322 s.groupBy = columns 323 return s 324 } 325 326 // OrderBy means "ORDER BY" 327 func (s *Selector[T]) OrderBy(orderBys ...OrderBy) *Selector[T] { 328 s.orderBy = orderBys 329 return s 330 } 331 332 // Limit limits the size of result set 333 func (s *Selector[T]) Limit(limit int) *Selector[T] { 334 s.limit = limit 335 return s 336 } 337 338 // Offset was used by "LIMIT" 339 func (s *Selector[T]) Offset(offset int) *Selector[T] { 340 s.offset = offset 341 return s 342 } 343 344 func (s *Selector[T]) AsSubquery(alias string) Subquery { 345 var table TableReference 346 if s.table == nil { 347 table = TableOf(new(T), alias) 348 } 349 return Subquery{ 350 entity: table, 351 q: s, 352 alias: alias, 353 columns: s.columns, 354 } 355 } 356 357 // Get 方法会执行查询,并且返回一条数据 358 // 注意,在不同的数据库情况下,第一条数据可能是按照不同的列来排序的 359 // 而且要注意,这个方法会强制设置 Limit 1 360 // 在没有查找到数据的情况下,会返回 ErrNoRows 361 func (s *Selector[T]) Get(ctx context.Context) (*T, error) { 362 query, err := s.Limit(1).Build() 363 if err != nil { 364 return nil, err 365 } 366 return newQuerier[T](s.Session, query, s.meta, SELECT).Get(ctx) 367 } 368 369 // OrderBy specify fields and ASC 370 type OrderBy struct { 371 fields []string 372 order string 373 } 374 375 // ASC means ORDER BY fields ASC 376 func ASC(fields ...string) OrderBy { 377 return OrderBy{ 378 fields: fields, 379 order: "ASC", 380 } 381 } 382 383 // DESC means ORDER BY fields DESC 384 func DESC(fields ...string) OrderBy { 385 return OrderBy{ 386 fields: fields, 387 order: "DESC", 388 } 389 } 390 391 // Selectable is a tag interface which represents SELECT XXX 392 type Selectable interface { 393 selected() 394 } 395 396 func (s *Selector[T]) GetMulti(ctx context.Context) ([]*T, error) { 397 query, err := s.Build() 398 if err != nil { 399 return nil, err 400 } 401 return newQuerier[T](s.Session, query, s.meta, SELECT).GetMulti(ctx) 402 } 403 404 func (s *Selector[T]) buildJoin(t Join) error { 405 s.writeByte('(') 406 if err := s.buildTable(t.left); err != nil { 407 return err 408 } 409 s.space() 410 s.writeString(t.typ) 411 s.space() 412 if err := s.buildTable(t.right); err != nil { 413 return err 414 } 415 if len(t.using) > 0 { 416 if err := s.buildUsing(t.using); err != nil { 417 return err 418 } 419 } 420 if len(t.on) > 0 { 421 s.writeString(" ON ") 422 if err := s.buildPredicates(t.on); err != nil { 423 return err 424 } 425 } 426 s.writeByte(')') 427 return nil 428 }