github.com/goplus/yap@v0.8.1/ydb/insert.go (about)

     1  /*
     2   * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ydb
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  	"log"
    23  	"reflect"
    24  	"strings"
    25  )
    26  
    27  // -----------------------------------------------------------------------------
    28  
    29  // Insert inserts new rows.
    30  //   - insert <colName1>, <val1>, <colName2>, <val2>, ...
    31  //   - insert <colName1>, <valSlice1>, <colName2>, <valSlice2>, ...
    32  //   - insert <structValOrPtr>
    33  //   - insert <structOrPtrSlice>
    34  func (p *Class) Insert(args ...any) (sql.Result, error) {
    35  	if p.tbl == "" {
    36  		log.Panicln("please call `use <tableName>` to specified current table")
    37  	}
    38  	nArg := len(args)
    39  	if nArg == 1 {
    40  		return p.insertStruc(args[0])
    41  	}
    42  	return p.insertKvPair(args...)
    43  }
    44  
    45  // Insert inserts a new row.
    46  //   - insert <structValOrPtr>
    47  //   - insert <structOrPtrSlice>
    48  func (p *Class) insertStruc(arg any) (sql.Result, error) {
    49  	vArg := reflect.ValueOf(arg)
    50  	switch vArg.Kind() {
    51  	case reflect.Slice:
    52  		return p.insertStrucRows(vArg)
    53  	case reflect.Pointer:
    54  		vArg = vArg.Elem()
    55  		fallthrough
    56  	default:
    57  		return p.insertStrucRow(vArg)
    58  	}
    59  }
    60  
    61  func (p *Class) insertStrucRows(vSlice reflect.Value) (sql.Result, error) {
    62  	rows := vSlice.Len()
    63  	if rows == 0 {
    64  		return nil, nil
    65  	}
    66  	hasPtr := false
    67  	elem := vSlice.Type().Elem()
    68  	kind := elem.Kind()
    69  	if kind == reflect.Pointer {
    70  		elem, hasPtr = elem.Elem(), true
    71  		kind = elem.Kind()
    72  	}
    73  	if kind != reflect.Struct {
    74  		log.Panicln("usage: insert <structOrPtrSlice>")
    75  	}
    76  	n := elem.NumField()
    77  	names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0)
    78  	vals := make([]any, 0, len(names)*rows)
    79  	for row := 0; row < rows; row++ {
    80  		vElem := vSlice.Index(row)
    81  		if hasPtr {
    82  			vElem = vElem.Elem()
    83  		}
    84  		vals = getVals(vals, vElem, cols, true)
    85  	}
    86  	return p.insertRowsVals(p.tbl, names, vals, rows)
    87  }
    88  
    89  func (p *Class) insertStrucRow(vArg reflect.Value) (sql.Result, error) {
    90  	if vArg.Kind() != reflect.Struct {
    91  		log.Panicln("usage: insert <structValOrPtr>")
    92  	}
    93  	n := vArg.NumField()
    94  	names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vArg.Type(), 0)
    95  	vals := getVals(make([]any, 0, len(cols)), vArg, cols, true)
    96  	return p.insertRow(p.tbl, names, vals)
    97  }
    98  
    99  const (
   100  	valFlagNormal  = 1
   101  	valFlagSlice   = 2
   102  	valFlagInvalid = valFlagNormal | valFlagSlice
   103  )
   104  
   105  // Insert inserts a new row.
   106  //   - insert <colName1>, <val1>, <colName2>, <val2>, ...
   107  //   - insert <colName1>, <valSlice1>, <colName2>, <valSlice2>, ...
   108  func (p *Class) insertKvPair(kvPair ...any) (sql.Result, error) {
   109  	nPair := len(kvPair)
   110  	if nPair < 2 || nPair&1 != 0 {
   111  		log.Panicln("usage: insert <colName1>, <val1>, <colName2>, <val2>, ...")
   112  	}
   113  	n := nPair >> 1
   114  	names := make([]string, n)
   115  	vals := make([]any, n)
   116  	rows := -1      // slice length
   117  	iArgSlice := -1 // -1: no slice, or index of first slice arg
   118  	kind := 0
   119  	for iPair := 0; iPair < nPair; iPair += 2 {
   120  		i := iPair >> 1
   121  		names[i] = kvPair[iPair].(string)
   122  		val := kvPair[iPair+1]
   123  		switch v := reflect.ValueOf(val); v.Kind() {
   124  		case reflect.Slice:
   125  			vlen := v.Len()
   126  			if iArgSlice == -1 {
   127  				iArgSlice = i
   128  				rows = vlen
   129  			} else if rows != vlen {
   130  				log.Panicf("insert: unexpected slice length. got %d, expected %d\n", vlen, rows)
   131  			} else {
   132  				kind |= valFlagSlice
   133  			}
   134  			vals[i] = v
   135  		default:
   136  			kind |= valFlagNormal
   137  			vals[i] = val
   138  		}
   139  	}
   140  	if kind == valFlagInvalid {
   141  		log.Panicln("insert: can't mix multiple slice arguments and normal value")
   142  	}
   143  	tbl := p.tblFromNames(names)
   144  	if kind == valFlagSlice {
   145  		return p.insertSlice(tbl, names, vals, rows)
   146  	}
   147  	if iArgSlice == -1 {
   148  		return p.insertRow(tbl, names, vals)
   149  	}
   150  	return p.insertMulti(tbl, names, iArgSlice, vals)
   151  }
   152  
   153  // NOTE: len(args) == len(names)
   154  func (p *Class) insertMulti(tbl string, names []string, iArgSlice int, args []any) (sql.Result, error) {
   155  	argSlice := args[iArgSlice]
   156  	defer func() {
   157  		args[iArgSlice] = argSlice
   158  	}()
   159  	vArgSlice := argSlice.(reflect.Value)
   160  	rows := vArgSlice.Len()
   161  	vals := make([]any, 0, len(names)*rows)
   162  	for i := 0; i < rows; i++ {
   163  		args[iArgSlice] = vArgSlice.Index(i).Interface()
   164  		vals = append(vals, args...)
   165  	}
   166  	return p.insertRowsVals(tbl, names, vals, rows)
   167  }
   168  
   169  // NOTE: len(args) == len(names)
   170  func (p *Class) insertSlice(tbl string, names []string, args []any, rows int) (sql.Result, error) {
   171  	vals := make([]any, 0, len(names)*rows)
   172  	for i := 0; i < rows; i++ {
   173  		for _, arg := range args {
   174  			v := arg.(reflect.Value)
   175  			vals = append(vals, v.Index(i).Interface())
   176  		}
   177  	}
   178  	return p.insertRowsVals(tbl, names, vals, rows)
   179  }
   180  
   181  // NOTE: len(vals) == len(names) * rows
   182  func (p *Class) insertRowsVals(tbl string, names []string, vals []any, rows int) (sql.Result, error) {
   183  	n := len(names)
   184  	query := makeInsertExpr(tbl, names)
   185  	query = append(query, valParams(n, rows)...)
   186  
   187  	q := string(query)
   188  	if debugExec {
   189  		log.Println("==>", q, vals)
   190  	}
   191  	result, err := p.db.ExecContext(context.TODO(), q, vals...)
   192  	return p.insertRet(result, err)
   193  }
   194  
   195  func (p *Class) insertRow(tbl string, names []string, vals []any) (sql.Result, error) {
   196  	if len(names) == 0 {
   197  		log.Panicln("insert: nothing to insert")
   198  	}
   199  	query := makeInsertExpr(tbl, names)
   200  	query = append(query, valParam(len(vals))...)
   201  
   202  	q := string(query)
   203  	if debugExec {
   204  		log.Println("==>", q, vals)
   205  	}
   206  	result, err := p.db.ExecContext(context.TODO(), q, vals...)
   207  	return p.insertRet(result, err)
   208  }
   209  
   210  func (p *Class) insertRet(result sql.Result, err error) (sql.Result, error) {
   211  	if err != nil {
   212  		p.handleErr("insert:", err)
   213  	}
   214  	return result, err
   215  }
   216  
   217  func makeInsertExpr(tbl string, names []string) []byte {
   218  	query := make([]byte, 0, 128)
   219  	query = append(query, "INSERT INTO "...)
   220  	query = append(query, tbl...)
   221  	query = append(query, ' ', '(')
   222  	query = append(query, strings.Join(names, ",")...)
   223  	query = append(query, ") VALUES "...)
   224  	return query
   225  }
   226  
   227  func valParams(n, rows int) string {
   228  	valparam := valParam(n)
   229  	valparams := strings.Repeat(valparam+",", rows)
   230  	valparams = valparams[:len(valparams)-1]
   231  	return valparams
   232  }
   233  
   234  func valParam(n int) string {
   235  	valparam := strings.Repeat("?,", n)
   236  	valparam = "(" + valparam[:len(valparam)-1] + ")"
   237  	return valparam
   238  }
   239  
   240  // -----------------------------------------------------------------------------