github.com/goplus/yap@v0.8.1/ydb/query.go (about)

     1  /*
     2   * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ydb
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  	"log"
    23  	"reflect"
    24  	"strconv"
    25  	"strings"
    26  
    27  	"github.com/goplus/yap/reflectutil"
    28  )
    29  
    30  // -----------------------------------------------------------------------------
    31  
    32  // Query creates a new query.
    33  //   - query <cond>, <arg1>, <arg2>, ...
    34  func (p *Class) Query(cond string, args ...any) {
    35  	p.query = &query{
    36  		cond: cond, args: args,
    37  	}
    38  	p.lastErr = nil
    39  	p.ret = p.queryRet
    40  }
    41  
    42  // NoRows checkes there are query result rows or not.
    43  func (p *Class) NoRows() bool {
    44  	return p.lastErr == ErrNoRows
    45  }
    46  
    47  // LastErr returns last error.
    48  func (p *Class) LastErr() error {
    49  	return p.lastErr
    50  }
    51  
    52  // -----------------------------------------------------------------------------
    53  
    54  type query struct {
    55  	cond  string // where
    56  	args  []any  // one of query argument <argN> can be a slice
    57  	limit int    // 0 means no limit
    58  }
    59  
    60  func (q *query) makeSelectExpr(tbl string, exprs []string) string {
    61  	query := make([]byte, 0, 128)
    62  	query = append(query, "SELECT "...)
    63  	query = append(query, strings.Join(exprs, ",")...)
    64  	query = append(query, " FROM "...)
    65  	query = append(query, tbl...)
    66  	query = append(query, " WHERE "...)
    67  	query = append(query, q.cond...)
    68  	if q.limit > 0 {
    69  		query = append(query, " LIMIT "...)
    70  		query = append(query, strconv.Itoa(q.limit)...)
    71  	}
    72  	return string(query)
    73  }
    74  
    75  // For checking query result:
    76  //   - ret <expr1>, &<var1>, <expr2>, &<var2>, ...
    77  //   - ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ...
    78  //   - ret &<structVar>
    79  //   - ret &<structSlice>
    80  func (p *Class) queryRet(args ...any) (err error) {
    81  	nArg := len(args)
    82  	if nArg == 1 {
    83  		err = p.queryRetPtr(args[0])
    84  	} else {
    85  		err = p.queryRetKvPair(args...)
    86  	}
    87  	p.query = nil
    88  	p.ret = nil
    89  	return
    90  }
    91  
    92  // For checking query result:
    93  //   - ret &<structVar>
    94  //   - ret &<structOrPtrSlice>
    95  func (p *Class) queryRetPtr(ret any) error {
    96  	vRet := reflect.ValueOf(ret)
    97  	if vRet.Kind() != reflect.Pointer {
    98  		log.Panicln("usage: ret &<structVar>")
    99  	}
   100  
   101  	switch vRet = vRet.Elem(); vRet.Kind() {
   102  	case reflect.Slice:
   103  		return p.queryStrucRows(vRet)
   104  	default:
   105  		return p.queryStrucRow(vRet)
   106  	}
   107  }
   108  
   109  // For checking query result:
   110  //   - ret &<structVar>
   111  func (p *Class) queryStrucRow(vRet reflect.Value) error {
   112  	if vRet.Kind() != reflect.Struct {
   113  		log.Panicln("usage: ret &<structVar>")
   114  	}
   115  
   116  	n := vRet.NumField()
   117  	names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vRet.Type(), 0)
   118  	rets := getVals(make([]any, 0, len(cols)), vRet, cols, false)
   119  
   120  	q := p.query
   121  	query := q.makeSelectExpr(p.tbl, names)
   122  	return p.queryVals(context.TODO(), query, q.args, rets)
   123  }
   124  
   125  func (p *Class) queryStrucOne(
   126  	ctx context.Context, query string, args []any,
   127  	vSlice reflect.Value, elem dbType, cols []field, hasPtr bool) error {
   128  	vRet := reflect.New(elem).Elem()
   129  	rets := getVals(make([]any, 0, len(cols)), vRet, cols, false)
   130  	err := p.queryVals(ctx, query, args, rets)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	if hasPtr {
   135  		vRet = vRet.Addr()
   136  	}
   137  	vSlice.Set(reflect.Append(vSlice, vRet))
   138  	return nil
   139  }
   140  
   141  func (p *Class) queryStrucMulti(
   142  	ctx context.Context, query string, args []any, iArgSlice int,
   143  	vSlice reflect.Value, elem dbType, cols []field, hasPtr bool) error {
   144  	argSlice := args[iArgSlice]
   145  	defer func() {
   146  		args[iArgSlice] = argSlice
   147  	}()
   148  	vArgSlice := reflect.ValueOf(argSlice)
   149  	for i, n := 0, vArgSlice.Len(); i < n; i++ {
   150  		arg := vArgSlice.Index(i).Interface()
   151  		args[iArgSlice] = arg
   152  		if err := p.queryStrucOne(ctx, query, args, vSlice, elem, cols, hasPtr); err != nil {
   153  			return err
   154  		}
   155  	}
   156  	return nil
   157  }
   158  
   159  // For checking query result:
   160  //   - ret &<structOrPtrSlice>
   161  func (p *Class) queryStrucRows(vSlice reflect.Value) error {
   162  	hasPtr := false
   163  	elem := vSlice.Type().Elem()
   164  	kind := elem.Kind()
   165  	if kind == reflect.Pointer {
   166  		elem, hasPtr = elem.Elem(), true
   167  		kind = elem.Kind()
   168  	}
   169  	if kind != reflect.Struct {
   170  		log.Panicln("usage: ret &<structOrPtrSlice>")
   171  	}
   172  
   173  	n := elem.NumField()
   174  	names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0)
   175  
   176  	q := p.query
   177  	query := q.makeSelectExpr(p.tbl, names)
   178  
   179  	args := q.args
   180  	iArgSlice := checkArgSlice(args)
   181  	if iArgSlice >= 0 {
   182  		return p.queryStrucMulti(context.TODO(), query, args, iArgSlice, vSlice, elem, cols, hasPtr)
   183  	}
   184  	return p.queryStrucOne(context.TODO(), query, args, vSlice, elem, cols, hasPtr)
   185  }
   186  
   187  // queryVals NOTE:
   188  //   - one of args maybe is a slice
   189  func (p *Class) queryVals(ctx context.Context, query string, args, rets []any) error {
   190  	iArgSlice := checkArgSlice(args)
   191  	if iArgSlice >= 0 {
   192  		log.Panicln("one of `query` arguments is a slice, but `ret` arguments are not")
   193  	}
   194  
   195  	if debugExec {
   196  		log.Println("==>", query, args)
   197  	}
   198  	rows, err := p.db.QueryContext(ctx, query, args...)
   199  	p.lastErr = err
   200  	if err != nil {
   201  		p.handleErr("query:", err)
   202  		return err
   203  	}
   204  	defer rows.Close()
   205  
   206  	return p.queryRetRow(rows, rets)
   207  }
   208  
   209  func (p *Class) queryRetRow(rows *sql.Rows, rets []any) error {
   210  	if !rows.Next() {
   211  		err := rows.Err()
   212  		if err == nil {
   213  			err = ErrNoRows
   214  		}
   215  		p.lastErr = err
   216  		if err != ErrNoRows {
   217  			p.handleErr("ret:", err)
   218  		}
   219  		return err
   220  	}
   221  	err := rows.Scan(rets...)
   222  	p.lastErr = err
   223  	if err != nil {
   224  		p.handleErr("ret:", err)
   225  	}
   226  	return err
   227  }
   228  
   229  func (p *Class) queryRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, needInit bool) error {
   230  	for rows.Next() {
   231  		if needInit {
   232  			for _, ret := range oneRet {
   233  				reflectutil.SetZero(reflect.ValueOf(ret).Elem())
   234  			}
   235  		} else {
   236  			needInit = true
   237  		}
   238  		err := rows.Scan(oneRet...)
   239  		p.lastErr = err
   240  		if err != nil {
   241  			p.handleErr("ret:", err)
   242  			return err
   243  		}
   244  		for i, vRet := range vRets {
   245  			v := reflect.ValueOf(oneRet[i])
   246  			vRet.Set(reflect.Append(vRet, v.Elem()))
   247  		}
   248  	}
   249  	err := rows.Err()
   250  	p.lastErr = err
   251  	if err != nil {
   252  		p.handleErr("ret:", err)
   253  	}
   254  	return err
   255  }
   256  
   257  // queryRows NOTE:
   258  //   - one of args maybe is a slice
   259  func (p *Class) queryRows(ctx context.Context, query string, args, rets []any) error {
   260  	iArgSlice := checkArgSlice(args)
   261  	if iArgSlice >= 0 {
   262  		return p.queryMulti(ctx, query, iArgSlice, args, rets)
   263  	}
   264  
   265  	if debugExec {
   266  		log.Println("==>", query, args)
   267  	}
   268  	rows, err := p.db.QueryContext(ctx, query, args...)
   269  	p.lastErr = err
   270  	if err != nil {
   271  		p.handleErr("query:", err)
   272  		return err
   273  	}
   274  	defer rows.Close()
   275  
   276  	vRets, oneRet := makeSliceRets(rets)
   277  	return p.queryRetRows(rows, vRets, oneRet, false)
   278  }
   279  
   280  func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) {
   281  	vRets = make([]reflect.Value, len(rets))
   282  	oneRet = make([]any, len(rets))
   283  	for i, ret := range rets {
   284  		slice := reflect.ValueOf(ret).Elem()
   285  		vRets[i] = slice
   286  
   287  		elem := slice.Type().Elem()
   288  		oneRet[i] = reflect.New(elem).Interface()
   289  	}
   290  	return
   291  }
   292  
   293  func (p *Class) queryMultiOne(ctx context.Context, query string, args, oneRet []any, vRets []reflect.Value) error {
   294  	if debugExec {
   295  		log.Println("==>", query, args)
   296  	}
   297  	rows, err := p.db.QueryContext(ctx, query, args...)
   298  	p.lastErr = err
   299  	if err != nil {
   300  		p.handleErr("query:", err)
   301  		return err
   302  	}
   303  	defer rows.Close()
   304  
   305  	return p.queryRetRows(rows, vRets, oneRet, true)
   306  }
   307  
   308  func (p *Class) queryMulti(ctx context.Context, query string, iArgSlice int, args, rets []any) error {
   309  	argSlice := args[iArgSlice]
   310  	defer func() {
   311  		args[iArgSlice] = argSlice
   312  	}()
   313  	vRets, oneRet := makeSliceRets(rets)
   314  	vArgSlice := reflect.ValueOf(argSlice)
   315  	for i, n := 0, vArgSlice.Len(); i < n; i++ {
   316  		arg := vArgSlice.Index(i).Interface()
   317  		args[iArgSlice] = arg
   318  		if err := p.queryMultiOne(ctx, query, args, oneRet, vRets); err != nil {
   319  			return err
   320  		}
   321  	}
   322  	return nil
   323  }
   324  
   325  // For checking query result:
   326  //   - ret <expr1>, &<var1>, <expr2>, &<var2>, ...
   327  //   - ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ...
   328  func (p *Class) queryRetKvPair(kvPair ...any) error {
   329  	nPair := len(kvPair)
   330  	if nPair < 2 || nPair&1 != 0 {
   331  		log.Panicln("usage: ret <expr1>, &<var1>, <expr2>, &<var2>, ...")
   332  	}
   333  
   334  	q := p.query
   335  	tbl := p.exprTblname(q.cond)
   336  
   337  	n := nPair >> 1
   338  	exprs := make([]string, n)
   339  	rets := make([]any, n)
   340  	kind := 0
   341  	for i := 0; i < nPair; i += 2 {
   342  		expr := kvPair[i].(string)
   343  		if etbl := p.exprTblname(expr); etbl != tbl {
   344  			log.Panicf(
   345  				"query currently doesn't support multiple tables: `query` use `%s` but `ret` use `%s`\n",
   346  				tbl, etbl,
   347  			)
   348  		}
   349  		ret := kvPair[i+1]
   350  		kind |= retKind(ret)
   351  		exprs[i>>1] = expr
   352  		rets[i>>1] = ret
   353  	}
   354  	if kind == valFlagInvalid {
   355  		log.Panicln(`all ret arguments should be address of slices or address of normal variable:
   356  	ret <expr1>, &<var1>, <expr2>, &<var2>, ...
   357  	ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ...`)
   358  	}
   359  
   360  	query := q.makeSelectExpr(tbl, exprs)
   361  	if kind == valFlagNormal {
   362  		return p.queryVals(context.TODO(), query, q.args, rets)
   363  	}
   364  	return p.queryRows(context.TODO(), query, q.args, rets)
   365  }
   366  
   367  func retKind(ret any) int {
   368  	v := reflect.ValueOf(ret)
   369  	if v.Kind() != reflect.Pointer {
   370  		log.Panicln("usage: ret <expr1>, &<var1>, <expr2>, &<var2>, ...")
   371  	}
   372  	if v.Elem().Kind() == reflect.Slice {
   373  		return valFlagSlice
   374  	}
   375  	return valFlagNormal
   376  }
   377  
   378  // -----------------------------------------------------------------------------
   379  
   380  // Limit sets query result rows limit.
   381  func (p *Class) Limit__0(n int) {
   382  	if p.query == nil {
   383  		log.Panicln("please call `limit` after a query statement")
   384  	}
   385  	p.query.limit = n
   386  }
   387  
   388  // Limit checks if query result rows is < n or not.
   389  func (p *Class) Limit__1(n int, cond string, args ...any) error {
   390  	ret, err := p.Count(cond, args...)
   391  	if err != nil {
   392  		return err
   393  	}
   394  	if ret >= n {
   395  		if p.onErr == nil {
   396  			log.Panicf("limit %s: got %d, expected <%d\n", cond, ret, n)
   397  		}
   398  		err = ErrOutOfLimit
   399  		p.onErr(err)
   400  	}
   401  	return err
   402  }
   403  
   404  // -----------------------------------------------------------------------------
   405  
   406  // Count returns rows of a query result.
   407  func (p *Class) Count(cond string, args ...any) (n int, err error) {
   408  	if p.tbl == "" {
   409  		log.Panicln("please call `use <tableName>` to specified a table name")
   410  	}
   411  	row := p.db.QueryRowContext(context.TODO(), "SELECT COUNT(*) FROM "+p.tbl+" WHERE "+cond, args...)
   412  	if err = row.Scan(&n); err != nil {
   413  		p.handleErr("query:", err)
   414  	}
   415  	return
   416  }
   417  
   418  // -----------------------------------------------------------------------------