github.com/goplus/yap@v0.8.1/ydb/insert.go (about) 1 /* 2 * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ydb 18 19 import ( 20 "context" 21 "database/sql" 22 "log" 23 "reflect" 24 "strings" 25 ) 26 27 // ----------------------------------------------------------------------------- 28 29 // Insert inserts new rows. 30 // - insert <colName1>, <val1>, <colName2>, <val2>, ... 31 // - insert <colName1>, <valSlice1>, <colName2>, <valSlice2>, ... 32 // - insert <structValOrPtr> 33 // - insert <structOrPtrSlice> 34 func (p *Class) Insert(args ...any) (sql.Result, error) { 35 if p.tbl == "" { 36 log.Panicln("please call `use <tableName>` to specified current table") 37 } 38 nArg := len(args) 39 if nArg == 1 { 40 return p.insertStruc(args[0]) 41 } 42 return p.insertKvPair(args...) 43 } 44 45 // Insert inserts a new row. 46 // - insert <structValOrPtr> 47 // - insert <structOrPtrSlice> 48 func (p *Class) insertStruc(arg any) (sql.Result, error) { 49 vArg := reflect.ValueOf(arg) 50 switch vArg.Kind() { 51 case reflect.Slice: 52 return p.insertStrucRows(vArg) 53 case reflect.Pointer: 54 vArg = vArg.Elem() 55 fallthrough 56 default: 57 return p.insertStrucRow(vArg) 58 } 59 } 60 61 func (p *Class) insertStrucRows(vSlice reflect.Value) (sql.Result, error) { 62 rows := vSlice.Len() 63 if rows == 0 { 64 return nil, nil 65 } 66 hasPtr := false 67 elem := vSlice.Type().Elem() 68 kind := elem.Kind() 69 if kind == reflect.Pointer { 70 elem, hasPtr = elem.Elem(), true 71 kind = elem.Kind() 72 } 73 if kind != reflect.Struct { 74 log.Panicln("usage: insert <structOrPtrSlice>") 75 } 76 n := elem.NumField() 77 names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0) 78 vals := make([]any, 0, len(names)*rows) 79 for row := 0; row < rows; row++ { 80 vElem := vSlice.Index(row) 81 if hasPtr { 82 vElem = vElem.Elem() 83 } 84 vals = getVals(vals, vElem, cols, true) 85 } 86 return p.insertRowsVals(p.tbl, names, vals, rows) 87 } 88 89 func (p *Class) insertStrucRow(vArg reflect.Value) (sql.Result, error) { 90 if vArg.Kind() != reflect.Struct { 91 log.Panicln("usage: insert <structValOrPtr>") 92 } 93 n := vArg.NumField() 94 names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vArg.Type(), 0) 95 vals := getVals(make([]any, 0, len(cols)), vArg, cols, true) 96 return p.insertRow(p.tbl, names, vals) 97 } 98 99 const ( 100 valFlagNormal = 1 101 valFlagSlice = 2 102 valFlagInvalid = valFlagNormal | valFlagSlice 103 ) 104 105 // Insert inserts a new row. 106 // - insert <colName1>, <val1>, <colName2>, <val2>, ... 107 // - insert <colName1>, <valSlice1>, <colName2>, <valSlice2>, ... 108 func (p *Class) insertKvPair(kvPair ...any) (sql.Result, error) { 109 nPair := len(kvPair) 110 if nPair < 2 || nPair&1 != 0 { 111 log.Panicln("usage: insert <colName1>, <val1>, <colName2>, <val2>, ...") 112 } 113 n := nPair >> 1 114 names := make([]string, n) 115 vals := make([]any, n) 116 rows := -1 // slice length 117 iArgSlice := -1 // -1: no slice, or index of first slice arg 118 kind := 0 119 for iPair := 0; iPair < nPair; iPair += 2 { 120 i := iPair >> 1 121 names[i] = kvPair[iPair].(string) 122 val := kvPair[iPair+1] 123 switch v := reflect.ValueOf(val); v.Kind() { 124 case reflect.Slice: 125 vlen := v.Len() 126 if iArgSlice == -1 { 127 iArgSlice = i 128 rows = vlen 129 } else if rows != vlen { 130 log.Panicf("insert: unexpected slice length. got %d, expected %d\n", vlen, rows) 131 } else { 132 kind |= valFlagSlice 133 } 134 vals[i] = v 135 default: 136 kind |= valFlagNormal 137 vals[i] = val 138 } 139 } 140 if kind == valFlagInvalid { 141 log.Panicln("insert: can't mix multiple slice arguments and normal value") 142 } 143 tbl := p.tblFromNames(names) 144 if kind == valFlagSlice { 145 return p.insertSlice(tbl, names, vals, rows) 146 } 147 if iArgSlice == -1 { 148 return p.insertRow(tbl, names, vals) 149 } 150 return p.insertMulti(tbl, names, iArgSlice, vals) 151 } 152 153 // NOTE: len(args) == len(names) 154 func (p *Class) insertMulti(tbl string, names []string, iArgSlice int, args []any) (sql.Result, error) { 155 argSlice := args[iArgSlice] 156 defer func() { 157 args[iArgSlice] = argSlice 158 }() 159 vArgSlice := argSlice.(reflect.Value) 160 rows := vArgSlice.Len() 161 vals := make([]any, 0, len(names)*rows) 162 for i := 0; i < rows; i++ { 163 args[iArgSlice] = vArgSlice.Index(i).Interface() 164 vals = append(vals, args...) 165 } 166 return p.insertRowsVals(tbl, names, vals, rows) 167 } 168 169 // NOTE: len(args) == len(names) 170 func (p *Class) insertSlice(tbl string, names []string, args []any, rows int) (sql.Result, error) { 171 vals := make([]any, 0, len(names)*rows) 172 for i := 0; i < rows; i++ { 173 for _, arg := range args { 174 v := arg.(reflect.Value) 175 vals = append(vals, v.Index(i).Interface()) 176 } 177 } 178 return p.insertRowsVals(tbl, names, vals, rows) 179 } 180 181 // NOTE: len(vals) == len(names) * rows 182 func (p *Class) insertRowsVals(tbl string, names []string, vals []any, rows int) (sql.Result, error) { 183 n := len(names) 184 query := makeInsertExpr(tbl, names) 185 query = append(query, valParams(n, rows)...) 186 187 q := string(query) 188 if debugExec { 189 log.Println("==>", q, vals) 190 } 191 result, err := p.db.ExecContext(context.TODO(), q, vals...) 192 return p.insertRet(result, err) 193 } 194 195 func (p *Class) insertRow(tbl string, names []string, vals []any) (sql.Result, error) { 196 if len(names) == 0 { 197 log.Panicln("insert: nothing to insert") 198 } 199 query := makeInsertExpr(tbl, names) 200 query = append(query, valParam(len(vals))...) 201 202 q := string(query) 203 if debugExec { 204 log.Println("==>", q, vals) 205 } 206 result, err := p.db.ExecContext(context.TODO(), q, vals...) 207 return p.insertRet(result, err) 208 } 209 210 func (p *Class) insertRet(result sql.Result, err error) (sql.Result, error) { 211 if err != nil { 212 p.handleErr("insert:", err) 213 } 214 return result, err 215 } 216 217 func makeInsertExpr(tbl string, names []string) []byte { 218 query := make([]byte, 0, 128) 219 query = append(query, "INSERT INTO "...) 220 query = append(query, tbl...) 221 query = append(query, ' ', '(') 222 query = append(query, strings.Join(names, ",")...) 223 query = append(query, ") VALUES "...) 224 return query 225 } 226 227 func valParams(n, rows int) string { 228 valparam := valParam(n) 229 valparams := strings.Repeat(valparam+",", rows) 230 valparams = valparams[:len(valparams)-1] 231 return valparams 232 } 233 234 func valParam(n int) string { 235 valparam := strings.Repeat("?,", n) 236 valparam = "(" + valparam[:len(valparam)-1] + ")" 237 return valparam 238 } 239 240 // -----------------------------------------------------------------------------