github.com/goplus/yap@v0.8.1/ydb/class.go (about)

     1  /*
     2   * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ydb
    18  
    19  import (
    20  	"database/sql"
    21  	"errors"
    22  	"log"
    23  	"reflect"
    24  	"runtime/debug"
    25  
    26  	"github.com/goplus/yap/test"
    27  	"github.com/goplus/yap/test/logt"
    28  )
    29  
    30  var (
    31  	ErrNoRows     = sql.ErrNoRows
    32  	ErrDuplicated = errors.New("duplicated")
    33  	ErrOutOfLimit = errors.New("out of limit")
    34  )
    35  
    36  // -----------------------------------------------------------------------------
    37  
    38  type Class struct {
    39  	Sql
    40  
    41  	self reflect.Value
    42  	tbl  string
    43  
    44  	result []reflect.Value // result of an api call
    45  
    46  	ret     func(args ...any) error
    47  	onErr   func(err error)
    48  	lastErr error
    49  	test.Case
    50  
    51  	query *query // query
    52  }
    53  
    54  func (p *Class) initClass(self any) {
    55  	p.initSql()
    56  	p.self = reflect.ValueOf(self)
    57  }
    58  
    59  func (p *Class) t() test.CaseT {
    60  	if p.CaseT == nil {
    61  		p.CaseT = logt.New()
    62  	}
    63  	return p.CaseT
    64  }
    65  
    66  // Use sets the default table used in following sql operations.
    67  func (p *Class) Use(table string) {
    68  	_, ok := p.tables[table]
    69  	if !ok {
    70  		log.Panicln("table not found:", table)
    71  	}
    72  	p.tbl = table
    73  }
    74  
    75  // OnErr sets error processing of a sql execution.
    76  func (p *Class) OnErr(onErr func(error)) {
    77  	p.onErr = onErr
    78  }
    79  
    80  func (p *Class) handleErr(prompt string, err error) {
    81  	err = p.wrap(prompt, err)
    82  	if p.onErr == nil {
    83  		log.Panicln(prompt, err)
    84  	}
    85  	p.onErr(err)
    86  }
    87  
    88  // Ret checks a query or call result.
    89  //
    90  // For checking query result:
    91  //   - ret <expr1>, &<var1>, <expr2>, &<var2>, ...
    92  //   - ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ...
    93  //   - ret &<structVar>
    94  //   - ret &<structSlice>
    95  //
    96  // For checking call result:
    97  //   - ret <expr1>, <expr2>, ...
    98  func (p *Class) Ret(args ...any) {
    99  	if p.ret == nil {
   100  		log.Panicln("please call `ret` after a `query` or `call` statement")
   101  	}
   102  	p.ret(args...)
   103  }
   104  
   105  // -----------------------------------------------------------------------------
   106  
   107  func (p *Class) Gop_Exec(name string, args ...any) {
   108  	vFn := p.method(name)
   109  	p.call(name, vFn, args...)
   110  }
   111  
   112  func (p *Class) method(name string) reflect.Value {
   113  	c := name[0]
   114  	if c >= 'a' && c <= 'z' {
   115  		name = string(c-('a'-'A')) + name[1:]
   116  	}
   117  	name = "API_" + name
   118  	return p.self.MethodByName(name)
   119  }
   120  
   121  func (p *Class) call(name string, vFn reflect.Value, args ...any) {
   122  	if debugExec {
   123  		log.Println("==>", name, args)
   124  	}
   125  
   126  	vArgs := make([]reflect.Value, len(args))
   127  	for i, arg := range args {
   128  		vArgs[i] = reflect.ValueOf(arg)
   129  	}
   130  
   131  	var old = p.onErr
   132  	var errRet error
   133  	p.onErr = func(err error) {
   134  		errRet = err
   135  		panic(err)
   136  	}
   137  	defer func() {
   138  		p.onErr = old
   139  		if p.result == nil { // set p.result to zero if panic
   140  			fnt := vFn.Type()
   141  			n := fnt.NumOut()
   142  			p.result = make([]reflect.Value, n, n+1)
   143  			for i := 0; i < n; i++ {
   144  				p.result[i] = reflect.Zero(fnt.Out(i))
   145  			}
   146  		}
   147  		if !hasRetErrType(p.result) {
   148  			p.result = append(p.result, reflect.Zero(tyError))
   149  		}
   150  		if e := recover(); e != nil {
   151  			if errRet == nil {
   152  				errRet = recoverErr(e)
   153  			}
   154  			p.result[len(p.result)-1] = reflect.ValueOf(errRet)
   155  			if debugExec {
   156  				log.Println("PANIC:", e)
   157  				debug.PrintStack()
   158  			}
   159  		}
   160  		p.ret = p.callRet
   161  	}()
   162  	p.result = nil
   163  	p.result = vFn.Call(vArgs)
   164  }
   165  
   166  func (p *Class) callRet(args ...any) error {
   167  	t := p.t()
   168  	result := p.result
   169  	if len(result) != len(args) {
   170  		if len(result) != len(args)+1 {
   171  			t.Fatalf(
   172  				"call ret: unmatched result parameters count - got %d, expected %d\n",
   173  				len(args), len(result),
   174  			)
   175  		}
   176  		args = append(args, nil)
   177  	}
   178  	for i, arg := range args {
   179  		ret := result[i].Interface()
   180  		test.Gopt_Case_MatchAny(t, arg, ret)
   181  	}
   182  	p.ret = nil
   183  	return nil
   184  }
   185  
   186  var (
   187  	tyError = reflect.TypeOf((*error)(nil)).Elem()
   188  )
   189  
   190  func hasRetErrType(result []reflect.Value) bool {
   191  	if n := len(result); n > 0 {
   192  		return result[n-1].Type() == tyError
   193  	}
   194  	return false
   195  }
   196  
   197  // Out returns the ith reuslt.
   198  func (p *Class) Out(i int) any {
   199  	return p.result[i].Interface()
   200  }
   201  
   202  // -----------------------------------------------------------------------------