github.com/goplus/yap@v0.8.1/ydb/query.go (about) 1 /* 2 * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ydb 18 19 import ( 20 "context" 21 "database/sql" 22 "log" 23 "reflect" 24 "strconv" 25 "strings" 26 27 "github.com/goplus/yap/reflectutil" 28 ) 29 30 // ----------------------------------------------------------------------------- 31 32 // Query creates a new query. 33 // - query <cond>, <arg1>, <arg2>, ... 34 func (p *Class) Query(cond string, args ...any) { 35 p.query = &query{ 36 cond: cond, args: args, 37 } 38 p.lastErr = nil 39 p.ret = p.queryRet 40 } 41 42 // NoRows checkes there are query result rows or not. 43 func (p *Class) NoRows() bool { 44 return p.lastErr == ErrNoRows 45 } 46 47 // LastErr returns last error. 48 func (p *Class) LastErr() error { 49 return p.lastErr 50 } 51 52 // ----------------------------------------------------------------------------- 53 54 type query struct { 55 cond string // where 56 args []any // one of query argument <argN> can be a slice 57 limit int // 0 means no limit 58 } 59 60 func (q *query) makeSelectExpr(tbl string, exprs []string) string { 61 query := make([]byte, 0, 128) 62 query = append(query, "SELECT "...) 63 query = append(query, strings.Join(exprs, ",")...) 64 query = append(query, " FROM "...) 65 query = append(query, tbl...) 66 query = append(query, " WHERE "...) 67 query = append(query, q.cond...) 68 if q.limit > 0 { 69 query = append(query, " LIMIT "...) 70 query = append(query, strconv.Itoa(q.limit)...) 71 } 72 return string(query) 73 } 74 75 // For checking query result: 76 // - ret <expr1>, &<var1>, <expr2>, &<var2>, ... 77 // - ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ... 78 // - ret &<structVar> 79 // - ret &<structSlice> 80 func (p *Class) queryRet(args ...any) (err error) { 81 nArg := len(args) 82 if nArg == 1 { 83 err = p.queryRetPtr(args[0]) 84 } else { 85 err = p.queryRetKvPair(args...) 86 } 87 p.query = nil 88 p.ret = nil 89 return 90 } 91 92 // For checking query result: 93 // - ret &<structVar> 94 // - ret &<structOrPtrSlice> 95 func (p *Class) queryRetPtr(ret any) error { 96 vRet := reflect.ValueOf(ret) 97 if vRet.Kind() != reflect.Pointer { 98 log.Panicln("usage: ret &<structVar>") 99 } 100 101 switch vRet = vRet.Elem(); vRet.Kind() { 102 case reflect.Slice: 103 return p.queryStrucRows(vRet) 104 default: 105 return p.queryStrucRow(vRet) 106 } 107 } 108 109 // For checking query result: 110 // - ret &<structVar> 111 func (p *Class) queryStrucRow(vRet reflect.Value) error { 112 if vRet.Kind() != reflect.Struct { 113 log.Panicln("usage: ret &<structVar>") 114 } 115 116 n := vRet.NumField() 117 names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vRet.Type(), 0) 118 rets := getVals(make([]any, 0, len(cols)), vRet, cols, false) 119 120 q := p.query 121 query := q.makeSelectExpr(p.tbl, names) 122 return p.queryVals(context.TODO(), query, q.args, rets) 123 } 124 125 func (p *Class) queryStrucOne( 126 ctx context.Context, query string, args []any, 127 vSlice reflect.Value, elem dbType, cols []field, hasPtr bool) error { 128 vRet := reflect.New(elem).Elem() 129 rets := getVals(make([]any, 0, len(cols)), vRet, cols, false) 130 err := p.queryVals(ctx, query, args, rets) 131 if err != nil { 132 return err 133 } 134 if hasPtr { 135 vRet = vRet.Addr() 136 } 137 vSlice.Set(reflect.Append(vSlice, vRet)) 138 return nil 139 } 140 141 func (p *Class) queryStrucMulti( 142 ctx context.Context, query string, args []any, iArgSlice int, 143 vSlice reflect.Value, elem dbType, cols []field, hasPtr bool) error { 144 argSlice := args[iArgSlice] 145 defer func() { 146 args[iArgSlice] = argSlice 147 }() 148 vArgSlice := reflect.ValueOf(argSlice) 149 for i, n := 0, vArgSlice.Len(); i < n; i++ { 150 arg := vArgSlice.Index(i).Interface() 151 args[iArgSlice] = arg 152 if err := p.queryStrucOne(ctx, query, args, vSlice, elem, cols, hasPtr); err != nil { 153 return err 154 } 155 } 156 return nil 157 } 158 159 // For checking query result: 160 // - ret &<structOrPtrSlice> 161 func (p *Class) queryStrucRows(vSlice reflect.Value) error { 162 hasPtr := false 163 elem := vSlice.Type().Elem() 164 kind := elem.Kind() 165 if kind == reflect.Pointer { 166 elem, hasPtr = elem.Elem(), true 167 kind = elem.Kind() 168 } 169 if kind != reflect.Struct { 170 log.Panicln("usage: ret &<structOrPtrSlice>") 171 } 172 173 n := elem.NumField() 174 names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0) 175 176 q := p.query 177 query := q.makeSelectExpr(p.tbl, names) 178 179 args := q.args 180 iArgSlice := checkArgSlice(args) 181 if iArgSlice >= 0 { 182 return p.queryStrucMulti(context.TODO(), query, args, iArgSlice, vSlice, elem, cols, hasPtr) 183 } 184 return p.queryStrucOne(context.TODO(), query, args, vSlice, elem, cols, hasPtr) 185 } 186 187 // queryVals NOTE: 188 // - one of args maybe is a slice 189 func (p *Class) queryVals(ctx context.Context, query string, args, rets []any) error { 190 iArgSlice := checkArgSlice(args) 191 if iArgSlice >= 0 { 192 log.Panicln("one of `query` arguments is a slice, but `ret` arguments are not") 193 } 194 195 if debugExec { 196 log.Println("==>", query, args) 197 } 198 rows, err := p.db.QueryContext(ctx, query, args...) 199 p.lastErr = err 200 if err != nil { 201 p.handleErr("query:", err) 202 return err 203 } 204 defer rows.Close() 205 206 return p.queryRetRow(rows, rets) 207 } 208 209 func (p *Class) queryRetRow(rows *sql.Rows, rets []any) error { 210 if !rows.Next() { 211 err := rows.Err() 212 if err == nil { 213 err = ErrNoRows 214 } 215 p.lastErr = err 216 if err != ErrNoRows { 217 p.handleErr("ret:", err) 218 } 219 return err 220 } 221 err := rows.Scan(rets...) 222 p.lastErr = err 223 if err != nil { 224 p.handleErr("ret:", err) 225 } 226 return err 227 } 228 229 func (p *Class) queryRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, needInit bool) error { 230 for rows.Next() { 231 if needInit { 232 for _, ret := range oneRet { 233 reflectutil.SetZero(reflect.ValueOf(ret).Elem()) 234 } 235 } else { 236 needInit = true 237 } 238 err := rows.Scan(oneRet...) 239 p.lastErr = err 240 if err != nil { 241 p.handleErr("ret:", err) 242 return err 243 } 244 for i, vRet := range vRets { 245 v := reflect.ValueOf(oneRet[i]) 246 vRet.Set(reflect.Append(vRet, v.Elem())) 247 } 248 } 249 err := rows.Err() 250 p.lastErr = err 251 if err != nil { 252 p.handleErr("ret:", err) 253 } 254 return err 255 } 256 257 // queryRows NOTE: 258 // - one of args maybe is a slice 259 func (p *Class) queryRows(ctx context.Context, query string, args, rets []any) error { 260 iArgSlice := checkArgSlice(args) 261 if iArgSlice >= 0 { 262 return p.queryMulti(ctx, query, iArgSlice, args, rets) 263 } 264 265 if debugExec { 266 log.Println("==>", query, args) 267 } 268 rows, err := p.db.QueryContext(ctx, query, args...) 269 p.lastErr = err 270 if err != nil { 271 p.handleErr("query:", err) 272 return err 273 } 274 defer rows.Close() 275 276 vRets, oneRet := makeSliceRets(rets) 277 return p.queryRetRows(rows, vRets, oneRet, false) 278 } 279 280 func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { 281 vRets = make([]reflect.Value, len(rets)) 282 oneRet = make([]any, len(rets)) 283 for i, ret := range rets { 284 slice := reflect.ValueOf(ret).Elem() 285 vRets[i] = slice 286 287 elem := slice.Type().Elem() 288 oneRet[i] = reflect.New(elem).Interface() 289 } 290 return 291 } 292 293 func (p *Class) queryMultiOne(ctx context.Context, query string, args, oneRet []any, vRets []reflect.Value) error { 294 if debugExec { 295 log.Println("==>", query, args) 296 } 297 rows, err := p.db.QueryContext(ctx, query, args...) 298 p.lastErr = err 299 if err != nil { 300 p.handleErr("query:", err) 301 return err 302 } 303 defer rows.Close() 304 305 return p.queryRetRows(rows, vRets, oneRet, true) 306 } 307 308 func (p *Class) queryMulti(ctx context.Context, query string, iArgSlice int, args, rets []any) error { 309 argSlice := args[iArgSlice] 310 defer func() { 311 args[iArgSlice] = argSlice 312 }() 313 vRets, oneRet := makeSliceRets(rets) 314 vArgSlice := reflect.ValueOf(argSlice) 315 for i, n := 0, vArgSlice.Len(); i < n; i++ { 316 arg := vArgSlice.Index(i).Interface() 317 args[iArgSlice] = arg 318 if err := p.queryMultiOne(ctx, query, args, oneRet, vRets); err != nil { 319 return err 320 } 321 } 322 return nil 323 } 324 325 // For checking query result: 326 // - ret <expr1>, &<var1>, <expr2>, &<var2>, ... 327 // - ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ... 328 func (p *Class) queryRetKvPair(kvPair ...any) error { 329 nPair := len(kvPair) 330 if nPair < 2 || nPair&1 != 0 { 331 log.Panicln("usage: ret <expr1>, &<var1>, <expr2>, &<var2>, ...") 332 } 333 334 q := p.query 335 tbl := p.exprTblname(q.cond) 336 337 n := nPair >> 1 338 exprs := make([]string, n) 339 rets := make([]any, n) 340 kind := 0 341 for i := 0; i < nPair; i += 2 { 342 expr := kvPair[i].(string) 343 if etbl := p.exprTblname(expr); etbl != tbl { 344 log.Panicf( 345 "query currently doesn't support multiple tables: `query` use `%s` but `ret` use `%s`\n", 346 tbl, etbl, 347 ) 348 } 349 ret := kvPair[i+1] 350 kind |= retKind(ret) 351 exprs[i>>1] = expr 352 rets[i>>1] = ret 353 } 354 if kind == valFlagInvalid { 355 log.Panicln(`all ret arguments should be address of slices or address of normal variable: 356 ret <expr1>, &<var1>, <expr2>, &<var2>, ... 357 ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ...`) 358 } 359 360 query := q.makeSelectExpr(tbl, exprs) 361 if kind == valFlagNormal { 362 return p.queryVals(context.TODO(), query, q.args, rets) 363 } 364 return p.queryRows(context.TODO(), query, q.args, rets) 365 } 366 367 func retKind(ret any) int { 368 v := reflect.ValueOf(ret) 369 if v.Kind() != reflect.Pointer { 370 log.Panicln("usage: ret <expr1>, &<var1>, <expr2>, &<var2>, ...") 371 } 372 if v.Elem().Kind() == reflect.Slice { 373 return valFlagSlice 374 } 375 return valFlagNormal 376 } 377 378 // ----------------------------------------------------------------------------- 379 380 // Limit sets query result rows limit. 381 func (p *Class) Limit__0(n int) { 382 if p.query == nil { 383 log.Panicln("please call `limit` after a query statement") 384 } 385 p.query.limit = n 386 } 387 388 // Limit checks if query result rows is < n or not. 389 func (p *Class) Limit__1(n int, cond string, args ...any) error { 390 ret, err := p.Count(cond, args...) 391 if err != nil { 392 return err 393 } 394 if ret >= n { 395 if p.onErr == nil { 396 log.Panicf("limit %s: got %d, expected <%d\n", cond, ret, n) 397 } 398 err = ErrOutOfLimit 399 p.onErr(err) 400 } 401 return err 402 } 403 404 // ----------------------------------------------------------------------------- 405 406 // Count returns rows of a query result. 407 func (p *Class) Count(cond string, args ...any) (n int, err error) { 408 if p.tbl == "" { 409 log.Panicln("please call `use <tableName>` to specified a table name") 410 } 411 row := p.db.QueryRowContext(context.TODO(), "SELECT COUNT(*) FROM "+p.tbl+" WHERE "+cond, args...) 412 if err = row.Scan(&n); err != nil { 413 p.handleErr("query:", err) 414 } 415 return 416 } 417 418 // -----------------------------------------------------------------------------