github.com/goplus/yap@v0.8.1/ydb/class.go (about) 1 /* 2 * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ydb 18 19 import ( 20 "database/sql" 21 "errors" 22 "log" 23 "reflect" 24 "runtime/debug" 25 26 "github.com/goplus/yap/test" 27 "github.com/goplus/yap/test/logt" 28 ) 29 30 var ( 31 ErrNoRows = sql.ErrNoRows 32 ErrDuplicated = errors.New("duplicated") 33 ErrOutOfLimit = errors.New("out of limit") 34 ) 35 36 // ----------------------------------------------------------------------------- 37 38 type Class struct { 39 Sql 40 41 self reflect.Value 42 tbl string 43 44 result []reflect.Value // result of an api call 45 46 ret func(args ...any) error 47 onErr func(err error) 48 lastErr error 49 test.Case 50 51 query *query // query 52 } 53 54 func (p *Class) initClass(self any) { 55 p.initSql() 56 p.self = reflect.ValueOf(self) 57 } 58 59 func (p *Class) t() test.CaseT { 60 if p.CaseT == nil { 61 p.CaseT = logt.New() 62 } 63 return p.CaseT 64 } 65 66 // Use sets the default table used in following sql operations. 67 func (p *Class) Use(table string) { 68 _, ok := p.tables[table] 69 if !ok { 70 log.Panicln("table not found:", table) 71 } 72 p.tbl = table 73 } 74 75 // OnErr sets error processing of a sql execution. 76 func (p *Class) OnErr(onErr func(error)) { 77 p.onErr = onErr 78 } 79 80 func (p *Class) handleErr(prompt string, err error) { 81 err = p.wrap(prompt, err) 82 if p.onErr == nil { 83 log.Panicln(prompt, err) 84 } 85 p.onErr(err) 86 } 87 88 // Ret checks a query or call result. 89 // 90 // For checking query result: 91 // - ret <expr1>, &<var1>, <expr2>, &<var2>, ... 92 // - ret <expr1>, &<varSlice1>, <expr2>, &<varSlice2>, ... 93 // - ret &<structVar> 94 // - ret &<structSlice> 95 // 96 // For checking call result: 97 // - ret <expr1>, <expr2>, ... 98 func (p *Class) Ret(args ...any) { 99 if p.ret == nil { 100 log.Panicln("please call `ret` after a `query` or `call` statement") 101 } 102 p.ret(args...) 103 } 104 105 // ----------------------------------------------------------------------------- 106 107 func (p *Class) Gop_Exec(name string, args ...any) { 108 vFn := p.method(name) 109 p.call(name, vFn, args...) 110 } 111 112 func (p *Class) method(name string) reflect.Value { 113 c := name[0] 114 if c >= 'a' && c <= 'z' { 115 name = string(c-('a'-'A')) + name[1:] 116 } 117 name = "API_" + name 118 return p.self.MethodByName(name) 119 } 120 121 func (p *Class) call(name string, vFn reflect.Value, args ...any) { 122 if debugExec { 123 log.Println("==>", name, args) 124 } 125 126 vArgs := make([]reflect.Value, len(args)) 127 for i, arg := range args { 128 vArgs[i] = reflect.ValueOf(arg) 129 } 130 131 var old = p.onErr 132 var errRet error 133 p.onErr = func(err error) { 134 errRet = err 135 panic(err) 136 } 137 defer func() { 138 p.onErr = old 139 if p.result == nil { // set p.result to zero if panic 140 fnt := vFn.Type() 141 n := fnt.NumOut() 142 p.result = make([]reflect.Value, n, n+1) 143 for i := 0; i < n; i++ { 144 p.result[i] = reflect.Zero(fnt.Out(i)) 145 } 146 } 147 if !hasRetErrType(p.result) { 148 p.result = append(p.result, reflect.Zero(tyError)) 149 } 150 if e := recover(); e != nil { 151 if errRet == nil { 152 errRet = recoverErr(e) 153 } 154 p.result[len(p.result)-1] = reflect.ValueOf(errRet) 155 if debugExec { 156 log.Println("PANIC:", e) 157 debug.PrintStack() 158 } 159 } 160 p.ret = p.callRet 161 }() 162 p.result = nil 163 p.result = vFn.Call(vArgs) 164 } 165 166 func (p *Class) callRet(args ...any) error { 167 t := p.t() 168 result := p.result 169 if len(result) != len(args) { 170 if len(result) != len(args)+1 { 171 t.Fatalf( 172 "call ret: unmatched result parameters count - got %d, expected %d\n", 173 len(args), len(result), 174 ) 175 } 176 args = append(args, nil) 177 } 178 for i, arg := range args { 179 ret := result[i].Interface() 180 test.Gopt_Case_MatchAny(t, arg, ret) 181 } 182 p.ret = nil 183 return nil 184 } 185 186 var ( 187 tyError = reflect.TypeOf((*error)(nil)).Elem() 188 ) 189 190 func hasRetErrType(result []reflect.Value) bool { 191 if n := len(result); n > 0 { 192 return result[n-1].Type() == tyError 193 } 194 return false 195 } 196 197 // Out returns the ith reuslt. 198 func (p *Class) Out(i int) any { 199 return p.result[i].Interface() 200 } 201 202 // -----------------------------------------------------------------------------