github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_select.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "sync" 20 21 "github.com/ecodeclub/eorm/internal/merger/batchmerger" 22 23 "github.com/ecodeclub/eorm/internal/sharding" 24 25 "github.com/ecodeclub/eorm/internal/errs" 26 "github.com/valyala/bytebufferpool" 27 ) 28 29 type ShardingSelector[T any] struct { 30 shardingSelectorBuilder 31 table *T 32 db Session 33 lock sync.Mutex 34 } 35 36 func NewShardingSelector[T any](db Session) *ShardingSelector[T] { 37 b := shardingSelectorBuilder{} 38 b.core = db.getCore() 39 b.buffer = bytebufferpool.Get() 40 return &ShardingSelector[T]{ 41 shardingSelectorBuilder: b, 42 db: db, 43 } 44 } 45 46 func (s *ShardingSelector[T]) Build(ctx context.Context) ([]sharding.Query, error) { 47 var err error 48 if s.meta == nil { 49 s.meta, err = s.metaRegistry.Get(new(T)) 50 if err != nil { 51 return nil, err 52 } 53 } 54 shardingRes, err := s.findDst(ctx, s.where...) 55 if err != nil { 56 return nil, err 57 } 58 res := make([]sharding.Query, 0, len(shardingRes.Dsts)) 59 defer bytebufferpool.Put(s.buffer) 60 for _, dst := range shardingRes.Dsts { 61 q, err := s.buildQuery(dst.DB, dst.Table, dst.Name) 62 if err != nil { 63 return nil, err 64 } 65 res = append(res, q) 66 s.args = nil 67 s.buffer.Reset() 68 } 69 return res, nil 70 } 71 72 func (s *ShardingSelector[T]) buildQuery(db, tbl, ds string) (sharding.Query, error) { 73 var err error 74 s.writeString("SELECT ") 75 if len(s.columns) == 0 { 76 if err = s.buildAllColumns(); err != nil { 77 return sharding.EmptyQuery, err 78 } 79 } else { 80 err = s.buildSelectedList() 81 if err != nil { 82 return sharding.EmptyQuery, err 83 } 84 } 85 s.writeString(" FROM ") 86 s.quote(db) 87 s.writeByte('.') 88 s.quote(tbl) 89 90 if len(s.where) > 0 { 91 s.writeString(" WHERE ") 92 p := s.where[0] 93 for i := 1; i < len(s.where); i++ { 94 p = p.And(s.where[i]) 95 } 96 if err = s.buildExpr(p); err != nil { 97 return sharding.EmptyQuery, err 98 } 99 } 100 101 // group by 102 if len(s.groupBy) > 0 { 103 err = s.buildGroupBy() 104 if err != nil { 105 return sharding.EmptyQuery, err 106 } 107 } 108 109 // order by 110 if len(s.orderBy) > 0 { 111 err = s.buildOrderBy() 112 if err != nil { 113 return sharding.EmptyQuery, err 114 } 115 } 116 117 // having 118 if len(s.having) > 0 { 119 s.writeString(" HAVING ") 120 p := s.having[0] 121 for i := 1; i < len(s.having); i++ { 122 p = p.And(s.having[i]) 123 } 124 if err = s.buildExpr(p); err != nil { 125 return sharding.EmptyQuery, err 126 } 127 } 128 129 if s.offset > 0 { 130 s.writeString(" OFFSET ") 131 s.parameter(s.offset) 132 } 133 134 if s.limit > 0 { 135 s.writeString(" LIMIT ") 136 s.parameter(s.limit) 137 } 138 s.end() 139 return sharding.Query{SQL: s.buffer.String(), Args: s.args, Datasource: ds, DB: db}, nil 140 } 141 142 func (s *ShardingSelector[T]) buildAllColumns() error { 143 for i, cMeta := range s.meta.Columns { 144 _ = s.buildColumns(i, cMeta.FieldName) 145 } 146 return nil 147 } 148 149 func (s *ShardingSelector[T]) buildSelectedList() error { 150 for i, selectable := range s.columns { 151 if i > 0 { 152 s.comma() 153 } 154 switch expr := selectable.(type) { 155 case Column: 156 err := s.builder.buildColumn(expr) 157 if err != nil { 158 return errs.NewInvalidFieldError(expr.name) 159 } 160 case columns: 161 for j, c := range expr.cs { 162 err := s.buildColumns(j, c) 163 if err != nil { 164 return err 165 } 166 } 167 case Aggregate: 168 if err := s.selectAggregate(expr); err != nil { 169 return err 170 } 171 case RawExpr: 172 s.buildRawExpr(expr) 173 } 174 } 175 return nil 176 177 } 178 func (s *ShardingSelector[T]) selectAggregate(aggregate Aggregate) error { 179 s.writeString(aggregate.fn) 180 181 s.writeByte('(') 182 if aggregate.distinct { 183 s.writeString("DISTINCT ") 184 } 185 cMeta, ok := s.meta.FieldMap[aggregate.arg] 186 if !ok { 187 return errs.NewInvalidFieldError(aggregate.arg) 188 } 189 if aggregate.table != nil { 190 if alias := aggregate.table.getAlias(); alias != "" { 191 s.quote(alias) 192 s.point() 193 } 194 } 195 s.quote(cMeta.ColumnName) 196 s.writeByte(')') 197 if aggregate.alias != "" { 198 s.writeString(" AS ") 199 s.quote(aggregate.alias) 200 } 201 return nil 202 } 203 204 func (s *ShardingSelector[T]) buildColumns(index int, name string) error { 205 if index > 0 { 206 s.comma() 207 } 208 cMeta, ok := s.meta.FieldMap[name] 209 if !ok { 210 return errs.NewInvalidFieldError(name) 211 } 212 s.quote(cMeta.ColumnName) 213 return nil 214 } 215 216 func (s *ShardingSelector[T]) buildExpr(expr Expr) error { 217 switch exp := expr.(type) { 218 case nil: 219 case Column: 220 exp.alias = "" 221 _ = s.buildColumn(exp) 222 case valueExpr: 223 s.parameter(exp.val) 224 case RawExpr: 225 s.buildRawExpr(exp) 226 case Predicate: 227 if err := s.buildBinaryExpr(binaryExpr(exp)); err != nil { 228 return err 229 } 230 default: 231 return errs.NewErrUnsupportedExpressionType() 232 } 233 return nil 234 } 235 236 func (s *ShardingSelector[T]) buildOrderBy() error { 237 s.writeString(" ORDER BY ") 238 for i, ob := range s.orderBy { 239 if i > 0 { 240 s.comma() 241 } 242 for _, c := range ob.fields { 243 cMeta, ok := s.meta.FieldMap[c] 244 if !ok { 245 return errs.NewInvalidFieldError(c) 246 } 247 s.quote(cMeta.ColumnName) 248 } 249 s.space() 250 s.writeString(ob.order) 251 } 252 return nil 253 } 254 255 func (s *ShardingSelector[T]) buildGroupBy() error { 256 s.writeString(" GROUP BY ") 257 for i, gb := range s.groupBy { 258 cMeta, ok := s.meta.FieldMap[gb] 259 if !ok { 260 return errs.NewInvalidFieldError(gb) 261 } 262 if i > 0 { 263 s.comma() 264 } 265 s.quote(cMeta.ColumnName) 266 } 267 return nil 268 } 269 270 func (s *ShardingSelector[T]) Get(ctx context.Context) (*T, error) { 271 qs, err := s.Limit(1).Build(ctx) 272 if err != nil { 273 return nil, err 274 } 275 if len(qs) == 0 { 276 return nil, errs.ErrNotGenShardingQuery 277 } 278 // TODO 要确保前面的改写 SQL 只能生成一个 SQL 279 if len(qs) > 1 { 280 return nil, errs.ErrOnlyResultOneQuery 281 } 282 q := qs[0] 283 // TODO 利用 ctx 传递 DB name 284 row, err := s.db.queryContext(ctx, q) 285 if err != nil { 286 return nil, err 287 } 288 if !row.Next() { 289 return nil, ErrNoRows 290 } 291 tp := new(T) 292 val := s.valCreator.NewPrimitiveValue(tp, s.meta) 293 if err = val.SetColumns(row); err != nil { 294 return nil, err 295 } 296 return tp, nil 297 } 298 299 func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) { 300 qs, err := s.Build(ctx) 301 if err != nil { 302 return nil, err 303 } 304 305 mgr := batchmerger.NewMerger() 306 rowsList, err := s.db.queryMulti(ctx, qs) 307 if err != nil { 308 return nil, err 309 } 310 rows, err := mgr.Merge(ctx, rowsList.AsSlice()) 311 if err != nil { 312 return nil, err 313 } 314 defer rows.Close() 315 var res []*T 316 for rows.Next() { 317 tp := new(T) 318 val := s.valCreator.NewPrimitiveValue(tp, s.meta) 319 if err = val.SetColumns(rows); err != nil { 320 return nil, err 321 } 322 res = append(res, tp) 323 } 324 return res, nil 325 } 326 327 // Select 指定查询的列。 328 // 列可以是物理列,也可以是聚合函数,或者 RawExpr 329 func (s *ShardingSelector[T]) Select(columns ...Selectable) *ShardingSelector[T] { 330 s.columns = columns 331 return s 332 } 333 334 // From specifies the table which must be pointer of structure 335 func (s *ShardingSelector[T]) From(tbl *T) *ShardingSelector[T] { 336 s.table = tbl 337 return s 338 } 339 340 // Where accepts predicates 341 func (s *ShardingSelector[T]) Where(predicates ...Predicate) *ShardingSelector[T] { 342 s.where = predicates 343 return s 344 } 345 346 // Having accepts predicates 347 func (s *ShardingSelector[T]) Having(predicates ...Predicate) *ShardingSelector[T] { 348 s.having = predicates 349 return s 350 } 351 352 // GroupBy means "GROUP BY" 353 func (s *ShardingSelector[T]) GroupBy(columns ...string) *ShardingSelector[T] { 354 s.groupBy = columns 355 return s 356 } 357 358 // OrderBy means "ORDER BY" 359 func (s *ShardingSelector[T]) OrderBy(orderBys ...OrderBy) *ShardingSelector[T] { 360 s.orderBy = orderBys 361 return s 362 } 363 364 // Limit limits the size of result set 365 func (s *ShardingSelector[T]) Limit(limit int) *ShardingSelector[T] { 366 s.limit = limit 367 return s 368 } 369 370 // Offset was used by "LIMIT" 371 func (s *ShardingSelector[T]) Offset(offset int) *ShardingSelector[T] { 372 s.offset = offset 373 return s 374 }